Examples ======== This page contains comprehensive examples for common use cases. Example 1: Basic 2D Image Processing -------------------------------------- Complete workflow for processing 2D images:: import torch from qlty import NCYXQuilt # Setup quilt = NCYXQuilt( Y=256, X=256, window=(64, 64), step=(32, 32), # 50% overlap border=(8, 8), border_weight=0.1 ) # Load data images = torch.randn(20, 3, 256, 256) # Split into patches patches = quilt.unstitch(images) print(f"Created {patches.shape[0]} patches from {images.shape[0]} images") # Process patches processed_patches = your_model(patches) # Stitch back together reconstructed, weights = quilt.stitch(processed_patches) assert reconstructed.shape[0] == images.shape[0] Example 2: Training with Input-Output Pairs -------------------------------------------- Training a model on unstitched patches:: from qlty import NCYXQuilt import torch quilt = NCYXQuilt(Y=128, X=128, window=(32, 32), step=(16, 16), border=(5, 5)) # Training data input_images = torch.randn(100, 3, 128, 128) target_labels = torch.randn(100, 128, 128) # Unstitch pairs input_patches, target_patches = quilt.unstitch_data_pair(input_images, target_labels) # Training loop model.train() optimizer = torch.optim.Adam(model.parameters()) for inp, tgt in zip(input_patches, target_patches): optimizer.zero_grad() output = model(inp.unsqueeze(0)) loss = criterion(output, tgt.unsqueeze(0)) loss.backward() optimizer.step() Example 3: Large Dataset with Disk Caching -------------------------------------------- Processing datasets too large for memory:: from qlty import LargeNCYXQuilt import torch import tempfile import os # Setup temp_dir = tempfile.mkdtemp() filename = os.path.join(temp_dir, "large_dataset") quilt = LargeNCYXQuilt( filename=filename, N=1000, # 1000 images Y=1024, X=1024, # Large images window=(256, 256), step=(128, 128), border=(20, 20), border_weight=0.1 ) # Load data (or iterate through dataset) data = torch.randn(1000, 3, 1024, 1024) # Process all chunks print(f"Processing {quilt.N_chunks} chunks...") for i in range(quilt.N_chunks): if i % 100 == 0: print(f"Progress: {i}/{quilt.N_chunks}") index, patch = quilt.unstitch_next(data) # Process patch with torch.no_grad(): processed = model(patch.unsqueeze(0)) # Accumulate quilt.stitch(processed, index) # Get final results mean_result = quilt.return_mean() mean_result, std_result = quilt.return_mean(std=True) print(f"Final shape: {mean_result.shape}") # Cleanup for suffix in ["_mean_cache.zarr", "_std_cache.zarr", "_norma_cache.zarr", "_mean.zarr", "_std.zarr"]: path = filename + suffix if os.path.exists(path): import shutil shutil.rmtree(path) Example 4: Handling Sparse/Missing Data ---------------------------------------- Filtering out patches with no valid data:: from qlty import NCYXQuilt, weed_sparse_classification_training_pairs_2D quilt = NCYXQuilt(Y=128, X=128, window=(32, 32), step=(16, 16), border=(5, 5)) # Data with missing labels input_data = torch.randn(50, 3, 128, 128) labels = torch.ones(50, 128, 128) * (-1) # All missing initially # Add some valid data labels[:, 30:98, 30:98] = torch.randint(0, 10, (50, 68, 68)).float() # Unstitch input_patches, label_patches = quilt.unstitch_data_pair( input_data, labels, missing_label=-1 ) print(f"Total patches: {input_patches.shape[0]}") # Filter valid patches border_tensor = quilt.border_tensor() valid_input, valid_labels, removed_mask = weed_sparse_classification_training_pairs_2D( input_patches, label_patches, missing_label=-1, border_tensor=border_tensor ) print(f"Valid patches: {valid_input.shape[0]}") print(f"Removed patches: {removed_mask.sum().item()}") Example 5: 3D Volume Processing -------------------------------- Processing 3D medical imaging or microscopy data:: from qlty import NCZYXQuilt import torch quilt = NCZYXQuilt( Z=128, Y=128, X=128, window=(64, 64, 64), step=(32, 32, 32), # 50% overlap in each dimension border=(8, 8, 8), border_weight=0.1 ) # 3D volume data volumes = torch.randn(10, 1, 128, 128, 128) # (N, C, Z, Y, X) # Process patches = quilt.unstitch(volumes) print(f"Created {patches.shape[0]} patches from {volumes.shape[0]} volumes") # Process with 3D model processed = your_3d_model(patches) # Stitch back reconstructed, weights = quilt.stitch(processed) assert reconstructed.shape == volumes.shape Example 6: Inference with Softmax Handling ------------------------------------------- Correct way to handle softmax when stitching:: from qlty import NCYXQuilt import torch.nn.functional as F quilt = NCYXQuilt(Y=256, X=256, window=(64, 64), step=(32, 32), border=(8, 8)) image = torch.randn(1, 3, 256, 256) patches = quilt.unstitch(image) # Process patches (get logits, NOT softmax) with torch.no_grad(): logits = model(patches) # Shape: (M, num_classes, 64, 64) # Stitch logits first stitched_logits, weights = quilt.stitch(logits) # THEN apply softmax probabilities = F.softmax(stitched_logits, dim=1) # This is correct! Averaging logits then softmaxing = softmax of averaged logits Example 7: Custom Border Weighting ----------------------------------- Experimenting with different border weights:: from qlty import NCYXQuilt # Test different border weights for border_weight in [0.0, 0.1, 0.5, 1.0]: quilt = NCYXQuilt( Y=128, X=128, window=(32, 32), step=(16, 16), border=(5, 5), border_weight=border_weight ) data = torch.randn(5, 3, 128, 128) patches = quilt.unstitch(data) reconstructed, weights = quilt.stitch(patches) # Evaluate reconstruction quality error = torch.mean(torch.abs(reconstructed - data)) print(f"Border weight {border_weight}: Error = {error:.6f}") Example 8: Batch Processing for Efficiency ------------------------------------------- Processing patches in batches for better GPU utilization:: from qlty import NCYXQuilt import torch quilt = NCYXQuilt(Y=512, X=512, window=(128, 128), step=(64, 64), border=(10, 10)) image = torch.randn(1, 3, 512, 512) patches = quilt.unstitch(image) # Process in batches batch_size = 32 processed_patches = [] for i in range(0, len(patches), batch_size): batch = patches[i:i+batch_size] with torch.no_grad(): output = model(batch) processed_patches.append(output) processed_patches = torch.cat(processed_patches, dim=0) result, weights = quilt.stitch(processed_patches) Example 9: Combining with DataLoaders -------------------------------------- Integrating with PyTorch DataLoaders:: from torch.utils.data import Dataset, DataLoader from qlty import NCYXQuilt class PatchedDataset(Dataset): def __init__(self, images, labels, quilt): self.quilt = quilt self.input_patches, self.label_patches = quilt.unstitch_data_pair( images, labels ) def __len__(self): return len(self.input_patches) def __getitem__(self, idx): return self.input_patches[idx], self.label_patches[idx] # Create dataset images = torch.randn(100, 3, 128, 128) labels = torch.randn(100, 128, 128) quilt = NCYXQuilt(Y=128, X=128, window=(32, 32), step=(16, 16), border=(5, 5)) dataset = PatchedDataset(images, labels, quilt) dataloader = DataLoader(dataset, batch_size=64, shuffle=True) # Train for batch_input, batch_labels in dataloader: # Training code... pass Example 10: Error Handling and Validation ------------------------------------------ Proper error handling:: from qlty import NCYXQuilt import torch # Valid usage try: quilt = NCYXQuilt( Y=128, X=128, window=(32, 32), step=(16, 16), border=(5, 5), border_weight=0.1 ) print("✓ Quilt created successfully") except ValueError as e: print(f"✗ Error: {e}") # Invalid border_weight try: quilt = NCYXQuilt(Y=128, X=128, window=(32, 32), step=(16, 16), border=(5, 5), border_weight=2.0) # Invalid! except ValueError as e: print(f"✓ Caught error: {e}") # Invalid border dimensions try: quilt = NCYXQuilt(Y=128, X=128, window=(32, 32), step=(16, 16), border=(1, 2, 3)) # Wrong size for 2D! except ValueError as e: print(f"✓ Caught error: {e}") Example 11: Pre-Tokenization for Patch Processing (2D) -------------------------------------------------------- **What and Why**: The ``pretokenizer_2d`` module prepares patches for tokenization by enabling sequence-based models (like transformers) to work with image patches. This is useful for: - **Self-supervised learning**: Learning representations from patch pairs with known geometric relationships - **Contrastive learning**: Using overlapping tokens as positive pairs - **Sequence models**: Converting 2D patches into token sequences with spatial awareness - **Efficient batch processing**: Processing many patch pairs in parallel with numba acceleration The key innovation is that it identifies which tokens overlap between two patches that have undergone a known rigid transformation (translation + rotation), providing the overlap information needed for training sequence-based models. **Basic Usage - Single Patch Pair**:: from qlty import extract_patch_pairs, build_sequence_pair, tokenize_patch import torch # Step 1: Extract patch pairs using qlty's existing functionality images = torch.randn(5, 3, 128, 128) patches1, patches2, deltas, rotations = extract_patch_pairs( images, window=(64, 64), num_patches=10, delta_range=(10.0, 20.0), random_seed=42 ) # Step 2: Build sequence pairs with overlap information # This tokenizes both patches and finds overlapping tokens result = build_sequence_pair( patches1[0], # First patch: (3, 64, 64) patches2[0], # Second patch: (3, 64, 64) dx=deltas[0, 0].item(), # Translation in x dy=deltas[0, 1].item(), # Translation in y rot_k90=rotations[0].item(), # Rotation (0, 1, 2, or 3 for 0°, 90°, 180°, 270°) patch_size=16, # Size of each token stride=8 # Stride for overlapping tokens (default: patch_size//2) ) # Result contains: print(f"Tokens from patch1: {result['tokens1'].shape}") # (T, D) where T=number of tokens print(f"Tokens from patch2: {result['tokens2'].shape}") # (T, D) print(f"Overlapping tokens: {result['overlap_mask1'].sum().item()} out of {result['tokens1'].shape[0]}") # Use for training: # - tokens1, tokens2: Input to your sequence model (e.g., transformer) # - coords1, coords2: Absolute coordinates for positional encoding # - overlap_mask1, overlap_mask2: Which tokens have corresponding overlaps # - overlap_indices1_to_2: Mapping from patch1 tokens to patch2 tokens # - overlap_fractions: How much each token overlaps (0.0 to 1.0) **Batch Processing - Efficient for Large Datasets**:: # Process all patch pairs at once (much faster!) batch_result = build_sequence_pair( patches1, # (50, 3, 64, 64) - batch of patches patches2, # (50, 3, 64, 64) dx=deltas[:, 0], # (50,) - x translations dy=deltas[:, 1], # (50,) - y translations rot_k90=rotations, # (50,) - rotations patch_size=16, stride=8 ) # Batch result has padded tensors for efficient processing print(f"Batch tokens1: {batch_result['tokens1'].shape}") # (50, T_max, D) print(f"Sequence lengths: {batch_result['sequence_lengths']}") # (50,) - actual lengths print(f"Overlap counts: {batch_result['overlap_pair_counts']}") # (50,) - overlaps per pair # Use sequence_lengths to mask padding in your model # Use overlap_pair_counts to understand data distribution **Tokenization Only - When You Just Need Tokens**:: # If you only need to tokenize a patch (no overlap computation) patch = torch.randn(3, 64, 64) tokens, coords = tokenize_patch(patch, patch_size=16, stride=8) print(f"Created {tokens.shape[0]} tokens") print(f"Token shape: {tokens.shape[1]}") # 3*16*16 = 768 dimensions print(f"Coordinates shape: {coords.shape}") # (T, 2) - (y, x) for each token # Use tokens as input to sequence models # Use coords for positional encoding **Real-World Use Case - Self-Supervised Learning**:: from qlty import extract_patch_pairs, build_sequence_pair import torch import torch.nn as nn # Extract patch pairs from unlabeled images images = torch.randn(100, 3, 256, 256) patches1, patches2, deltas, rotations = extract_patch_pairs( images, window=(128, 128), num_patches=20, delta_range=(20.0, 40.0) ) # Build sequence pairs batch_result = build_sequence_pair( patches1, patches2, deltas[:, 0], deltas[:, 1], rotations, patch_size=32, stride=16 ) # Train a transformer to predict overlapping tokens class PatchTransformer(nn.Module): def __init__(self, token_dim, hidden_dim): super().__init__() self.embedding = nn.Linear(token_dim, hidden_dim) self.transformer = nn.TransformerEncoder( nn.TransformerEncoderLayer(hidden_dim, nhead=8), num_layers=6 ) self.predictor = nn.Linear(hidden_dim, token_dim) def forward(self, tokens, coords, mask): # Add positional encoding from coords pos_enc = self.positional_encoding(coords) x = self.embedding(tokens) + pos_enc x = self.transformer(x) return self.predictor(x) model = PatchTransformer(token_dim=3*32*32, hidden_dim=512) # Training loop for epoch in range(10): for i in range(0, len(patches1), 32): # Process in batches batch_idx = slice(i, i+32) result = build_sequence_pair( patches1[batch_idx], patches2[batch_idx], deltas[batch_idx, 0], deltas[batch_idx, 1], rotations[batch_idx], patch_size=32, stride=16 ) # Get overlapping tokens tokens1 = result['tokens1'] # (32, T_max, D) tokens2 = result['tokens2'] # (32, T_max, D) overlap_mask = result['overlap_mask1'] # (32, T_max) overlap_indices = result['overlap_indices1_to_2'] # (32, T_max) # Predict tokens2 from tokens1 predicted = model(tokens1, result['coords1'], overlap_mask) # Loss only on overlapping tokens # (simplified - actual implementation would handle padding) loss = nn.functional.mse_loss( predicted[overlap_mask], tokens2[overlap_mask] ) # Backprop and update... **Performance Notes**: - **Batch processing is highly optimized**: Uses numba JIT compilation and parallel processing for large batches (N > 5) - **Automatic fallback**: Falls back to sequential processing for small batches or when numba is unavailable - **Memory efficient**: Batch tokenization reuses a single ``NCYXQuilt`` object - **GPU support**: All tensors maintain device placement (CPU/GPU) **When to Use**: - ✅ Training sequence models (transformers) on image patches - ✅ Self-supervised learning with geometric augmentations - ✅ Contrastive learning with patch pairs - ✅ Any task requiring token-level overlap information **When NOT to Use**: - ❌ Simple patch extraction (use ``NCYXQuilt.unstitch()`` instead) - ❌ Stitching patches back together (use ``NCYXQuilt.stitch()`` instead) - ❌ When you don't need overlap information