API Reference
This page provides detailed API documentation for all public classes and functions in qlty.
In-Memory Classes
NCYXQuilt
- class qlty.qlty2D.NCYXQuilt(Y: int, X: int, window: Tuple[int, int], step: Tuple[int, int], border: int | Tuple[int, int] | None, border_weight: float = 1.0)[source]
Bases:
objectThis class allows one to split larger tensors into smaller ones that perhaps do fit into memory. This class is aimed at handling tensors of type (N,C,Y,X)
- get_times() Tuple[int, int][source]
Compute the number of patches along each spatial dimension.
This method calculates how many patches will be created in the Y and X dimensions, ensuring the last patch always fits within the image bounds.
Returns
- Tuple[int, int]
A tuple (nY, nX) where: - nY: Number of patches in the Y (height) dimension - nX: Number of patches in the X (width) dimension
The total number of patches per image is nY * nX.
Examples
>>> quilt = NCYXQuilt(Y=128, X=128, window=(32, 32), step=(16, 16)) >>> nY, nX = quilt.get_times() >>> print(f"Patches per image: {nY * nX}") >>> print(f"Total patches for 10 images: {10 * nY * nX}")
- stitch(ml_tensor: Tensor, use_numba: bool = True) Tuple[Tensor, Tensor][source]
Reassemble patches back into full-size tensors.
This method takes patches produced by unstitch() and stitches them back together, averaging overlapping regions using a weight matrix. Border regions are downweighted according to border_weight.
Typical workflow:
Unstitch the data:
patches = quilt.unstitch(input_images)
Process patches with your model:
output_patches = model(patches)
Stitch back together:
reconstructed, weights = quilt.stitch(output_patches)
Parameters
- ml_tensortorch.Tensor
Patches tensor of shape (M, C, window[0], window[1]) where: - M must equal N * nY * nX (number of patches) - C: Number of channels - window: Patch dimensions
- use_numbabool, optional
Whether to use Numba JIT compilation for faster stitching. Default is True (recommended for performance).
Returns
- Tuple[torch.Tensor, torch.Tensor]
A tuple of (reconstructed, weights) where: - reconstructed: Shape (N, C, Y, X) - the stitched result - weights: Shape (Y, X) - normalization weights (number of contributors per pixel)
Notes
Important: When working with classification outputs:
Apply softmax AFTER stitching, not before
Averaging softmaxed tensors ≠ softmax of averaged tensors
Process logits, stitch them, then apply softmax to the final result
Example:
# CORRECT: logits = model(patches) stitched_logits, _ = quilt.stitch(logits) probabilities = F.softmax(stitched_logits, dim=1) # WRONG: probs = F.softmax(model(patches), dim=1) result, _ = quilt.stitch(probs) # This is incorrect!
Examples
>>> quilt = NCYXQuilt(Y=128, X=128, window=(32, 32), step=(16, 16)) >>> data = torch.randn(10, 3, 128, 128) >>> patches = quilt.unstitch(data) >>> processed = model(patches) >>> reconstructed, weights = quilt.stitch(processed) >>> print(reconstructed.shape) # (10, C, 128, 128)
- unstitch(tensor: Tensor) Tensor[source]
Split a tensor into smaller overlapping patches.
Parameters
- tensortorch.Tensor
Input tensor of shape (N, C, Y, X) where: - N: Number of images - C: Number of channels - Y: Height (must match self.Y) - X: Width (must match self.X)
Returns
- torch.Tensor
Patches tensor of shape (M, C, window[0], window[1]) where: - M = N * nY * nX (total number of patches) - window[0], window[1]: Patch dimensions
Examples
>>> quilt = NCYXQuilt(Y=128, X=128, window=(32, 32), step=(16, 16)) >>> data = torch.randn(10, 3, 128, 128) >>> patches = quilt.unstitch(data) >>> print(patches.shape) # (M, 3, 32, 32)
- unstitch_data_pair(tensor_in: Tensor, tensor_out: Tensor, missing_label: int | float | None = None) Tuple[Tensor, Tensor][source]
Split input and output tensors into smaller overlapping patches.
This method is useful for training neural networks where you need to process input-output pairs together. The output tensor can optionally have missing labels that will be masked in border regions.
Parameters
- tensor_intorch.Tensor
Input tensor of shape (N, C, Y, X). The tensor going into the network.
- tensor_outtorch.Tensor
Output tensor of shape (N, C, Y, X) or (N, Y, X). The target tensor. If 3D, will be automatically expanded to 4D.
- missing_labelOptional[Union[int, float]], optional
Label value that indicates missing/invalid data. If provided, pixels in the border region will be set to this value in the output patches. Default is None (no masking).
Returns
- Tuple[torch.Tensor, torch.Tensor]
A tuple of (input_patches, output_patches) where: - input_patches: Shape (M, C, window[0], window[1]) - output_patches: Shape (M, C, window[0], window[1]) or (M, window[0], window[1]) where M = N * nY * nX
Examples
>>> quilt = NCYXQuilt(Y=128, X=128, window=(32, 32), step=(16, 16), border=(5, 5)) >>> input_data = torch.randn(10, 3, 128, 128) >>> target_data = torch.randn(10, 128, 128) >>> inp_patches, tgt_patches = quilt.unstitch_data_pair(input_data, target_data) >>> print(inp_patches.shape) # (M, 3, 32, 32) >>> print(tgt_patches.shape) # (M, 32, 32)
Example:
from qlty import NCYXQuilt
quilt = NCYXQuilt(
Y=128, X=128,
window=(32, 32),
step=(16, 16),
border=(5, 5),
border_weight=0.1
)
data = torch.randn(10, 3, 128, 128)
patches = quilt.unstitch(data)
reconstructed, weights = quilt.stitch(patches)
NCZYXQuilt
- class qlty.qlty3D.NCZYXQuilt(Z: int, Y: int, X: int, window: Tuple[int, int, int], step: Tuple[int, int, int], border: int | Tuple[int, int, int] | None, border_weight: float = 0.1)[source]
Bases:
objectThis class allows one to split larger tensors into smaller ones that perhaps do fit into memory. This class is aimed at handling tensors of type (N,C,Z,Y,X)
- get_times() Tuple[int, int, int][source]
Computes the number of chunks along Z, Y, and X dimensions, ensuring the last chunk is included by adjusting the starting points.
- stitch(ml_tensor: Tensor) Tuple[Tensor, Tensor][source]
Reassemble 3D patches back into full-size volumes.
This method takes patches produced by unstitch() and stitches them back together, averaging overlapping regions using a weight matrix.
Typical workflow:
Unstitch the data:
patches = quilt.unstitch(volumes)
Process patches with your model:
output_patches = model(patches)
Stitch back together:
reconstructed, weights = quilt.stitch(output_patches)
Parameters
- ml_tensortorch.Tensor
Patches tensor of shape (M, C, window[0], window[1], window[2]) where: - M must equal N * nZ * nY * nX (number of patches) - C: Number of channels - window: Patch dimensions in (Z, Y, X)
Returns
- Tuple[torch.Tensor, torch.Tensor]
A tuple of (reconstructed, weights) where: - reconstructed: Shape (N, C, Z, Y, X) - the stitched result - weights: Shape (Z, Y, X) - normalization weights
Notes
Important: When working with classification outputs:
Apply softmax AFTER stitching, not before
Averaging softmaxed tensors ≠ softmax of averaged tensors
Process logits, stitch them, then apply softmax to the final result
Examples
>>> quilt = NCZYXQuilt(Z=64, Y=64, X=64, window=(32, 32, 32), step=(16, 16, 16)) >>> volume = torch.randn(5, 1, 64, 64, 64) >>> patches = quilt.unstitch(volume) >>> processed = model(patches) >>> reconstructed, weights = quilt.stitch(processed) >>> print(reconstructed.shape) # (5, C, 64, 64, 64)
- unstitch(tensor: Tensor) Tensor[source]
Split a 3D tensor into smaller overlapping patches.
Parameters
- tensortorch.Tensor
Input tensor of shape (N, C, Z, Y, X) where: - N: Number of volumes - C: Number of channels - Z, Y, X: Dimensions (must match self.Z, self.Y, self.X)
Returns
- torch.Tensor
Patches tensor of shape (M, C, window[0], window[1], window[2]) where: - M = N * nZ * nY * nX (total number of patches) - window[0], window[1], window[2]: Patch dimensions in Z, Y, X
Examples
>>> quilt = NCZYXQuilt(Z=64, Y=64, X=64, window=(32, 32, 32), step=(16, 16, 16)) >>> volume = torch.randn(5, 1, 64, 64, 64) >>> patches = quilt.unstitch(volume) >>> print(patches.shape) # (M, 1, 32, 32, 32)
- unstitch_data_pair(tensor_in: Tensor, tensor_out: Tensor) Tuple[Tensor, Tensor][source]
Split input and output 3D tensors into smaller overlapping patches.
This method is useful for training neural networks on 3D volumes where you need to process input-output pairs together.
Parameters
- tensor_intorch.Tensor
Input tensor of shape (N, C, Z, Y, X). The tensor going into the network.
- tensor_outtorch.Tensor
Output tensor of shape (N, C, Z, Y, X) or (N, Z, Y, X). The target tensor. If 4D, will be automatically expanded to 5D.
Returns
- Tuple[torch.Tensor, torch.Tensor]
A tuple of (input_patches, output_patches) where: - input_patches: Shape (M, C, window[0], window[1], window[2]) - output_patches: Shape (M, C, window[0], window[1], window[2]) or (M, window[0], window[1], window[2]) where M = N * nZ * nY * nX
Examples
>>> quilt = NCZYXQuilt(Z=64, Y=64, X=64, window=(32, 32, 32), step=(16, 16, 16)) >>> input_data = torch.randn(5, 1, 64, 64, 64) >>> target_data = torch.randn(5, 64, 64, 64) >>> inp_patches, tgt_patches = quilt.unstitch_data_pair(input_data, target_data) >>> print(inp_patches.shape) # (M, 1, 32, 32, 32) >>> print(tgt_patches.shape) # (M, 32, 32, 32)
Example:
from qlty import NCZYXQuilt
quilt = NCZYXQuilt(
Z=64, Y=64, X=64,
window=(32, 32, 32),
step=(16, 16, 16),
border=(4, 4, 4),
border_weight=0.1
)
volume = torch.randn(5, 1, 64, 64, 64)
patches = quilt.unstitch(volume)
reconstructed, weights = quilt.stitch(patches)
Disk-Cached Classes
LargeNCYXQuilt
- class qlty.qlty2DLarge.LargeNCYXQuilt(filename: str, N: int, Y: int, X: int, window: Tuple[int, int], step: Tuple[int, int], border: int | Tuple[int, int] | None, border_weight: float = 0.1)[source]
Bases:
objectThis class allows one to split larger tensors into smaller ones that perhaps do fit into memory. This class is aimed at handling tensors of type (N, C, Y, X).
This object is geared towards handling large datasets.
- border_tensor() ndarray[Any, dtype[float64]][source]
Compute border tensor indicating valid (non-border) regions.
- get_times() Tuple[int, int][source]
Computes the number of chunks along Y and X dimensions, ensuring the last chunk is included by adjusting the starting points.
- return_mean(std: bool = False, normalize: bool = False, eps: float = 1e-08) ndarray[Any, dtype[float64]] | Tuple[ndarray[Any, dtype[float64]], ndarray[Any, dtype[float64]]][source]
Compute and return the final stitched result.
After calling stitch() for all patches, this method computes the final averaged result. The result is normalized by the weight matrix to account for overlapping regions and border downweighting.
Parameters
- stdbool, optional
Whether to compute and return the standard deviation. Requires that patch_var was provided to stitch() calls. Default is False.
- normalizebool, optional
Whether to normalize the result so that values sum to 1.0 along the channel dimension. Useful for probability distributions. Default is False.
- epsfloat, optional
Small epsilon value to prevent division by zero. Default is 1e-8.
Returns
- Union[npt.NDArray, Tuple[npt.NDArray, npt.NDArray]]
If std=False: Returns mean array of shape (N, C, Y, X) If std=True: Returns tuple (mean, std) where both have shape (N, C, Y, X)
The result is a NumPy array (stored as Zarr array on disk).
Notes
This method uses Dask for parallel processing of the Zarr arrays
Results are saved to disk as Zarr arrays (filename + ‘_mean.zarr’ and ‘_std.zarr’)
The computation happens lazily and is only executed when needed
Examples
>>> quilt = LargeNCYXQuilt("data", N=10, Y=128, X=128, ... window=(32, 32), step=(16, 16)) >>> # ... process all patches with quilt.stitch() ... >>> mean = quilt.return_mean() >>> mean, std = quilt.return_mean(std=True) >>> print(f"Mean shape: {mean.shape}") # (10, C, 128, 128)
- unstitch(tensor: Tensor, index: int) Tensor[source]
Extract a single patch from a tensor by index.
This method is used internally by unstitch_next() but can also be called directly if you know the patch index.
Parameters
- tensortorch.Tensor
Input tensor of shape (N, C, Y, X) where: - N: Number of images - C: Number of channels - Y, X: Must match self.Y and self.X
- indexint
Linear index of the patch to extract. Must be in range [0, N_chunks).
Returns
- torch.Tensor
Single patch of shape (C, window[0], window[1])
Examples
>>> quilt = LargeNCYXQuilt("data", N=10, Y=128, X=128, ... window=(32, 32), step=(16, 16)) >>> data = torch.randn(10, 3, 128, 128) >>> patch = quilt.unstitch(data, index=0) >>> print(patch.shape) # (3, 32, 32)
- unstitch_and_clean_sparse_data_pair(tensor_in: Tensor, tensor_out: Tensor, missing_label: int | float) Tuple[Tensor | List, Tensor | List][source]
Split input and output tensors into patches, filtering out patches with no valid data.
This method combines unstitching with sparse data filtering. It: 1. Splits both tensors into patches 2. Marks border regions as missing 3. Filters out patches that contain only missing labels
Parameters
- tensor_intorch.Tensor
Input tensor of shape (N, C, Y, X). The tensor going into the network.
- tensor_outtorch.Tensor
Output tensor of shape (N, C, Y, X) or (N, Y, X). The target tensor. Missing/invalid data should be marked with missing_label.
- missing_labelUnion[int, float]
Label value that indicates missing/invalid data. Patches containing only this value (including border regions) will be filtered out.
Returns
- Tuple[Union[torch.Tensor, List], Union[torch.Tensor, List]]
A tuple of (input_patches, output_patches). If no valid patches are found, returns empty lists. Otherwise returns torch.Tensor objects.
input_patches: Shape (M, C, window[0], window[1]) where M <= N * nY * nX
output_patches: Shape (M, C, window[0], window[1]) or (M, window[0], window[1])
Notes
Border regions are automatically marked as missing in the output patches
Only patches with at least one non-missing label in the valid (non-border) region are kept
This is useful for training with sparse annotations where most of the image is unlabeled
Examples
>>> quilt = LargeNCYXQuilt("data", N=10, Y=128, X=128, ... window=(32, 32), step=(16, 16), border=(5, 5)) >>> input_data = torch.randn(10, 3, 128, 128) >>> labels = torch.ones(10, 128, 128) * (-1) # All missing >>> labels[:, 20:108, 20:108] = 1.0 # Some valid data >>> inp_patches, lbl_patches = quilt.unstitch_and_clean_sparse_data_pair( ... input_data, labels, missing_label=-1 ... ) >>> print(f"Valid patches: {len(inp_patches) if isinstance(inp_patches, list) else inp_patches.shape[0]}")
- unstitch_next(tensor: Tensor) Tuple[int, Tensor][source]
Get the next patch in sequence (generator-like interface).
This method maintains an internal iterator and returns the next patch each time it’s called. Useful for processing large datasets chunk by chunk.
Parameters
- tensortorch.Tensor
Input tensor of shape (N, C, Y, X) where N matches self.N
Returns
- Tuple[int, torch.Tensor]
A tuple of (index, patch) where: - index: Linear index of the patch (0 to N_chunks-1) - patch: Patch tensor of shape (C, window[0], window[1])
Notes
The iterator resets after reaching the end. To process all patches:
for i in range(quilt.N_chunks): index, patch = quilt.unstitch_next(data) # Process patch...
Examples
>>> quilt = LargeNCYXQuilt("data", N=10, Y=128, X=128, ... window=(32, 32), step=(16, 16)) >>> data = torch.randn(10, 3, 128, 128) >>> for i in range(quilt.N_chunks): ... idx, patch = quilt.unstitch_next(data) ... processed = model(patch.unsqueeze(0)) ... quilt.stitch(processed, idx)
Example:
from qlty import LargeNCYXQuilt
import tempfile
import os
temp_dir = tempfile.mkdtemp()
filename = os.path.join(temp_dir, "dataset")
quilt = LargeNCYXQuilt(
filename=filename,
N=100,
Y=512, X=512,
window=(128, 128),
step=(64, 64),
border=(10, 10),
border_weight=0.1
)
data = torch.randn(100, 3, 512, 512)
for i in range(quilt.N_chunks):
idx, patch = quilt.unstitch_next(data)
processed = model(patch.unsqueeze(0))
quilt.stitch(processed, idx)
result = quilt.return_mean()
LargeNCZYXQuilt
- class qlty.qlty3DLarge.LargeNCZYXQuilt(filename: str, N: int, Z: int, Y: int, X: int, window: Tuple[int, int, int], step: Tuple[int, int, int], border: int | Tuple[int, int, int] | None = None, border_weight: float = 0.1)[source]
Bases:
objectThis class allows one to split larger tensors into smaller ones that perhaps do fit into memory. This class is aimed at handling tensors of type (N,C,Z,Y,X)
This object is geared towards handling large datasets.
- border_tensor() ndarray[Any, dtype[float64]][source]
Compute border tensor indicating valid (non-border) regions.
- get_times() Tuple[int, int, int][source]
Computes the number of chunks along Z, Y, and X dimensions, ensuring the last chunk is included by adjusting the starting points.
- return_mean(std: bool = False, renormalize_channels: bool = False, eps: float = 1e-08) ndarray[Any, dtype[float64]] | Tuple[ndarray[Any, dtype[float64]], ndarray[Any, dtype[float64]]][source]
Compute and return the final stitched 3D result.
After calling stitch() for all patches, this method computes the final averaged result. The result is normalized by the weight matrix to account for overlapping regions and border downweighting.
Parameters
- stdbool, optional
Whether to compute and return the standard deviation. Requires that patch_var was provided to stitch() calls. Default is False.
- renormalize_channelsbool, optional
Whether to normalize the result so that values sum to 1.0 along the channel dimension. Useful for probability distributions. Default is False.
- epsfloat, optional
Small epsilon value to prevent division by zero. Default is 1e-8.
Returns
- Union[npt.NDArray, Tuple[npt.NDArray, npt.NDArray]]
If std=False: Returns mean array of shape (N, C, Z, Y, X) If std=True: Returns tuple (mean, std) where both have shape (N, C, Z, Y, X)
The result is a NumPy array (stored as Zarr array on disk).
Notes
This method uses Dask for parallel processing of the Zarr arrays
Results are saved to disk as Zarr arrays
The computation happens lazily and is only executed when needed
Examples
>>> quilt = LargeNCZYXQuilt("data", N=5, Z=64, Y=64, X=64, ... window=(32, 32, 32), step=(16, 16, 16)) >>> # ... process all patches with quilt.stitch() ... >>> mean = quilt.return_mean() >>> mean, std = quilt.return_mean(std=True) >>> print(f"Mean shape: {mean.shape}") # (5, C, 64, 64, 64)
- unstitch(tensor: Tensor, index: int) Tensor[source]
Extract a single 3D patch from a tensor by index.
This method is used internally by unstitch_next() but can also be called directly if you know the patch index.
Parameters
- tensortorch.Tensor
Input tensor of shape (N, C, Z, Y, X) where: - N: Number of volumes - C: Number of channels - Z, Y, X: Must match self.Z, self.Y, and self.X
- indexint
Linear index of the patch to extract. Must be in range [0, N_chunks).
Returns
- torch.Tensor
Single patch of shape (C, window[0], window[1], window[2])
Examples
>>> quilt = LargeNCZYXQuilt("data", N=5, Z=64, Y=64, X=64, ... window=(32, 32, 32), step=(16, 16, 16)) >>> volume = torch.randn(5, 1, 64, 64, 64) >>> patch = quilt.unstitch(volume, index=0) >>> print(patch.shape) # (1, 32, 32, 32)
- unstitch_and_clean_sparse_data_pair(tensor_in: Tensor, tensor_out: Tensor, missing_label: int | float) Tuple[Tensor, Tensor][source]
Split input and output 3D tensors into patches, filtering out patches with no valid data.
This method combines unstitching with sparse data filtering for 3D volumes. It: 1. Splits both tensors into patches 2. Marks border regions as missing 3. Filters out patches that contain only missing labels
Parameters
- tensor_intorch.Tensor
Input tensor of shape (N, C, Z, Y, X). The tensor going into the network.
- tensor_outtorch.Tensor
Output tensor of shape (N, C, Z, Y, X) or (N, Z, Y, X). The target tensor. Missing/invalid data should be marked with missing_label.
- missing_labelUnion[int, float]
Label value that indicates missing/invalid data. Patches containing only this value (including border regions) will be filtered out.
Returns
- Tuple[torch.Tensor, torch.Tensor]
A tuple of (input_patches, output_patches) where: - input_patches: Shape (M, C, window[0], window[1], window[2]) - output_patches: Shape (M, C, window[0], window[1], window[2]) or (M, window[0], window[1], window[2]) where M <= N * nZ * nY * nX
Examples
>>> quilt = LargeNCZYXQuilt("data", N=5, Z=64, Y=64, X=64, ... window=(32, 32, 32), step=(16, 16, 16), border=(4, 4, 4)) >>> input_data = torch.randn(5, 1, 64, 64, 64) >>> labels = torch.ones(5, 64, 64, 64) * (-1) # All missing >>> labels[:, 10:54, 10:54, 10:54] = 1.0 # Some valid data >>> inp_patches, lbl_patches = quilt.unstitch_and_clean_sparse_data_pair( ... input_data, labels, missing_label=-1 ... ) >>> print(f"Valid patches: {inp_patches.shape[0]}")
- unstitch_next(tensor: Tensor) Tuple[int, Tensor][source]
Get the next 3D patch in sequence (generator-like interface).
This method maintains an internal iterator and returns the next patch each time it’s called. Useful for processing large 3D datasets chunk by chunk.
Parameters
- tensortorch.Tensor
Input tensor of shape (N, C, Z, Y, X) where N matches self.N
Returns
- Tuple[int, torch.Tensor]
A tuple of (index, patch) where: - index: Linear index of the patch (0 to N_chunks-1) - patch: Patch tensor of shape (C, window[0], window[1], window[2])
Examples
>>> quilt = LargeNCZYXQuilt("data", N=5, Z=64, Y=64, X=64, ... window=(32, 32, 32), step=(16, 16, 16)) >>> volume = torch.randn(5, 1, 64, 64, 64) >>> for i in range(quilt.N_chunks): ... idx, patch = quilt.unstitch_next(volume) ... processed = model(patch.unsqueeze(0)) ... quilt.stitch(processed, idx)
Utility Functions
weed_sparse_classification_training_pairs_2D
- qlty.cleanup.weed_sparse_classification_training_pairs_2D(tensor_in: Tensor, tensor_out: Tensor, missing_label: int | float, border_tensor: Tensor) Tuple[Tensor, Tensor, Tensor][source]
Filter out patches that contain no valid data after unstitching.
This function removes patches that have only missing labels (or only in border regions). Useful for training with sparse annotations where most of the image is unlabeled.
Parameters
- tensor_intorch.Tensor
Input patches tensor, typically of shape (N, C, Y, X) or (N, Y, X)
- tensor_outtorch.Tensor
Output patches tensor with labels, shape (N, C, Y, X) or (N, Y, X). Missing/invalid data should be marked with missing_label.
- missing_labelUnion[int, float]
Label value that indicates missing/invalid data (typically -1)
- border_tensortorch.Tensor
Border mask tensor from NCYXQuilt.border_tensor() or NCZYXQuilt.border_tensor(). Shape should be (Y, X) for 2D or (Z, Y, X) for 3D (this function handles 2D).
Returns
- Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
A tuple of (valid_input, valid_output, removal_mask) where: - valid_input: Filtered input patches (only patches with valid data) - valid_output: Filtered output patches (only patches with valid data) - removal_mask: Boolean tensor indicating which patches were removed
Notes
Only patches with at least one non-missing label in the valid (non-border) region are kept
Border regions are automatically excluded from the validity check
Useful for semi-supervised learning with sparse annotations
Examples
>>> from qlty import NCYXQuilt, weed_sparse_classification_training_pairs_2D >>> quilt = NCYXQuilt(Y=128, X=128, window=(32, 32), step=(16, 16), border=(5, 5)) >>> input_patches = torch.randn(100, 3, 32, 32) >>> label_patches = torch.ones(100, 32, 32) * (-1) # All missing >>> label_patches[0:50] = 1.0 # Some valid >>> border_tensor = quilt.border_tensor() >>> valid_in, valid_out, mask = weed_sparse_classification_training_pairs_2D( ... input_patches, label_patches, missing_label=-1, border_tensor=border_tensor ... ) >>> print(f"Kept {valid_in.shape[0]} out of {input_patches.shape[0]} patches")
Example:
from qlty import NCYXQuilt, weed_sparse_classification_training_pairs_2D
quilt = NCYXQuilt(Y=128, X=128, window=(32, 32), step=(16, 16), border=(5, 5))
input_patches = torch.randn(100, 3, 32, 32)
label_patches = torch.ones(100, 32, 32) * (-1) # Missing labels
label_patches[0:50] = 1.0 # Some valid
border_tensor = quilt.border_tensor()
valid_in, valid_out, mask = weed_sparse_classification_training_pairs_2D(
input_patches, label_patches, missing_label=-1, border_tensor=border_tensor
)
weed_sparse_classification_training_pairs_3D
- qlty.cleanup.weed_sparse_classification_training_pairs_3D(tensor_in: Tensor, tensor_out: Tensor, missing_label: int | float, border_tensor: Tensor) Tuple[Tensor, Tensor, Tensor][source]
Filter out 3D patches that contain no valid data after unstitching.
This function removes patches that have only missing labels (or only in border regions). Useful for training with sparse 3D annotations.
Parameters
- tensor_intorch.Tensor
Input patches tensor, typically of shape (N, C, Z, Y, X) or (N, Z, Y, X)
- tensor_outtorch.Tensor
Output patches tensor with labels, shape (N, C, Z, Y, X) or (N, Z, Y, X). Missing/invalid data should be marked with missing_label.
- missing_labelUnion[int, float]
Label value that indicates missing/invalid data (typically -1)
- border_tensortorch.Tensor
Border mask tensor from NCZYXQuilt.border_tensor(). Shape should be (Z, Y, X).
Returns
- Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
A tuple of (valid_input, valid_output, removal_mask) where: - valid_input: Filtered input patches (only patches with valid data) - valid_output: Filtered output patches (only patches with valid data) - removal_mask: Boolean tensor indicating which patches were removed
Examples
>>> from qlty import NCZYXQuilt, weed_sparse_classification_training_pairs_3D >>> quilt = NCZYXQuilt(Z=64, Y=64, X=64, window=(32, 32, 32), step=(16, 16, 16), border=(4, 4, 4)) >>> input_patches = torch.randn(100, 1, 32, 32, 32) >>> label_patches = torch.ones(100, 32, 32, 32) * (-1) # All missing >>> label_patches[0:50] = 1.0 # Some valid >>> border_tensor = quilt.border_tensor() >>> valid_in, valid_out, mask = weed_sparse_classification_training_pairs_3D( ... input_patches, label_patches, missing_label=-1, border_tensor=border_tensor ... ) >>> print(f"Kept {valid_in.shape[0]} out of {input_patches.shape[0]} patches")
Parameter Details
Window and Step Sizes
window: Size of each patch in pixels - 2D: (Y_size, X_size) - 3D: (Z_size, Y_size, X_size)
step: Distance the window moves between patches - 2D: (Y_step, X_step) - 3D: (Z_step, Y_step, X_step) - Common: step = window/2 for 50% overlap
Border Parameters
border: Size of border region to downweight - Can be int (same for all dimensions) or tuple (per dimension) - None or 0 means no border - Typically 10-20% of window size
border_weight: Weight for border pixels (0.0 to 1.0) - 0.0: Completely exclude borders - 0.1: Recommended default - 1.0: Full weight (not recommended)
Return Types
All methods return PyTorch tensors (in-memory classes) or NumPy arrays (Large classes):
unstitch(): Returns torch.Tensor of shape (M, C, …)
stitch(): Returns Tuple[torch.Tensor, torch.Tensor] (result, weights)
border_tensor(): Returns torch.Tensor (in-memory) or np.ndarray (Large)
get_times(): Returns Tuple[int, …] with number of patches per dimension