Usage
qlty provides tools to unstitch and stitch PyTorch tensors efficiently.
Basic Import
To use qlty in a project, import it:
import qlty
from qlty import NCYXQuilt, NCZYXQuilt
from qlty import LargeNCYXQuilt, LargeNCZYXQuilt
2D In-Memory Processing
Basic Example
Let’s make some mock data and process it:
import einops
import torch
import numpy as np
# Create sample data
x = torch.rand((10, 3, 128, 128)) # Input images: (N, C, Y, X)
y = torch.rand((10, 1, 128, 128)) # Target labels: (N, C, Y, X)
Assume that x and y are data whose relation you are trying to learn using some network, such that after training, you have:
y_guess = net(x)
with:
torch.sum(torch.abs(y_guess - y)) < a_small_number
If the data you have is large and doesn’t fit onto your GPU card, or if you need to chop things up into smaller bits for boundary detection, qlty can be used. Let’s take the above data and chop it into smaller bits:
quilt = qlty.NCYXQuilt(
Y=128,
X=128,
window=(16, 16), # Patch size
step=(4, 4), # Step size (creates overlap)
border=(4, 4), # Border region
border_weight=0.1 # Weight for border pixels
)
This object now allows one to cut any input tensor with shape (N, C, Y, X) into smaller, overlapping patches of size (M, C, Y_window, X_window). The moving window, in this case a 16x16 patch, is moved along the input tensor with steps (4, 4). In addition, we define a border region in these patches of 4 pixels wide. Pixels in this area will be assigned weight border_weight (0.1 in this case) when data is stitched back together.
Unstitching Data Pairs
Let’s unstitch the (x, y) training data pair:
x_bits, y_bits = quilt.unstitch_data_pair(x, y)
print("x shape: ", x.shape)
print("y shape: ", y.shape)
print("x_bits shape:", x_bits.shape)
print("y_bits shape:", y_bits.shape)
Yielding:
x shape: torch.Size([10, 3, 128, 128])
y shape: torch.Size([10, 1, 128, 128])
x_bits shape: torch.Size([8410, 3, 16, 16])
y_bits shape: torch.Size([8410, 16, 16])
Stitching Back Together
If we now make some mock data that a neural network has returned:
y_mock = torch.rand((8410, 17, 16, 16))
we can stitch it back together into the right shape, averaging overlapping areas, excluding or downweighting border areas:
y_stitched, weights = quilt.stitch(y_mock)
which gives:
print(y_stitched.shape)
torch.Size([10, 17, 128, 128])
The ‘weights’ tensor encodes how many contributors there were for each pixel.
Using Numba Acceleration
The 2D stitch method can use Numba JIT compilation for faster processing:
result, weights = quilt.stitch(patches, use_numba=True) # Default
result, weights = quilt.stitch(patches, use_numba=False) # Pure PyTorch
3D Volume Processing
For 3D volumes, use NCZYXQuilt:
import torch
from qlty import NCZYXQuilt
# Create 3D quilt object
quilt = NCZYXQuilt(
Z=64, Y=64, X=64,
window=(32, 32, 32), # 3D patch size
step=(16, 16, 16), # Step in Z, Y, X
border=(4, 4, 4), # Border in each dimension
border_weight=0.1
)
# Process 3D volume
volume = torch.randn(5, 1, 64, 64, 64) # (N, C, Z, Y, X)
patches = quilt.unstitch(volume)
# Process patches...
processed = your_model(patches)
# Stitch back
reconstructed, weights = quilt.stitch(processed)
Large Dataset Processing (Disk-Cached)
For very large datasets that don’t fit in memory, use the Large classes:
import torch
import tempfile
import os
from qlty import LargeNCYXQuilt
# Create temporary directory for cache
temp_dir = tempfile.mkdtemp()
filename = os.path.join(temp_dir, "my_dataset")
# Create Large quilt object
quilt = LargeNCYXQuilt(
filename=filename,
N=100, # Number of images
Y=512, X=512, # Image dimensions
window=(128, 128),
step=(64, 64),
border=(10, 10),
border_weight=0.1
)
# Load your data
data = torch.randn(100, 3, 512, 512)
# Process all chunks
for i in range(quilt.N_chunks):
index, patch = quilt.unstitch_next(data)
# Process patch (e.g., with neural network)
processed = your_model(patch.unsqueeze(0))
# Accumulate result
quilt.stitch(processed, index)
# Get final result
mean_result = quilt.return_mean()
mean_with_std = quilt.return_mean(std=True)
Handling Missing Data
When working with sparse or incomplete data, you can filter out patches with no valid data:
from qlty import NCYXQuilt, weed_sparse_classification_training_pairs_2D
quilt = NCYXQuilt(
Y=128, X=128,
window=(32, 32),
step=(16, 16),
border=(5, 5),
border_weight=0.1
)
# Create data with missing labels
input_data = torch.randn(10, 3, 128, 128)
labels = torch.ones(10, 128, 128) * (-1) # Missing label = -1
labels[:, 20:108, 20:108] = 1.0 # Some valid data
# Unstitch with missing label handling
input_patches, label_patches = quilt.unstitch_data_pair(
input_data, labels, missing_label=-1
)
# Filter out patches with no valid data
border_tensor = quilt.border_tensor()
valid_input, valid_labels, mask = weed_sparse_classification_training_pairs_2D(
input_patches, label_patches, missing_label=-1, border_tensor=border_tensor
)
print(f"Original patches: {input_patches.shape[0]}")
print(f"Valid patches: {valid_input.shape[0]}")
Advanced: Working with Border Regions
The border tensor indicates which pixels are in the border region:
border_mask = quilt.border_tensor()
print(border_mask.shape) # (window_height, window_width)
print(border_mask.sum()) # Number of valid (non-border) pixels
Border regions are set to 0.0, valid regions to 1.0. This can be used to mask out border regions during training.
Computing Chunk Information
To know how many patches will be created:
nY, nX = quilt.get_times()
print(f"Patches in Y direction: {nY}")
print(f"Patches in X direction: {nX}")
print(f"Total patches per image: {nY * nX}")
For a tensor with N images, the total number of patches will be N * nY * nX.
Best Practices
Overlap Strategy: - Use step size = window/2 for 50% overlap (common choice) - More overlap = smoother results but more computation - Less overlap = faster but may have artifacts
Border Size: - Typically 10-20% of window size - Larger for networks sensitive to edge effects - Smaller for networks with good edge handling
Border Weight: - 0.1 is a good default - 0.0 completely excludes borders - 1.0 gives equal weight (not recommended)
Memory Management: - Use in-memory classes (NCYXQuilt, NCZYXQuilt) if data fits in RAM - Use Large classes for datasets > several GB - Large classes use Zarr for efficient disk caching
Softmax Warning: - Apply softmax AFTER stitching, not before - Averaging softmaxed tensors ≠ softmax of averaged tensors - Process logits, then apply softmax to final result
Common Patterns
Training Loop Pattern
quilt = NCYXQuilt(Y=256, X=256, window=(64, 64), step=(32, 32), border=(8, 8))
for epoch in range(num_epochs):
for images, labels in dataloader:
# Unstitch
img_patches, lbl_patches = quilt.unstitch_data_pair(images, labels)
# Train
for img, lbl in zip(img_patches, lbl_patches):
output = model(img.unsqueeze(0))
loss = criterion(output, lbl.unsqueeze(0))
# ...
Inference Pattern
quilt = NCYXQuilt(Y=512, X=512, window=(128, 128), step=(64, 64), border=(10, 10))
# Unstitch
patches = quilt.unstitch(test_image)
# Process
with torch.no_grad():
outputs = model(patches)
# Stitch
result, weights = quilt.stitch(outputs)
Large Dataset Pattern
quilt = LargeNCYXQuilt(filename, N=1000, Y=1024, X=1024,
window=(256, 256), step=(128, 128), border=(20, 20))
# Process in chunks
for i in range(quilt.N_chunks):
idx, patch = quilt.unstitch_next(data)
processed = model(patch.unsqueeze(0))
quilt.stitch(processed, idx)
# Get results
mean = quilt.return_mean()
mean, std = quilt.return_mean(std=True)