Source code for qlty.patch_pairs_2d
"""
Extract pairs of patches from 2D image tensors with controlled displacement.
This module provides functionality to extract pairs of patches from 2D tensors
where the displacement between patch centers follows specified constraints.
"""
from typing import Optional, Sequence, Tuple
import torch
[docs]
def extract_patch_pairs(
tensor: torch.Tensor,
window: Tuple[int, int],
num_patches: int,
delta_range: Tuple[float, float],
random_seed: Optional[int] = None,
rotation_choices: Optional[Sequence[int]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Extract pairs of patches from 2D image tensors with controlled displacement.
For each image in the input tensor, this function extracts P pairs of patches.
Each pair consists of two patches: one at location (x_i, y_i) and another at
(x_i + dx_i, y_i + dy_i), where the Euclidean distance between the locations
is constrained to be within the specified delta_range.
Parameters
----------
tensor : torch.Tensor
Input tensor of shape (N, C, Y, X) where:
- N: Number of images
- C: Number of channels
- Y: Height of images
- X: Width of images
window : Tuple[int, int]
Window shape (U, V) where:
- U: Height of patches
- V: Width of patches
num_patches : int
Number of patch pairs P to extract per image
delta_range : Tuple[float, float]
Range (low, high) for the Euclidean distance of displacement vectors.
The constraint is: low <= sqrt(dx_i² + dy_i²) <= high
Additionally, low and high must satisfy: window//4 <= low <= high <= 3*window//4
where window is the maximum of U and V.
random_seed : Optional[int], optional
Random seed for reproducibility. If None, uses current random state.
Default is None.
rotation_choices : Optional[Sequence[int]], optional
Allowed quarter-turn rotations (0 = 0°, 1 = 90°, 2 = 180°, 3 = 270°) to apply
to the second patch in each pair. If provided, a rotation from this set is
sampled uniformly per pair and tracked in the returned `rotations` tensor.
When None (default), no rotations are applied.
Returns
-------
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
A tuple containing:
- patches1: Tensor of shape (N*P, C, U, V) containing patches at (x_i, y_i)
- patches2: Tensor of shape (N*P, C, U, V) containing patches at (x_i + dx_i, y_i + dy_i)
- deltas: Tensor of shape (N*P, 2) containing (dx_i, dy_i) displacement vectors
- rotations: Tensor of shape (N*P,) containing quarter-turn rotations applied to patches2
Raises
------
ValueError
If delta_range constraints are violated or image dimensions are too small
for the specified window and delta range.
Examples
--------
>>> tensor = torch.randn(5, 3, 128, 128) # 5 images, 3 channels, 128x128
>>> window = (32, 32) # 32x32 patches
>>> num_patches = 10 # 10 patch pairs per image
>>> delta_range = (8.0, 16.0) # Euclidean distance between 8 and 16 pixels
>>> patches1, patches2, deltas, rotations = extract_patch_pairs(
... tensor, window, num_patches, delta_range
... )
>>> print(patches1.shape) # (50, 3, 32, 32)
>>> print(patches2.shape) # (50, 3, 32, 32)
>>> print(deltas.shape) # (50, 2)
>>> print(rotations.shape) # (50,)
"""
# Validate input tensor shape
if len(tensor.shape) != 4:
raise ValueError(
f"Input tensor must be 4D (N, C, Y, X), got shape {tensor.shape}"
)
N, C, Y, X = tensor.shape
U, V = window
# Validate delta_range constraints
max_window = max(U, V)
window_quarter = max_window // 4
window_three_quarters = 3 * max_window // 4
low, high = delta_range
if low < window_quarter or high > window_three_quarters:
raise ValueError(
f"delta_range must satisfy: {window_quarter} <= low <= high <= {window_three_quarters}, "
f"got ({low}, {high})"
)
if low > high:
raise ValueError(f"delta_range low ({low}) must be <= high ({high})")
# Check if image is large enough for window and delta range
min_y = U + int(high)
min_x = V + int(high)
if Y < min_y or X < min_x:
raise ValueError(
f"Image dimensions ({Y}, {X}) are too small for window ({U}, {V}) "
f"and delta_range ({low}, {high}). Minimum required: ({min_y}, {min_x})"
)
# Set random seed if provided
if random_seed is not None:
generator = torch.Generator(device=tensor.device)
generator.manual_seed(random_seed)
else:
generator = None
# Pre-allocate output tensors
total_patches = N * num_patches
patches1 = torch.empty(
(total_patches, C, U, V), dtype=tensor.dtype, device=tensor.device
)
patches2 = torch.empty(
(total_patches, C, U, V), dtype=tensor.dtype, device=tensor.device
)
deltas_tensor = torch.empty(
(total_patches, 2), dtype=torch.float32, device=tensor.device
)
rotations_tensor = torch.zeros(
total_patches, dtype=torch.int64, device=tensor.device
)
if rotation_choices is None:
rotation_choices = (0,)
else:
rotation_choices = tuple(int(choice) % 4 for choice in rotation_choices)
if len(rotation_choices) == 0:
rotation_choices = (0,)
rotation_choices_tensor = torch.tensor(
rotation_choices, dtype=torch.int64, device=tensor.device
)
allow_rotations = any(choice != 0 for choice in rotation_choices)
patch_idx = 0
# Process each image
for n in range(N):
image = tensor[n] # Shape: (C, Y, X)
# Extract P patch pairs for this image
for _ in range(num_patches):
# Sample displacement vector (dx, dy) with Euclidean distance constraint
dx, dy = _sample_displacement_vector(
low, high, generator, device=tensor.device
)
# Sample first patch location (x, y) ensuring both patches fit
# Valid x range: [0, X - V - max(|dx|, 0)]
# Valid y range: [0, Y - U - max(|dy|, 0)]
# But we need to ensure both patches fit, so:
# x in [max(0, -dx), min(X - V, X - V - dx)]
# y in [max(0, -dy), min(Y - U, Y - U - dy)]
x_min = max(0, -dx)
x_max = min(X - V, X - V - dx)
y_min = max(0, -dy)
y_max = min(Y - U, Y - U - dy)
if x_min >= x_max or y_min >= y_max:
# If displacement is too large, try again with a smaller one
# This shouldn't happen often if delta_range is reasonable
attempts = 0
while (x_min >= x_max or y_min >= y_max) and attempts < 10:
dx, dy = _sample_displacement_vector(
low, high, generator, device=tensor.device
)
x_min = max(0, -dx)
x_max = min(X - V, X - V - dx)
y_min = max(0, -dy)
y_max = min(Y - U, Y - U - dy)
attempts += 1
if x_min >= x_max or y_min >= y_max:
raise ValueError(
f"Could not find valid patch locations for displacement ({dx}, {dy}) "
f"in image of size ({Y}, {X}) with window ({U}, {V})"
)
# Sample random location for first patch (keep on GPU if possible)
if generator is not None:
x = torch.randint(
x_min, x_max, (1,), generator=generator, device=tensor.device
)[0]
y = torch.randint(
y_min, y_max, (1,), generator=generator, device=tensor.device
)[0]
else:
x = torch.randint(x_min, x_max, (1,), device=tensor.device)[0]
y = torch.randint(y_min, y_max, (1,), device=tensor.device)[0]
# Convert to Python int for slicing (necessary for indexing)
x_int = int(x)
y_int = int(y)
# Extract first patch at (x, y)
patch1 = image[:, y_int : y_int + U, x_int : x_int + V] # Shape: (C, U, V)
# Extract second patch at (x + dx, y + dy)
patch2 = image[
:, y_int + dy : y_int + dy + U, x_int + dx : x_int + dx + V
] # Shape: (C, U, V)
if allow_rotations:
rotation_idx_tensor = torch.randint(
0,
rotation_choices_tensor.numel(),
(1,),
generator=generator,
device=tensor.device,
)[0]
rotation_idx = int(rotation_idx_tensor)
rotation = int(rotation_choices_tensor[rotation_idx])
else:
rotation = 0
if rotation != 0:
patch2 = torch.rot90(patch2, k=rotation, dims=(-2, -1))
# Store patches and delta directly in pre-allocated tensors
patches1[patch_idx] = patch1
patches2[patch_idx] = patch2
deltas_tensor[patch_idx, 0] = float(dx)
deltas_tensor[patch_idx, 1] = float(dy)
rotations_tensor[patch_idx] = rotation
patch_idx += 1
return patches1, patches2, deltas_tensor, rotations_tensor
def _sample_displacement_vector(
low: float,
high: float,
generator: Optional[torch.Generator] = None,
device: Optional[torch.device] = None,
) -> Tuple[int, int]:
"""
Sample a displacement vector (dx, dy) such that low <= sqrt(dx² + dy²) <= high.
Uses rejection sampling to ensure the Euclidean distance constraint is satisfied.
Parameters
----------
low : float
Minimum Euclidean distance
high : float
Maximum Euclidean distance
generator : Optional[torch.Generator]
Random number generator for reproducibility
Returns
-------
Tuple[int, int]
Displacement vector (dx, dy) as integers
"""
max_attempts = 1000
for _ in range(max_attempts):
# Sample dx and dy in a range that could potentially satisfy the constraint
# We sample from a larger range to ensure we can find valid vectors
max_delta = int(high) + 1
if device is None:
device = torch.device("cpu")
if generator is not None:
dx_tensor = torch.randint(
-max_delta, max_delta + 1, (1,), generator=generator, device=device
)
dy_tensor = torch.randint(
-max_delta, max_delta + 1, (1,), generator=generator, device=device
)
else:
dx_tensor = torch.randint(-max_delta, max_delta + 1, (1,), device=device)
dy_tensor = torch.randint(-max_delta, max_delta + 1, (1,), device=device)
dx = int(dx_tensor[0])
dy = int(dy_tensor[0])
# Check Euclidean distance constraint
distance = (dx**2 + dy**2) ** 0.5
if low <= distance <= high:
return dx, dy
# If we couldn't find a valid vector after many attempts, use a fallback
# Sample angle uniformly and distance uniformly in [low, high]
if generator is not None:
angle_tensor = (
torch.rand(1, generator=generator, device=device) * 2 * 3.141592653589793
)
distance_tensor = low + (high - low) * torch.rand(
1, generator=generator, device=device
)
else:
angle_tensor = torch.rand(1, device=device) * 2 * 3.141592653589793
distance_tensor = low + (high - low) * torch.rand(1, device=device)
distance = float(distance_tensor[0])
# Compute cos and sin on GPU if device is GPU
cos_val = torch.cos(angle_tensor)[0]
sin_val = torch.sin(angle_tensor)[0]
dx = int(round(distance * float(cos_val)))
dy = int(round(distance * float(sin_val)))
# Ensure distance is still in range (may have been affected by rounding)
actual_distance = (dx**2 + dy**2) ** 0.5
if actual_distance < low:
# Scale up to meet minimum
scale = low / actual_distance
dx = int(round(dx * scale))
dy = int(round(dy * scale))
elif actual_distance > high:
# Scale down to meet maximum
scale = high / actual_distance
dx = int(round(dx * scale))
dy = int(round(dy * scale))
return dx, dy
[docs]
def extract_overlapping_pixels(
patches1: torch.Tensor,
patches2: torch.Tensor,
deltas: torch.Tensor,
rotations: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Extract overlapping pixels from patch pairs based on displacement vectors.
For each patch pair, this function finds pixels that have valid correspondences
between the two patches (i.e., pixels that represent the same spatial location
in the original image). Only overlapping pixels are returned.
Parameters
----------
patches1 : torch.Tensor
First set of patches, shape (N*P, C, U, V) where:
- N*P: Total number of patch pairs
- C: Number of channels
- U: Patch height
- V: Patch width
patches2 : torch.Tensor
Second set of patches, shape (N*P, C, U, V), corresponding patches
extracted at displaced locations
deltas : torch.Tensor
Displacement vectors, shape (N*P, 2) containing (dx, dy) for each pair
rotations : Optional[torch.Tensor], optional
Quarter-turn rotations (0 = 0°, 1 = 90°, 2 = 180°, 3 = 270°) that were
applied to `patches2`. When provided, each value is used to undo the rotation
before extracting overlaps so that corresponding pixels align spatially.
Returns
-------
Tuple[torch.Tensor, torch.Tensor]
A tuple containing:
- overlapping1: Overlapping pixel values from patches1, shape (K, C)
- overlapping2: Overlapping pixel values from patches2, shape (K, C)
where K is the total number of overlapping pixels across all patch pairs,
and corresponding pixels are at the same index in both tensors.
Examples
--------
>>> patches1 = torch.randn(10, 3, 32, 32)
>>> patches2 = torch.randn(10, 3, 32, 32)
>>> deltas = torch.tensor([[5, 3], [-2, 4], ...]) # shape (10, 2)
>>> overlapping1, overlapping2 = extract_overlapping_pixels(patches1, patches2, deltas)
>>> print(overlapping1.shape) # (K, 3) where K depends on overlap
>>> print(overlapping2.shape) # (K, 3)
>>> # overlapping1[i] and overlapping2[i] correspond to the same spatial location
"""
# Validate input shapes
if len(patches1.shape) != 4 or len(patches2.shape) != 4:
raise ValueError(
f"Both patches1 and patches2 must be 4D tensors (N*P, C, U, V), "
f"got shapes {patches1.shape} and {patches2.shape}"
)
if patches1.shape != patches2.shape:
raise ValueError(
f"patches1 and patches2 must have the same shape, "
f"got {patches1.shape} and {patches2.shape}"
)
if len(deltas.shape) != 2 or deltas.shape[1] != 2:
raise ValueError(
f"deltas must be 2D tensor of shape (N*P, 2), got {deltas.shape}"
)
num_pairs, C, U, V = patches1.shape
if deltas.shape[0] != num_pairs:
raise ValueError(
f"Number of deltas ({deltas.shape[0]}) must match number of patch pairs ({num_pairs})"
)
if rotations is not None:
if rotations.shape[0] != num_pairs:
raise ValueError(
f"Number of rotations ({rotations.shape[0]}) must match number of patch pairs ({num_pairs})"
)
rotations_int = rotations.int()
else:
rotations_int = None
# Convert deltas to integers for indexing (keep on same device)
deltas_int = deltas.int()
# Collect all overlapping pixels from both patches
overlapping_pixels1 = []
overlapping_pixels2 = []
for i in range(num_pairs):
# Get delta values without moving to CPU (use indexing, then convert to int)
dx_tensor = deltas_int[i, 0]
dy_tensor = deltas_int[i, 1]
# Convert to Python int only when needed for indexing
dx = int(dx_tensor)
dy = int(dy_tensor)
# Get the two patches
patch1 = patches1[i] # Shape: (C, U, V)
patch2 = patches2[i] # Shape: (C, U, V)
rotation = 0
if rotations_int is not None:
rotation = int(rotations_int[i] % 4)
if rotation != 0:
patch2 = torch.rot90(patch2, k=-rotation, dims=(-2, -1))
# Find valid overlap region in patch1 coordinates
# A pixel at (u1, v1) in patch1 corresponds to (u1 - dy, v1 - dx) in patch2
# For valid correspondence, we need:
# 0 <= u1 - dy < U and 0 <= v1 - dx < V
# Which means: dy <= u1 < U + dy and dx <= v1 < V + dx
# Combined with u1 in [0, U) and v1 in [0, V):
u_min = max(0, dy)
u_max = min(U, U + dy)
v_min = max(0, dx)
v_max = min(V, V + dx)
# Check if there's any overlap
if u_min >= u_max or v_min >= v_max:
# No overlap for this patch pair, skip it
continue
# Extract overlapping region from patch1
overlap_region1 = patch1[
:, u_min:u_max, v_min:v_max
] # Shape: (C, u_max-u_min, v_max-v_min)
# Extract corresponding region from patch2
# In patch2 coordinates: u2 = u1 - dy, v2 = v1 - dx
# So: u2_min = u_min - dy, u2_max = u_max - dy
# v2_min = v_min - dx, v2_max = v_max - dx
u2_min = u_min - dy
u2_max = u_max - dy
v2_min = v_min - dx
v2_max = v_max - dx
overlap_region2 = patch2[
:, u2_min:u2_max, v2_min:v2_max
] # Shape: (C, u_max-u_min, v_max-v_min)
# Reshape to (C, K') where K' is the number of overlapping pixels for this pair
K_prime = (u_max - u_min) * (v_max - v_min)
overlap_flat1 = overlap_region1.reshape(C, K_prime).t() # Shape: (K', C)
overlap_flat2 = overlap_region2.reshape(C, K_prime).t() # Shape: (K', C)
overlapping_pixels1.append(overlap_flat1)
overlapping_pixels2.append(overlap_flat2)
# Concatenate all overlapping pixels
if len(overlapping_pixels1) == 0:
# No overlapping pixels found, return empty tensors with correct shape
empty_tensor = torch.empty((0, C), dtype=patches1.dtype, device=patches1.device)
return empty_tensor, empty_tensor
# Stack all overlapping pixels
result1 = torch.cat(overlapping_pixels1, dim=0) # Shape: (K, C) where K is total
result2 = torch.cat(overlapping_pixels2, dim=0) # Shape: (K, C) where K is total
return result1, result2