from typing import Optional, Tuple, Union
import einops
import torch
from qlty.base import (
compute_border_tensor_torch,
compute_chunk_times,
compute_weight_matrix_torch,
normalize_border,
validate_border_weight,
)
[docs]
class NCZYXQuilt:
"""
This class allows one to split larger tensors into smaller ones that perhaps do fit into memory.
This class is aimed at handling tensors of type (N,C,Z,Y,X)
"""
def __init__(
self,
Z: int,
Y: int,
X: int,
window: Tuple[int, int, int],
step: Tuple[int, int, int],
border: Optional[Union[int, Tuple[int, int, int]]],
border_weight: float = 0.1,
) -> None:
"""
This class allows one to split larger tensors into smaller ones that perhaps do fit into memory.
This class is aimed at handling tensors of type (N,C,Z,Y,X).
Parameters
----------
Z : number of elements in the Z direction
Y : number of elements in the Y direction
X : number of elements in the X direction
window: The size of the sliding window, a tuple (Zsub, Ysub, Xsub)
step: The step size at which we want to sample the sliding window (Zstep, Ystep,Xstep)
border: Border pixels of the window we want to 'ignore' or down weight when stitching things back
border_weight: The weight for the border pixels, should be between 0 and 1. The default of 0.1 should be fine
"""
self.Z = Z
self.Y = Y
self.X = X
self.window = window
self.step = step
# Normalize and validate border
self.border = normalize_border(border, ndim=3)
self.border_weight = validate_border_weight(border_weight)
# Compute chunk times
self.nZ, self.nY, self.nX = compute_chunk_times(
dimension_sizes=(Z, Y, X), window=window, step=step
)
# Compute weight matrix
self.weight = compute_weight_matrix_torch(
window=window, border=self.border, border_weight=self.border_weight
)
[docs]
def border_tensor(self) -> torch.Tensor:
"""Compute border tensor indicating valid (non-border) regions."""
return compute_border_tensor_torch(window=self.window, border=self.border)
[docs]
def get_times(self) -> Tuple[int, int, int]:
"""
Computes the number of chunks along Z, Y, and X dimensions, ensuring the last chunk
is included by adjusting the starting points.
"""
return compute_chunk_times(
dimension_sizes=(self.Z, self.Y, self.X), window=self.window, step=self.step
)
[docs]
def unstitch_data_pair(
self, tensor_in: torch.Tensor, tensor_out: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Split input and output 3D tensors into smaller overlapping patches.
This method is useful for training neural networks on 3D volumes where you need
to process input-output pairs together.
Parameters
----------
tensor_in : torch.Tensor
Input tensor of shape (N, C, Z, Y, X). The tensor going into the network.
tensor_out : torch.Tensor
Output tensor of shape (N, C, Z, Y, X) or (N, Z, Y, X). The target tensor.
If 4D, will be automatically expanded to 5D.
Returns
-------
Tuple[torch.Tensor, torch.Tensor]
A tuple of (input_patches, output_patches) where:
- input_patches: Shape (M, C, window[0], window[1], window[2])
- output_patches: Shape (M, C, window[0], window[1], window[2]) or (M, window[0], window[1], window[2])
where M = N * nZ * nY * nX
Examples
--------
>>> quilt = NCZYXQuilt(Z=64, Y=64, X=64, window=(32, 32, 32), step=(16, 16, 16))
>>> input_data = torch.randn(5, 1, 64, 64, 64)
>>> target_data = torch.randn(5, 64, 64, 64)
>>> inp_patches, tgt_patches = quilt.unstitch_data_pair(input_data, target_data)
>>> print(inp_patches.shape) # (M, 1, 32, 32, 32)
>>> print(tgt_patches.shape) # (M, 32, 32, 32)
"""
rearranged = False
if len(tensor_out.shape) == 4:
tensor_out = einops.rearrange(tensor_out, "N Z Y X -> N () Z Y X")
rearranged = True
assert len(tensor_out.shape) == 5
assert len(tensor_in.shape) == 5
assert tensor_in.shape[0] == tensor_out.shape[0]
unstitched_in = self.unstitch(tensor_in)
unstitched_out = self.unstitch(tensor_out)
if rearranged:
assert unstitched_out.shape[1] == 1
unstitched_out = unstitched_out.squeeze(dim=1)
return unstitched_in, unstitched_out
[docs]
def unstitch(self, tensor: torch.Tensor) -> torch.Tensor:
"""
Split a 3D tensor into smaller overlapping patches.
Parameters
----------
tensor : torch.Tensor
Input tensor of shape (N, C, Z, Y, X) where:
- N: Number of volumes
- C: Number of channels
- Z, Y, X: Dimensions (must match self.Z, self.Y, self.X)
Returns
-------
torch.Tensor
Patches tensor of shape (M, C, window[0], window[1], window[2]) where:
- M = N * nZ * nY * nX (total number of patches)
- window[0], window[1], window[2]: Patch dimensions in Z, Y, X
Examples
--------
>>> quilt = NCZYXQuilt(Z=64, Y=64, X=64, window=(32, 32, 32), step=(16, 16, 16))
>>> volume = torch.randn(5, 1, 64, 64, 64)
>>> patches = quilt.unstitch(volume)
>>> print(patches.shape) # (M, 1, 32, 32, 32)
"""
N, C, Z, Y, X = tensor.shape
result = []
for n in range(N):
tmp = tensor[n, ...]
for zz in range(self.nZ):
for yy in range(self.nY):
for xx in range(self.nX):
start_z = zz * self.step[0]
start_y = yy * self.step[1]
start_x = xx * self.step[2]
stop_z = start_z + self.window[0]
stop_y = start_y + self.window[1]
stop_x = start_x + self.window[2]
patch = tmp[:, start_z:stop_z, start_y:stop_y, start_x:stop_x]
result.append(patch)
result = einops.rearrange(result, "M C Z Y X -> M C Z Y X")
return result
[docs]
def stitch(self, ml_tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Reassemble 3D patches back into full-size volumes.
This method takes patches produced by `unstitch()` and stitches them back
together, averaging overlapping regions using a weight matrix.
Typical workflow:
1. Unstitch the data::
patches = quilt.unstitch(volumes)
2. Process patches with your model::
output_patches = model(patches)
3. Stitch back together::
reconstructed, weights = quilt.stitch(output_patches)
Parameters
----------
ml_tensor : torch.Tensor
Patches tensor of shape (M, C, window[0], window[1], window[2]) where:
- M must equal N * nZ * nY * nX (number of patches)
- C: Number of channels
- window: Patch dimensions in (Z, Y, X)
Returns
-------
Tuple[torch.Tensor, torch.Tensor]
A tuple of (reconstructed, weights) where:
- reconstructed: Shape (N, C, Z, Y, X) - the stitched result
- weights: Shape (Z, Y, X) - normalization weights
Notes
-----
**Important**: When working with classification outputs:
- Apply softmax AFTER stitching, not before
- Averaging softmaxed tensors ≠ softmax of averaged tensors
- Process logits, stitch them, then apply softmax to the final result
Examples
--------
>>> quilt = NCZYXQuilt(Z=64, Y=64, X=64, window=(32, 32, 32), step=(16, 16, 16))
>>> volume = torch.randn(5, 1, 64, 64, 64)
>>> patches = quilt.unstitch(volume)
>>> processed = model(patches)
>>> reconstructed, weights = quilt.stitch(processed)
>>> print(reconstructed.shape) # (5, C, 64, 64, 64)
"""
N, C, Z, Y, X = ml_tensor.shape
# we now need to figure out how to sticth this back into what dimension
times = self.nZ * self.nY * self.nX
M_images = N // times
assert N % times == 0
result = torch.zeros((M_images, C, self.Z, self.Y, self.X))
norma = torch.zeros((self.Z, self.Y, self.X))
this_image = 0
for m in range(M_images):
count = 0
for zz in range(self.nZ):
for yy in range(self.nY):
for xx in range(self.nX):
here_and_now = times * this_image + count
start_z = zz * self.step[0]
start_y = yy * self.step[1]
start_x = xx * self.step[2]
stop_z = start_z + self.window[0]
stop_y = start_y + self.window[1]
stop_x = start_x + self.window[2]
tmp = ml_tensor[here_and_now, ...]
result[
this_image,
:,
start_z:stop_z,
start_y:stop_y,
start_x:stop_x,
] += tmp * self.weight
count += 1
# get the weight matrix, only compute once
if m == 0:
norma[
start_z:stop_z, start_y:stop_y, start_x:stop_x
] += self.weight
this_image += 1
result = result / norma
return result, norma