Usage ===== qlty provides tools to unstitch and stitch PyTorch tensors efficiently. Basic Import ------------ To use qlty in a project, import it:: import qlty from qlty import NCYXQuilt, NCZYXQuilt from qlty import LargeNCYXQuilt, LargeNCZYXQuilt 2D In-Memory Processing ------------------------ Basic Example ~~~~~~~~~~~~~ Let's make some mock data and process it:: import einops import torch import numpy as np # Create sample data x = torch.rand((10, 3, 128, 128)) # Input images: (N, C, Y, X) y = torch.rand((10, 1, 128, 128)) # Target labels: (N, C, Y, X) Assume that x and y are data whose relation you are trying to learn using some network, such that after training, you have:: y_guess = net(x) with:: torch.sum(torch.abs(y_guess - y)) < a_small_number If the data you have is large and doesn't fit onto your GPU card, or if you need to chop things up into smaller bits for boundary detection, qlty can be used. Let's take the above data and chop it into smaller bits:: quilt = qlty.NCYXQuilt( Y=128, X=128, window=(16, 16), # Patch size step=(4, 4), # Step size (creates overlap) border=(4, 4), # Border region border_weight=0.1 # Weight for border pixels ) This object now allows one to cut any input tensor with shape (N, C, Y, X) into smaller, overlapping patches of size (M, C, Y_window, X_window). The moving window, in this case a 16x16 patch, is moved along the input tensor with steps (4, 4). In addition, we define a border region in these patches of 4 pixels wide. Pixels in this area will be assigned weight border_weight (0.1 in this case) when data is stitched back together. Unstitching Data Pairs ~~~~~~~~~~~~~~~~~~~~~~~ Let's unstitch the (x, y) training data pair:: x_bits, y_bits = quilt.unstitch_data_pair(x, y) print("x shape: ", x.shape) print("y shape: ", y.shape) print("x_bits shape:", x_bits.shape) print("y_bits shape:", y_bits.shape) Yielding:: x shape: torch.Size([10, 3, 128, 128]) y shape: torch.Size([10, 1, 128, 128]) x_bits shape: torch.Size([8410, 3, 16, 16]) y_bits shape: torch.Size([8410, 16, 16]) Stitching Back Together ~~~~~~~~~~~~~~~~~~~~~~~~ If we now make some mock data that a neural network has returned:: y_mock = torch.rand((8410, 17, 16, 16)) we can stitch it back together into the right shape, averaging overlapping areas, excluding or downweighting border areas:: y_stitched, weights = quilt.stitch(y_mock) which gives:: print(y_stitched.shape) torch.Size([10, 17, 128, 128]) The 'weights' tensor encodes how many contributors there were for each pixel. Using Numba Acceleration ~~~~~~~~~~~~~~~~~~~~~~~~ The 2D stitch method can use Numba JIT compilation for faster processing:: result, weights = quilt.stitch(patches, use_numba=True) # Default result, weights = quilt.stitch(patches, use_numba=False) # Pure PyTorch 3D Volume Processing -------------------- For 3D volumes, use NCZYXQuilt:: import torch from qlty import NCZYXQuilt # Create 3D quilt object quilt = NCZYXQuilt( Z=64, Y=64, X=64, window=(32, 32, 32), # 3D patch size step=(16, 16, 16), # Step in Z, Y, X border=(4, 4, 4), # Border in each dimension border_weight=0.1 ) # Process 3D volume volume = torch.randn(5, 1, 64, 64, 64) # (N, C, Z, Y, X) patches = quilt.unstitch(volume) # Process patches... processed = your_model(patches) # Stitch back reconstructed, weights = quilt.stitch(processed) Large Dataset Processing (Disk-Cached) --------------------------------------- For very large datasets that don't fit in memory, use the Large classes:: import torch import tempfile import os from qlty import LargeNCYXQuilt # Create temporary directory for cache temp_dir = tempfile.mkdtemp() filename = os.path.join(temp_dir, "my_dataset") # Create Large quilt object quilt = LargeNCYXQuilt( filename=filename, N=100, # Number of images Y=512, X=512, # Image dimensions window=(128, 128), step=(64, 64), border=(10, 10), border_weight=0.1 ) # Load your data data = torch.randn(100, 3, 512, 512) # Process all chunks for i in range(quilt.N_chunks): index, patch = quilt.unstitch_next(data) # Process patch (e.g., with neural network) processed = your_model(patch.unsqueeze(0)) # Accumulate result quilt.stitch(processed, index) # Get final result mean_result = quilt.return_mean() mean_with_std = quilt.return_mean(std=True) Handling Missing Data --------------------- When working with sparse or incomplete data, you can filter out patches with no valid data:: from qlty import NCYXQuilt, weed_sparse_classification_training_pairs_2D quilt = NCYXQuilt( Y=128, X=128, window=(32, 32), step=(16, 16), border=(5, 5), border_weight=0.1 ) # Create data with missing labels input_data = torch.randn(10, 3, 128, 128) labels = torch.ones(10, 128, 128) * (-1) # Missing label = -1 labels[:, 20:108, 20:108] = 1.0 # Some valid data # Unstitch with missing label handling input_patches, label_patches = quilt.unstitch_data_pair( input_data, labels, missing_label=-1 ) # Filter out patches with no valid data border_tensor = quilt.border_tensor() valid_input, valid_labels, mask = weed_sparse_classification_training_pairs_2D( input_patches, label_patches, missing_label=-1, border_tensor=border_tensor ) print(f"Original patches: {input_patches.shape[0]}") print(f"Valid patches: {valid_input.shape[0]}") Advanced: Working with Border Regions -------------------------------------- The border tensor indicates which pixels are in the border region:: border_mask = quilt.border_tensor() print(border_mask.shape) # (window_height, window_width) print(border_mask.sum()) # Number of valid (non-border) pixels Border regions are set to 0.0, valid regions to 1.0. This can be used to mask out border regions during training. Computing Chunk Information ---------------------------- To know how many patches will be created:: nY, nX = quilt.get_times() print(f"Patches in Y direction: {nY}") print(f"Patches in X direction: {nX}") print(f"Total patches per image: {nY * nX}") For a tensor with N images, the total number of patches will be N * nY * nX. Best Practices -------------- 1. **Overlap Strategy**: - Use step size = window/2 for 50% overlap (common choice) - More overlap = smoother results but more computation - Less overlap = faster but may have artifacts 2. **Border Size**: - Typically 10-20% of window size - Larger for networks sensitive to edge effects - Smaller for networks with good edge handling 3. **Border Weight**: - 0.1 is a good default - 0.0 completely excludes borders - 1.0 gives equal weight (not recommended) 4. **Memory Management**: - Use in-memory classes (NCYXQuilt, NCZYXQuilt) if data fits in RAM - Use Large classes for datasets > several GB - Large classes use Zarr for efficient disk caching 5. **Softmax Warning**: - Apply softmax AFTER stitching, not before - Averaging softmaxed tensors ≠ softmax of averaged tensors - Process logits, then apply softmax to final result Common Patterns --------------- Training Loop Pattern ~~~~~~~~~~~~~~~~~~~~~ :: quilt = NCYXQuilt(Y=256, X=256, window=(64, 64), step=(32, 32), border=(8, 8)) for epoch in range(num_epochs): for images, labels in dataloader: # Unstitch img_patches, lbl_patches = quilt.unstitch_data_pair(images, labels) # Train for img, lbl in zip(img_patches, lbl_patches): output = model(img.unsqueeze(0)) loss = criterion(output, lbl.unsqueeze(0)) # ... Inference Pattern ~~~~~~~~~~~~~~~~~ :: quilt = NCYXQuilt(Y=512, X=512, window=(128, 128), step=(64, 64), border=(10, 10)) # Unstitch patches = quilt.unstitch(test_image) # Process with torch.no_grad(): outputs = model(patches) # Stitch result, weights = quilt.stitch(outputs) Large Dataset Pattern ~~~~~~~~~~~~~~~~~~~~~ :: quilt = LargeNCYXQuilt(filename, N=1000, Y=1024, X=1024, window=(256, 256), step=(128, 128), border=(20, 20)) # Process in chunks for i in range(quilt.N_chunks): idx, patch = quilt.unstitch_next(data) processed = model(patch.unsqueeze(0)) quilt.stitch(processed, idx) # Get results mean = quilt.return_mean() mean, std = quilt.return_mean(std=True)