Source code for qlty.qlty3DLarge

from typing import Optional, Tuple, Union

import einops
import numpy as np
import numpy.typing as npt
import torch
import zarr

from qlty.base import (
    compute_border_tensor_numpy,
    compute_chunk_times,
    compute_weight_matrix_numpy,
    normalize_border,
    validate_border_weight,
)


[docs] class LargeNCZYXQuilt: """ This class allows one to split larger tensors into smaller ones that perhaps do fit into memory. This class is aimed at handling tensors of type (N,C,Z,Y,X) This object is geared towards handling large datasets. """ def __init__( self, filename: str, N: int, Z: int, Y: int, X: int, window: Tuple[int, int, int], step: Tuple[int, int, int], border: Optional[Union[int, Tuple[int, int, int]]] = None, border_weight: float = 0.1, ) -> None: """ This class allows one to split larger tensors into smaller ones that perhaps do fit into memory. This class is aimed at handling tensors of type (N,C,Z,Y,X). Parameters ---------- filename: the base filename for storage. Z : number of elements in the Z direction Y : number of elements in the Y direction X : number of elements in the X direction window: The size of the sliding window, a tuple (Zsub, Ysub, Xsub) step: The step size at which we want to sample the sliding window (Zstep, Ystep,Xstep) border: Border pixels of the window we want to 'ignore' or down weight when stitching things back border_weight: The weight for the border pixels, should be between 0 and 1. The default of 0.1 should be fine """ self.filename = filename self.N = N self.Z = Z self.Y = Y self.X = X self.window = window self.step = step # Normalize and validate border self.border = normalize_border(border, ndim=3) self.border_weight = validate_border_weight(border_weight) # Compute chunk times self.nZ, self.nY, self.nX = compute_chunk_times( dimension_sizes=(Z, Y, X), window=window, step=step ) # Compute weight matrix (as torch tensor for compatibility) weight_np = compute_weight_matrix_numpy( window=window, border=self.border, border_weight=self.border_weight ) self.weight = torch.from_numpy(weight_np) self.N_chunks = self.N * self.nZ * self.nY * self.nX self.mean = None self.norma = None self.chunkerator = iter(np.arange(self.N_chunks))
[docs] def border_tensor(self) -> npt.NDArray[np.float64]: """Compute border tensor indicating valid (non-border) regions.""" return compute_border_tensor_numpy(window=self.window, border=self.border)
[docs] def get_times(self) -> Tuple[int, int, int]: """ Computes the number of chunks along Z, Y, and X dimensions, ensuring the last chunk is included by adjusting the starting points. """ return compute_chunk_times( dimension_sizes=(self.Z, self.Y, self.X), window=self.window, step=self.step )
[docs] def unstitch_and_clean_sparse_data_pair( self, tensor_in: torch.Tensor, tensor_out: torch.Tensor, missing_label: Union[int, float], ) -> Tuple[torch.Tensor, torch.Tensor]: """ Split input and output 3D tensors into patches, filtering out patches with no valid data. This method combines unstitching with sparse data filtering for 3D volumes. It: 1. Splits both tensors into patches 2. Marks border regions as missing 3. Filters out patches that contain only missing labels Parameters ---------- tensor_in : torch.Tensor Input tensor of shape (N, C, Z, Y, X). The tensor going into the network. tensor_out : torch.Tensor Output tensor of shape (N, C, Z, Y, X) or (N, Z, Y, X). The target tensor. Missing/invalid data should be marked with `missing_label`. missing_label : Union[int, float] Label value that indicates missing/invalid data. Patches containing only this value (including border regions) will be filtered out. Returns ------- Tuple[torch.Tensor, torch.Tensor] A tuple of (input_patches, output_patches) where: - input_patches: Shape (M, C, window[0], window[1], window[2]) - output_patches: Shape (M, C, window[0], window[1], window[2]) or (M, window[0], window[1], window[2]) where M <= N * nZ * nY * nX Examples -------- >>> quilt = LargeNCZYXQuilt("data", N=5, Z=64, Y=64, X=64, ... window=(32, 32, 32), step=(16, 16, 16), border=(4, 4, 4)) >>> input_data = torch.randn(5, 1, 64, 64, 64) >>> labels = torch.ones(5, 64, 64, 64) * (-1) # All missing >>> labels[:, 10:54, 10:54, 10:54] = 1.0 # Some valid data >>> inp_patches, lbl_patches = quilt.unstitch_and_clean_sparse_data_pair( ... input_data, labels, missing_label=-1 ... ) >>> print(f"Valid patches: {inp_patches.shape[0]}") """ rearranged = False if len(tensor_out.shape) == 4: tensor_out = tensor_out.unsqueeze(dim=1) rearranged = True assert len(tensor_out.shape) == 5 assert len(tensor_in.shape) == 5 assert tensor_in.shape[0] == tensor_out.shape[0] unstitched_in = [] unstitched_out = [] modsel = self.border_tensor() modsel = modsel < 0.5 for ii in range(self.N_chunks): out_chunk = self.unstitch(tensor_out, ii).clone() out_chunk[:, modsel] = missing_label NN = out_chunk.nelement() not_present = torch.sum(out_chunk == missing_label).item() if not_present != NN: unstitched_in.append(self.unstitch(tensor_in, ii)) unstitched_out.append(out_chunk) unstitched_in = einops.rearrange(unstitched_in, "N C Z Y X -> N C Z Y X") unstitched_out = einops.rearrange(unstitched_out, "N C Z Y X -> N C Z Y X") if rearranged: assert unstitched_out.shape[1] == 1 unstitched_out = unstitched_out.squeeze(dim=1) return unstitched_in, unstitched_out
[docs] def unstitch(self, tensor: torch.Tensor, index: int) -> torch.Tensor: """ Extract a single 3D patch from a tensor by index. This method is used internally by `unstitch_next()` but can also be called directly if you know the patch index. Parameters ---------- tensor : torch.Tensor Input tensor of shape (N, C, Z, Y, X) where: - N: Number of volumes - C: Number of channels - Z, Y, X: Must match self.Z, self.Y, and self.X index : int Linear index of the patch to extract. Must be in range [0, N_chunks). Returns ------- torch.Tensor Single patch of shape (C, window[0], window[1], window[2]) Examples -------- >>> quilt = LargeNCZYXQuilt("data", N=5, Z=64, Y=64, X=64, ... window=(32, 32, 32), step=(16, 16, 16)) >>> volume = torch.randn(5, 1, 64, 64, 64) >>> patch = quilt.unstitch(volume, index=0) >>> print(patch.shape) # (1, 32, 32, 32) """ N, C, Z, Y, X = tensor.shape out_shape = (N, self.nZ, self.nY, self.nX) n, zz, yy, xx = np.unravel_index(index, out_shape) # Adjust the starting point for the last chunk in each dimension start_z = min(zz * self.step[0], Z - self.window[0]) start_y = min(yy * self.step[1], Y - self.window[1]) start_x = min(xx * self.step[2], X - self.window[2]) stop_z = start_z + self.window[0] stop_y = start_y + self.window[1] stop_x = start_x + self.window[2] patch = tensor[n, :, start_z:stop_z, start_y:stop_y, start_x:stop_x] return patch
[docs] def stitch( self, patch: torch.Tensor, index_flat: int, patch_var: Optional[torch.Tensor] = None, ) -> None: C = patch.shape[1] if self.mean is None: # Initialization code remains the same... self.mean = zarr.open( self.filename + "_mean_cache.zarr", shape=(self.N, C, self.Z, self.Y, self.X), chunks=(1, C, self.window[0], self.window[1], self.window[2]), mode="w", fill_value=0, ) self.std = zarr.open( self.filename + "_std_cache.zarr", shape=(self.N, C, self.Z, self.Y, self.X), chunks=(1, C, self.window[0], self.window[1], self.window[2]), mode="w", fill_value=0, ) self.norma = zarr.open( self.filename + "_norma_cache.zarr", shape=(self.Z, self.Y, self.X), chunks=self.window, mode="w", fill_value=0, ) screen_shape = (self.N, self.nZ, self.nY, self.nX) n, zz, yy, xx = np.unravel_index(index_flat, screen_shape) # Adjust the starting point for the last chunk in each dimension start_z = min(zz * self.step[0], self.Z - self.window[0]) start_y = min(yy * self.step[1], self.Y - self.window[1]) start_x = min(xx * self.step[2], self.X - self.window[2]) stop_z = start_z + self.window[0] stop_y = start_y + self.window[1] stop_x = start_x + self.window[2] # Update the mean, std, and norma arrays self.mean[n : n + 1, :, start_z:stop_z, start_y:stop_y, start_x:stop_x] += ( patch.numpy() * self.weight.numpy() ) if patch_var is not None: self.std[n : n + 1, :, start_z:stop_z, start_y:stop_y, start_x:stop_x] += ( patch_var.numpy() * self.weight.numpy() ) if n == 0: self.norma[ start_z:stop_z, start_y:stop_y, start_x:stop_x ] += self.weight.numpy()
[docs] def unstitch_next(self, tensor: torch.Tensor) -> Tuple[int, torch.Tensor]: """ Get the next 3D patch in sequence (generator-like interface). This method maintains an internal iterator and returns the next patch each time it's called. Useful for processing large 3D datasets chunk by chunk. Parameters ---------- tensor : torch.Tensor Input tensor of shape (N, C, Z, Y, X) where N matches self.N Returns ------- Tuple[int, torch.Tensor] A tuple of (index, patch) where: - index: Linear index of the patch (0 to N_chunks-1) - patch: Patch tensor of shape (C, window[0], window[1], window[2]) Examples -------- >>> quilt = LargeNCZYXQuilt("data", N=5, Z=64, Y=64, X=64, ... window=(32, 32, 32), step=(16, 16, 16)) >>> volume = torch.randn(5, 1, 64, 64, 64) >>> for i in range(quilt.N_chunks): ... idx, patch = quilt.unstitch_next(volume) ... processed = model(patch.unsqueeze(0)) ... quilt.stitch(processed, idx) """ this_ind = next(self.chunkerator) tmp = self.unstitch(tensor, this_ind) return this_ind, tmp
[docs] def return_mean( self, std: bool = False, renormalize_channels: bool = False, eps: float = 1e-8 ) -> Union[ npt.NDArray[np.float64], Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]] ]: """ Compute and return the final stitched 3D result. After calling `stitch()` for all patches, this method computes the final averaged result. The result is normalized by the weight matrix to account for overlapping regions and border downweighting. Parameters ---------- std : bool, optional Whether to compute and return the standard deviation. Requires that `patch_var` was provided to `stitch()` calls. Default is False. renormalize_channels : bool, optional Whether to normalize the result so that values sum to 1.0 along the channel dimension. Useful for probability distributions. Default is False. eps : float, optional Small epsilon value to prevent division by zero. Default is 1e-8. Returns ------- Union[npt.NDArray, Tuple[npt.NDArray, npt.NDArray]] If std=False: Returns mean array of shape (N, C, Z, Y, X) If std=True: Returns tuple (mean, std) where both have shape (N, C, Z, Y, X) The result is a NumPy array (stored as Zarr array on disk). Notes ----- - This method uses Dask for parallel processing of the Zarr arrays - Results are saved to disk as Zarr arrays - The computation happens lazily and is only executed when needed Examples -------- >>> quilt = LargeNCZYXQuilt("data", N=5, Z=64, Y=64, X=64, ... window=(32, 32, 32), step=(16, 16, 16)) >>> # ... process all patches with quilt.stitch() ... >>> mean = quilt.return_mean() >>> mean, std = quilt.return_mean(std=True) >>> print(f"Mean shape: {mean.shape}") # (5, C, 64, 64, 64) """ import dask.array as da # Convert Zarr arrays to Dask arrays for parallel processing mean_dask = da.from_zarr(self.mean) norma_dask = da.from_zarr(self.norma) + eps norma_dask = da.expand_dims(norma_dask, axis=0) norma_dask = da.expand_dims(norma_dask, axis=0) std_dask = da.from_zarr(self.std) if std else None # Compute mean and std using Dask mean_accumulated = mean_dask / norma_dask if std: std_accumulated = da.sqrt(da.abs(std_dask / norma_dask)) # Renormalize if required if renormalize_channels: norm = da.sum(mean_accumulated, axis=1) mean_accumulated /= norm if std: std_accumulated /= norm # Define file paths for Zarr arrays mean_zarr_path = self.filename + "_mean.zarr" std_zarr_path = (self.filename + "_std.zarr") if std else None # Store the result into Zarr arrays on disk mean_zarr = mean_accumulated.compute() zarr.save(mean_zarr_path, mean_zarr) if std: std_zarr = std_accumulated.compute() zarr.save(std_zarr_path, std_zarr) return mean_zarr, std_zarr return mean_zarr
def tst(): data = np.random.uniform(0, 1, (2, 1, 100, 100, 100)) * 100 labels = np.zeros((2, 100, 100, 100)) - 1 labels[:, 0:51, 0:51, 0:51] = 1 Tdata = torch.Tensor(data) Tlabels = torch.tensor(labels) qobj = LargeNCZYXQuilt( "test", 2, 100, 100, 100, window=(50, 50, 50), step=(25, 35, 45), border=(1, 1, 1), ) d, n = qobj.unstitch_and_clean_sparse_data_pair(Tdata, Tlabels, -1) assert d.shape[0] == 16 for ii in range(qobj.N_chunks): ind, tmp = qobj.unstitch_next(Tdata) neural_network_result = tmp.unsqueeze(0) qobj.stitch(neural_network_result, ii) mean = qobj.return_mean() assert np.max(np.abs(mean - data)) < 1e-4 return True if __name__ == "__main__": tst() print("OK")