Source code for qlty.pretokenizer_2d.sequences

"""Pre-tokenization utilities for patch processing.

This module provides functions to prepare patches for tokenization by splitting them
into subpatches (tokens) and computing overlap information between patch pairs.
The actual tokenization (conversion to embeddings) is done by downstream models.
"""

from typing import Dict, Tuple, Union

import torch

from qlty.qlty2D import NCYXQuilt

# Try to import numba for JIT compilation
try:
    import numpy as np
    from numba import njit, prange

    HAS_NUMBA = True
except ImportError:
    HAS_NUMBA = False


[docs] def tokenize_patch( patch: torch.Tensor, patch_size: int, stride: int = None ) -> Tuple[torch.Tensor, torch.Tensor]: """ Pre-tokenize a patch by splitting it into a sequence of subpatches with absolute coordinates. This is a pre-tokenization step that splits the patch into square subpatches (potentially overlapping) using a sliding window approach. The subpatches are returned as a sequence with their absolute coordinates within the patch. These subpatches can then be tokenized (converted to embeddings) by downstream models. Subpatches are extracted such that they never extend beyond patch boundaries. This function uses qlty's NCYXQuilt framework for patch extraction, ensuring consistency with the rest of the qlty codebase. Parameters ---------- patch : torch.Tensor Input patch of shape (C, H, W) where: - C: Number of channels - H: Height of patch - W: Width of patch patch_size : int Size of each token in pixels stride : int, optional Stride for sliding window extraction. Defaults to patch_size // 2. Must be positive and <= patch_size. Returns ------- Tuple[torch.Tensor, torch.Tensor] A tuple containing: - tokens: Tensor of shape (T, C * patch_size * patch_size) where T is the number of tokens. Tokens are in row-major order. - coords: Tensor of shape (T, 2) containing absolute pixel coordinates (y, x) of the top-left corner of each token within the patch. Examples -------- >>> patch = torch.randn(3, 16, 16) # 3 channels, 16x16 patch >>> tokens, coords = tokenize_patch(patch, patch_size=4) >>> print(tokens.shape) # (25, 48) - 25 tokens with stride=2, each 3*4*4=48 dims >>> print(coords.shape) # (25, 2) - coordinates for each token >>> # coords[0] = [0, 0] for top-left token >>> # coords[1] = [0, 2] for next token to the right (with stride=2) """ if len(patch.shape) != 3: raise ValueError(f"patch must be 3D (C, H, W), got shape {patch.shape}") C, H, W = patch.shape if patch_size <= 0: raise ValueError(f"patch_size must be positive, got {patch_size}") if stride is None: stride = patch_size // 2 if stride <= 0: raise ValueError(f"stride must be positive, got {stride}") if stride > patch_size: raise ValueError(f"stride ({stride}) must be <= patch_size ({patch_size})") if patch_size > H or patch_size > W: raise ValueError( f"patch_size ({patch_size}) must be <= patch dimensions ({H}, {W})" ) # Use qlty's NCYXQuilt framework for patch extraction quilt = NCYXQuilt( Y=H, X=W, window=(patch_size, patch_size), step=(stride, stride), border=None, # No border weighting needed for tokenization ) # Add batch dimension: (C, H, W) -> (1, C, H, W) patch_batch = patch.unsqueeze(0) # Extract patches using qlty's unstitch: (1, C, H, W) -> (T, C, patch_size, patch_size) patches = quilt.unstitch(patch_batch) # Flatten patches: (T, C, patch_size, patch_size) -> (T, C * patch_size * patch_size) T = patches.shape[0] tokens = patches.view(T, C * patch_size * patch_size) # Compute coordinates using the same logic as NCYXQuilt.unstitch() # This ensures consistency with how patches are extracted coords_list = [] nY, nX = quilt.get_times() for yy in range(nY): for xx in range(nX): start_y = min(yy * stride, H - patch_size) start_x = min(xx * stride, W - patch_size) coords_list.append([start_y, start_x]) coords = torch.tensor(coords_list, dtype=torch.long, device=patch.device) return tokens, coords
def _find_overlapping_tokens( coords1: torch.Tensor, coords2: torch.Tensor, dx: float, dy: float, rot_k90: int, patch_size: int, patch_shape: Tuple[int, int], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Find overlapping tokens between two patches given their geometric relationship. Parameters ---------- coords1 : torch.Tensor Token coordinates from patch1, shape (T1, 2) with (y, x) pixel coordinates coords2 : torch.Tensor Token coordinates from patch2, shape (T2, 2) with (y, x) pixel coordinates dx : float Translation in pixels along x-axis dy : float Translation in pixels along y-axis rot_k90 : int Rotation applied to patch2 in 90-degree increments (0, 1, 2, or 3) patch_size : int Size of each token in pixels patch_shape : Tuple[int, int] Shape of the patch (H, W) in pixels Returns ------- Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] A tuple containing: - overlap_mask1: Boolean tensor of shape (T1,) indicating which tokens in patch1 have overlaps in patch2 - overlap_indices1_to_2: Long tensor of shape (T1,) where overlap_indices1_to_2[i] is the index into coords2/tokens2 for the overlapping token, or -1 if no overlap exists - overlap_mask2: Boolean tensor of shape (T2,) indicating which tokens in patch2 have overlaps in patch1 - overlap_indices2_to_1: Long tensor of shape (T2,) where overlap_indices2_to_1[j] is the index into coords1/tokens1 for the overlapping token, or -1 if no overlap exists - overlap_fractions: Float tensor of shape (T1,) containing the fraction of overlap for each token in patch1 (0.0 to 1.0), or 0.0 if no overlap """ T1 = coords1.shape[0] T2 = coords2.shape[0] H, W = patch_shape rot_k90 = rot_k90 % 4 overlap_mask1 = torch.zeros(T1, dtype=torch.bool, device=coords1.device) overlap_indices1_to_2 = torch.full( (T1,), -1, dtype=torch.long, device=coords1.device ) overlap_mask2 = torch.zeros(T2, dtype=torch.bool, device=coords2.device) overlap_indices2_to_1 = torch.full( (T2,), -1, dtype=torch.long, device=coords2.device ) overlap_fractions = torch.zeros(T1, dtype=torch.float32, device=coords1.device) token_area = float(patch_size * patch_size) # For each token in patch1, find if it overlaps with any token in patch2 for i in range(T1): y1 = coords1[i, 0].item() x1 = coords1[i, 1].item() # Token region in patch1: [y1, y1+patch_size) x [x1, x1+patch_size) # Transform the four corners of this token to patch2's coordinate system corners1 = [ (x1, y1), # top-left (x1 + patch_size, y1), # top-right (x1, y1 + patch_size), # bottom-left (x1 + patch_size, y1 + patch_size), # bottom-right ] # Transform corners to patch2's coordinate system corners2 = [] for x, y in corners1: # Apply inverse translation x_unrot = x - dx y_unrot = y - dy # Apply inverse rotation if rot_k90 == 0: x2 = x_unrot y2 = y_unrot elif rot_k90 == 1: # 90° clockwise rotation: (x, y) -> (y, W - x) # Inverse: (x, y) -> (W - y, x) x2 = W - y_unrot y2 = x_unrot elif rot_k90 == 2: # 180° rotation: (x, y) -> (W - x, H - y) x2 = W - x_unrot y2 = H - y_unrot elif rot_k90 == 3: # 270° clockwise: (x, y) -> (H - y, x) # Inverse: (x, y) -> (y, H - x) x2 = y_unrot y2 = H - x_unrot else: raise ValueError(f"Invalid rotation: {rot_k90}") corners2.append((x2, y2)) # Find bounding box of transformed token in patch2 x2_min = min(x for x, y in corners2) x2_max = max(x for x, y in corners2) y2_min = min(y for x, y in corners2) y2_max = max(y for x, y in corners2) # Check which tokens in patch2 overlap with this bounding box best_overlap = 0.0 best_j = -1 for j in range(coords2.shape[0]): y2 = coords2[j, 0].item() x2 = coords2[j, 1].item() # Token region in patch2: [y2, y2+patch_size) x [x2, x2+patch_size) # Compute intersection with transformed token from patch1 # Intersection in patch2 coordinates intersect_x_min = max(x2_min, x2) intersect_x_max = min(x2_max, x2 + patch_size) intersect_y_min = max(y2_min, y2) intersect_y_max = min(y2_max, y2 + patch_size) if intersect_x_min < intersect_x_max and intersect_y_min < intersect_y_max: # There is an intersection # Transform intersection back to patch1 coordinates to compute area # For simplicity, approximate using the intersection in patch2 # (this is exact for integer translations and rotations) intersect_area = (intersect_x_max - intersect_x_min) * ( intersect_y_max - intersect_y_min ) overlap_frac = intersect_area / token_area if overlap_frac > best_overlap: best_overlap = overlap_frac best_j = j if best_overlap > 0.0: overlap_mask1[i] = True overlap_indices1_to_2[i] = best_j overlap_fractions[i] = best_overlap # Also set reverse mapping (use same fraction since tokens are same size) if not overlap_mask2[best_j]: overlap_mask2[best_j] = True overlap_indices2_to_1[best_j] = i return ( overlap_mask1, overlap_indices1_to_2, overlap_mask2, overlap_indices2_to_1, overlap_fractions, ) if HAS_NUMBA: @njit(parallel=True, fastmath=True) def _numba_find_overlaps_batch( coords: np.ndarray, # (T, 2) float64 dx: np.ndarray, # (N,) float32 dy: np.ndarray, # (N,) float32 rot_k90: np.ndarray, # (N,) int64 patch_size: int, H: int, W: int, overlap_mask1_out: np.ndarray, # (N, T) bool overlap_indices1_to_2_out: np.ndarray, # (N, T) int64 overlap_mask2_out: np.ndarray, # (N, T) bool overlap_indices2_to_1_out: np.ndarray, # (N, T) int64 overlap_fractions_out: np.ndarray, # (N, T) float32 ): """ Numba-accelerated batch overlap computation. Processes all pairs in parallel using prange. """ N = dx.shape[0] T = coords.shape[0] token_area = float(patch_size * patch_size) # Process each patch pair in parallel for n in prange(N): dx_val = dx[n] dy_val = dy[n] rot = int(rot_k90[n]) % 4 # For each token in patch1, find best overlap in patch2 for i in range(T): y1 = coords[i, 0] x1 = coords[i, 1] # Transform four corners to patch2's coordinate system # Corner coordinates: top-left, top-right, bottom-left, bottom-right x1_tl = x1 y1_tl = y1 x1_tr = x1 + patch_size y1_tr = y1 x1_bl = x1 y1_bl = y1 + patch_size x1_br = x1 + patch_size y1_br = y1 + patch_size # Apply inverse translation and rotation to each corner # Transform corner 1 (top-left) x_unrot_tl = x1_tl - dx_val y_unrot_tl = y1_tl - dy_val if rot == 0: x2_tl, y2_tl = x_unrot_tl, y_unrot_tl elif rot == 1: x2_tl, y2_tl = W - y_unrot_tl, x_unrot_tl elif rot == 2: x2_tl, y2_tl = W - x_unrot_tl, H - y_unrot_tl elif rot == 3: x2_tl, y2_tl = y_unrot_tl, H - x_unrot_tl else: x2_tl, y2_tl = x_unrot_tl, y_unrot_tl # Transform corner 2 (top-right) x_unrot_tr = x1_tr - dx_val y_unrot_tr = y1_tr - dy_val if rot == 0: x2_tr, y2_tr = x_unrot_tr, y_unrot_tr elif rot == 1: x2_tr, y2_tr = W - y_unrot_tr, x_unrot_tr elif rot == 2: x2_tr, y2_tr = W - x_unrot_tr, H - y_unrot_tr elif rot == 3: x2_tr, y2_tr = y_unrot_tr, H - x_unrot_tr else: x2_tr, y2_tr = x_unrot_tr, y_unrot_tr # Transform corner 3 (bottom-left) x_unrot_bl = x1_bl - dx_val y_unrot_bl = y1_bl - dy_val if rot == 0: x2_bl, y2_bl = x_unrot_bl, y_unrot_bl elif rot == 1: x2_bl, y2_bl = W - y_unrot_bl, x_unrot_bl elif rot == 2: x2_bl, y2_bl = W - x_unrot_bl, H - y_unrot_bl elif rot == 3: x2_bl, y2_bl = y_unrot_bl, H - x_unrot_bl else: x2_bl, y2_bl = x_unrot_bl, y_unrot_bl # Transform corner 4 (bottom-right) x_unrot_br = x1_br - dx_val y_unrot_br = y1_br - dy_val if rot == 0: x2_br, y2_br = x_unrot_br, y_unrot_br elif rot == 1: x2_br, y2_br = W - y_unrot_br, x_unrot_br elif rot == 2: x2_br, y2_br = W - x_unrot_br, H - y_unrot_br elif rot == 3: x2_br, y2_br = y_unrot_br, H - x_unrot_br else: x2_br, y2_br = x_unrot_br, y_unrot_br # Find bounding box x2_min = min(x2_tl, x2_tr, x2_bl, x2_br) x2_max = max(x2_tl, x2_tr, x2_bl, x2_br) y2_min = min(y2_tl, y2_tr, y2_bl, y2_br) y2_max = max(y2_tl, y2_tr, y2_bl, y2_br) # Find best overlap with tokens in patch2 best_overlap = 0.0 best_j = -1 for j in range(T): y2 = coords[j, 0] x2 = coords[j, 1] # Compute intersection intersect_x_min = max(x2_min, x2) intersect_x_max = min(x2_max, x2 + patch_size) intersect_y_min = max(y2_min, y2) intersect_y_max = min(y2_max, y2 + patch_size) if ( intersect_x_min < intersect_x_max and intersect_y_min < intersect_y_max ): intersect_area = (intersect_x_max - intersect_x_min) * ( intersect_y_max - intersect_y_min ) overlap_frac = intersect_area / token_area if overlap_frac > best_overlap: best_overlap = overlap_frac best_j = j # Store results if best_overlap > 0.0: overlap_mask1_out[n, i] = True overlap_indices1_to_2_out[n, i] = best_j overlap_fractions_out[n, i] = best_overlap # Set reverse mapping (only if not already set by another token) if not overlap_mask2_out[n, best_j]: overlap_mask2_out[n, best_j] = True overlap_indices2_to_1_out[n, best_j] = i def _to_tensor_batch( value, N: int, dtype: torch.dtype, device: torch.device, name: str ) -> torch.Tensor: """ Convert a value (scalar, tensor, or numpy array) to a batched tensor. Parameters ---------- value : scalar, torch.Tensor, or numpy.ndarray Input value to convert N : int Batch size dtype : torch.dtype Target dtype device : torch.device Target device name : str Name of parameter (for error messages) Returns ------- torch.Tensor Tensor of shape (N,) on the specified device """ if isinstance(value, torch.Tensor): value = value.to(device) if value.shape[0] != N: raise ValueError( f"{name} must have shape ({N},) or be scalar, got {value.shape}" ) return value # Handle numpy arrays or scalars try: import numpy as np if isinstance(value, np.ndarray): value = torch.from_numpy(value).to(device) if value.shape[0] != N: raise ValueError( f"{name} must have shape ({N},) or be scalar, got {value.shape}" ) return value except ImportError: pass # Scalar: broadcast to batch return torch.tensor([value] * N, dtype=dtype, device=device)
[docs] def build_sequence_pair( patch1: torch.Tensor, patch2: torch.Tensor, dx: Union[float, torch.Tensor], dy: Union[float, torch.Tensor], rot_k90: Union[int, torch.Tensor], patch_size: int, stride: int = None, ) -> Dict[str, torch.Tensor]: """ Build sequence pair from two patches with overlap information. This function pre-tokenizes both patches (splits them into subpatches), finds overlapping subpatches, and returns sequences with absolute coordinates suitable for downstream tokenization and embedding methods. Supports both single patches and batches: - Single: patch1/patch2 shape (C, H, W), dx/dy/rot_k90 are scalars - Batch: patch1/patch2 shape (N, C, H, W), dx/dy/rot_k90 are (N,) tensors Parameters ---------- patch1 : torch.Tensor First patch of shape (C, H, W) or batch of shape (N, C, H, W) patch2 : torch.Tensor Second patch of shape (C, H, W) or batch of shape (N, C, H, W) dx : float or torch.Tensor Translation in pixels along x-axis. Scalar for single patch, (N,) tensor for batch. dy : float or torch.Tensor Translation in pixels along y-axis. Scalar for single patch, (N,) tensor for batch. rot_k90 : int or torch.Tensor Rotation applied to patch2 in 90-degree increments (0, 1, 2, or 3). Scalar for single patch, (N,) tensor for batch. patch_size : int Size of each token in pixels stride : int, optional Stride for sliding window token extraction. Defaults to patch_size // 2. Must be positive and <= patch_size. Returns ------- Dict[str, torch.Tensor] Dictionary containing: For single patch: - "tokens1": Token vectors from patch1, shape (T1, D) - "tokens2": Token vectors from patch2, shape (T2, D) - "coords1": Absolute pixel coordinates (y, x) for patch1 tokens, shape (T1, 2) - "coords2": Absolute pixel coordinates (y, x) for patch2 tokens, shape (T2, 2) - "overlap_mask1": Boolean mask indicating which patch1 tokens overlap, shape (T1,) - "overlap_mask2": Boolean mask indicating which patch2 tokens overlap, shape (T2,) - "overlap_indices1_to_2": Mapping from patch1 to patch2 tokens, shape (T1,) - "overlap_indices2_to_1": Mapping from patch2 to patch1 tokens, shape (T2,) - "overlap_fractions": Fraction of overlap for each patch1 token (0.0 to 1.0), shape (T1,) - "overlap_pairs": Tensor of shape (N_overlaps, 2) containing (i, j) pairs For batch (all tensors are padded to max length): - "tokens1": Token vectors from patch1, shape (N, T_max, D) - "tokens2": Token vectors from patch2, shape (N, T_max, D) - "coords1": Absolute pixel coordinates for patch1 tokens, shape (N, T_max, 2) - "coords2": Absolute pixel coordinates for patch2 tokens, shape (N, T_max, 2) - "overlap_mask1": Boolean mask for patch1 tokens, shape (N, T_max) - "overlap_mask2": Boolean mask for patch2 tokens, shape (N, T_max) - "overlap_indices1_to_2": Mapping from patch1 to patch2, shape (N, T_max), -1 for padding - "overlap_indices2_to_1": Mapping from patch2 to patch1, shape (N, T_max), -1 for padding - "overlap_fractions": Fraction of overlap, shape (N, T_max) - "overlap_pairs": Overlap pairs, shape (N, max_pairs, 2), -1 for padding - "sequence_lengths": Actual sequence lengths (same for both patches), shape (N,) - "overlap_pair_counts": Number of overlap pairs per sample, shape (N,) """ # Check if inputs are batched is_batched = len(patch1.shape) == 4 and len(patch2.shape) == 4 if is_batched: # Batch processing N, C1, H1, W1 = patch1.shape N2, C2, H2, W2 = patch2.shape if N != N2: raise ValueError( f"Batch sizes must match: patch1 has {N} patches, patch2 has {N2}" ) if C1 != C2 or H1 != H2 or W1 != W2: raise ValueError( f"Patches must have same shape, got {patch1.shape} and {patch2.shape}" ) # Convert dx, dy, rot_k90 to tensors if needed dx = _to_tensor_batch(dx, N, torch.float32, patch1.device, "dx") dy = _to_tensor_batch(dy, N, torch.float32, patch1.device, "dy") rot_k90 = _to_tensor_batch(rot_k90, N, torch.int64, patch1.device, "rot_k90") # OPTIMIZATION: Batch tokenize all patches at once # Since all patches have the same shape, we can use a single quilt object C, H, W = patch1.shape[1:] # Determine stride (same logic as tokenize_patch) if stride is None: stride_val = patch_size // 2 else: stride_val = stride # Create quilt object once (same for all patches) quilt = NCYXQuilt( Y=H, X=W, window=(patch_size, patch_size), step=(stride_val, stride_val), border=None, ) # Batch tokenize: (N, C, H, W) -> (N*T, C, patch_size, patch_size) patches1_flat = quilt.unstitch(patch1) # (N*T, C, patch_size, patch_size) patches2_flat = quilt.unstitch(patch2) # (N*T, C, patch_size, patch_size) # Get number of tokens per patch nY, nX = quilt.get_times() T = nY * nX # Same for all patches # Compute coordinates once (same for all patches) # Use the same logic as NCYXQuilt.unstitch() to ensure consistency coords_list = [] for yy in range(nY): for xx in range(nX): start_y = min(yy * stride_val, H - patch_size) start_x = min(xx * stride_val, W - patch_size) coords_list.append([start_y, start_x]) coords = torch.tensor( coords_list, dtype=torch.long, device=patch1.device ) # (T, 2) # Flatten patches to tokens: (N*T, C, patch_size, patch_size) -> (N*T, C*patch_size*patch_size) D = C * patch_size * patch_size tokens1_flat = patches1_flat.view(N * T, D) # (N*T, D) tokens2_flat = patches2_flat.view(N * T, D) # (N*T, D) # Reshape to (N, T, D) tokens1_batch = tokens1_flat.view(N, T, D) tokens2_batch = tokens2_flat.view(N, T, D) # Expand coordinates for all patches: (T, 2) -> (N, T, 2) coords1_batch = coords.unsqueeze(0).expand(N, -1, -1) coords2_batch = coords.unsqueeze(0).expand(N, -1, -1) # Initialize overlap tensors overlap_mask1_batch = torch.zeros( (N, T), dtype=torch.bool, device=patch1.device ) overlap_mask2_batch = torch.zeros( (N, T), dtype=torch.bool, device=patch1.device ) overlap_indices1_to_2_batch = torch.full( (N, T), -1, dtype=torch.long, device=patch1.device ) overlap_indices2_to_1_batch = torch.full( (N, T), -1, dtype=torch.long, device=patch1.device ) overlap_fractions_batch = torch.zeros( (N, T), dtype=torch.float32, device=patch1.device ) # Process overlaps - use numba-accelerated batch processing if available # Otherwise fall back to sequential or threading use_numba = HAS_NUMBA and N > 5 # Use numba for batches larger than 5 if use_numba: # Use numba-accelerated batch processing with parallel execution # Convert tensors to numpy for numba coords_np = coords.cpu().numpy().astype(np.float64) dx_np = dx.cpu().numpy().astype(np.float32) dy_np = dy.cpu().numpy().astype(np.float32) rot_k90_np = rot_k90.cpu().numpy().astype(np.int64) # Initialize output arrays overlap_mask1_np = np.zeros((N, T), dtype=np.bool_) overlap_indices1_to_2_np = np.full((N, T), -1, dtype=np.int64) overlap_mask2_np = np.zeros((N, T), dtype=np.bool_) overlap_indices2_to_1_np = np.full((N, T), -1, dtype=np.int64) overlap_fractions_np = np.zeros((N, T), dtype=np.float32) # Run numba-accelerated batch computation _numba_find_overlaps_batch( coords_np, dx_np, dy_np, rot_k90_np, patch_size, H, W, overlap_mask1_np, overlap_indices1_to_2_np, overlap_mask2_np, overlap_indices2_to_1_np, overlap_fractions_np, ) # Convert back to PyTorch tensors on the original device overlap_mask1_batch = torch.from_numpy(overlap_mask1_np).to(patch1.device) overlap_mask2_batch = torch.from_numpy(overlap_mask2_np).to(patch1.device) overlap_indices1_to_2_batch = torch.from_numpy(overlap_indices1_to_2_np).to( patch1.device ) overlap_indices2_to_1_batch = torch.from_numpy(overlap_indices2_to_1_np).to( patch1.device ) overlap_fractions_batch = torch.from_numpy(overlap_fractions_np).to( patch1.device ) # Build overlap pairs (vectorized for all pairs) overlap_pairs_all = [] for i in range(N): mask = overlap_mask1_batch[i] if mask.any(): indices1 = torch.arange(T, device=patch1.device)[mask] indices2 = overlap_indices1_to_2_batch[i][mask] pairs = torch.stack([indices1, indices2], dim=1) overlap_pairs_all.append(pairs) else: overlap_pairs_all.append( torch.empty((0, 2), dtype=torch.long, device=patch1.device) ) else: # Sequential processing (for small batches or when numba unavailable) overlap_pairs_all = [] for i in range(N): ( overlap_mask1, overlap_indices1_to_2, overlap_mask2, overlap_indices2_to_1, overlap_fractions, ) = _find_overlapping_tokens( coords, coords, dx[i].item(), dy[i].item(), rot_k90[i].item(), patch_size, (H, W), ) overlap_mask1_batch[i] = overlap_mask1 overlap_mask2_batch[i] = overlap_mask2 overlap_indices1_to_2_batch[i] = overlap_indices1_to_2 overlap_indices2_to_1_batch[i] = overlap_indices2_to_1 overlap_fractions_batch[i] = overlap_fractions # Build overlap pairs (vectorized) mask = overlap_mask1 if mask.any(): indices1 = torch.arange(T, device=patch1.device)[mask] indices2 = overlap_indices1_to_2[mask] pairs = torch.stack([indices1, indices2], dim=1) overlap_pairs_all.append(pairs) else: overlap_pairs_all.append( torch.empty((0, 2), dtype=torch.long, device=patch1.device) ) # Find maximum number of overlap pairs max_pairs = ( max(pairs.shape[0] for pairs in overlap_pairs_all) if overlap_pairs_all else 0 ) # Create overlap_pairs_batch tensor if max_pairs == 0: overlap_pairs_batch = torch.empty( (N, 0, 2), dtype=torch.long, device=patch1.device ) else: overlap_pairs_batch = torch.full( (N, max_pairs, 2), -1, dtype=torch.long, device=patch1.device ) for i, pairs in enumerate(overlap_pairs_all): num_pairs = pairs.shape[0] if num_pairs > 0: overlap_pairs_batch[i, :num_pairs] = pairs # Create sequence lengths and pair counts sequence_lengths = torch.full((N,), T, dtype=torch.long, device=patch1.device) overlap_pair_counts = torch.tensor( [pairs.shape[0] for pairs in overlap_pairs_all], dtype=torch.long, device=patch1.device, ) return { "tokens1": tokens1_batch, # (N, T_max, D) "tokens2": tokens2_batch, # (N, T_max, D) "coords1": coords1_batch, # (N, T_max, 2) "coords2": coords2_batch, # (N, T_max, 2) "overlap_mask1": overlap_mask1_batch, # (N, T_max) "overlap_mask2": overlap_mask2_batch, # (N, T_max) "overlap_indices1_to_2": overlap_indices1_to_2_batch, # (N, T_max) "overlap_indices2_to_1": overlap_indices2_to_1_batch, # (N, T_max) "overlap_fractions": overlap_fractions_batch, # (N, T_max) "overlap_pairs": overlap_pairs_batch, # (N, max_pairs, 2) "sequence_lengths": sequence_lengths, # (N,) - actual sequence lengths (same for both patches) "overlap_pair_counts": overlap_pair_counts, # (N,) - number of overlap pairs per sample } else: # Single patch processing (original behavior) if len(patch1.shape) != 3 or len(patch2.shape) != 3: raise ValueError( f"Both patches must be 3D (C, H, W) or 4D (N, C, H, W), " f"got shapes {patch1.shape} and {patch2.shape}" ) C1, H1, W1 = patch1.shape C2, H2, W2 = patch2.shape if C1 != C2 or H1 != H2 or W1 != W2: raise ValueError( f"Patches must have same shape, got {patch1.shape} and {patch2.shape}" ) # Convert scalars to floats/ints if they're tensors if isinstance(dx, torch.Tensor): dx = dx.item() if isinstance(dy, torch.Tensor): dy = dy.item() if isinstance(rot_k90, torch.Tensor): rot_k90 = rot_k90.item() # Tokenize both patches tokens1, coords1 = tokenize_patch(patch1, patch_size, stride=stride) tokens2, coords2 = tokenize_patch(patch2, patch_size, stride=stride) # Find overlapping tokens ( overlap_mask1, overlap_indices1_to_2, overlap_mask2, overlap_indices2_to_1, overlap_fractions, ) = _find_overlapping_tokens( coords1, coords2, dx, dy, rot_k90, patch_size, (H1, W1) ) # Build list of overlap pairs: [(i, j), ...] where token i in patch1 overlaps with token j in patch2 overlap_pairs = [] for i in range(overlap_mask1.shape[0]): if overlap_mask1[i]: j = overlap_indices1_to_2[i].item() overlap_pairs.append((i, j)) overlap_pairs_tensor = ( torch.tensor(overlap_pairs, dtype=torch.long, device=tokens1.device) if overlap_pairs else torch.empty((0, 2), dtype=torch.long, device=tokens1.device) ) return { "tokens1": tokens1, "tokens2": tokens2, "coords1": coords1, "coords2": coords2, "overlap_mask1": overlap_mask1, "overlap_mask2": overlap_mask2, "overlap_indices1_to_2": overlap_indices1_to_2, "overlap_indices2_to_1": overlap_indices2_to_1, "overlap_fractions": overlap_fractions, # Fraction of overlap for each patch1 token (0.0 to 1.0) "overlap_pairs": overlap_pairs_tensor, # Shape (N_overlaps, 2) with (i, j) pairs }