Source code for qlty.qlty2D

import math
from typing import Optional, Tuple, Union

import einops
import numpy as np
import torch
from numba import njit, prange

from qlty.base import (
    compute_border_tensor_torch,
    compute_chunk_times,
    compute_weight_matrix_torch,
    normalize_border,
    validate_border_weight,
)


@njit(fastmath=True)  # pragma: no cover
def numba_njit_stitch(
    ml_tensor, result, norma, weight, window, step, Y, X, nX, times, m
):
    # NOTE:
    # We intentionally avoid `parallel=True` because concurrent updates to
    # shared output slices (`result` and `norma`) introduce race conditions
    # that break test expectations. Keeping the loop serial preserves correctness.
    for i in range(times):
        yy = i // nX
        xx = i % nX
        here_and_now = times * m + yy * nX + xx
        start_y = min(yy * step[0], Y - window[0])
        start_x = min(xx * step[1], X - window[1])
        stop_y = start_y + window[0]
        stop_x = start_x + window[1]
        for j in range(ml_tensor.shape[1]):
            tmp = ml_tensor[here_and_now, j, ...]
            result[m, j, start_y:stop_y, start_x:stop_x] += tmp * weight
        # get the weight matrix, only compute once
        if m == 0:
            norma[start_y:stop_y, start_x:stop_x] += weight
    return result, norma


@njit(fastmath=True, parallel=True)  # pragma: no cover
def numba_njit_stitch_color(
    ml_tensor,
    result,
    norma,
    weight,
    window,
    step,
    Y,
    X,
    nX,
    times,
    m,
    color_y_mod,
    color_x_mod,
    color_y_idx,
    color_x_idx,
):
    for i in prange(times):
        yy = i // nX
        xx = i % nX
        if yy % color_y_mod != color_y_idx or xx % color_x_mod != color_x_idx:
            continue
        here_and_now = times * m + yy * nX + xx
        start_y = min(yy * step[0], Y - window[0])
        start_x = min(xx * step[1], X - window[1])
        stop_y = start_y + window[0]
        stop_x = start_x + window[1]
        for j in range(ml_tensor.shape[1]):
            tmp = ml_tensor[here_and_now, j, ...]
            result[m, j, start_y:stop_y, start_x:stop_x] += tmp * weight
        if m == 0:
            norma[start_y:stop_y, start_x:stop_x] += weight
    return result, norma


def _ensure_numpy(array):
    if isinstance(array, torch.Tensor):
        return array.detach().cpu().contiguous().numpy()
    return array


def stitch_serial_numba(
    ml_tensor: torch.Tensor,
    weight: torch.Tensor,
    window: Tuple[int, int],
    step: Tuple[int, int],
    Y: int,
    X: int,
    nY: int,
    nX: int,
):
    times = nY * nX
    ml_tensor_np = _ensure_numpy(ml_tensor)
    weight_np = _ensure_numpy(weight)

    M_images = ml_tensor_np.shape[0] // times
    assert ml_tensor_np.shape[0] % times == 0

    result_np = np.zeros(
        (M_images, ml_tensor_np.shape[1], Y, X), dtype=ml_tensor_np.dtype
    )
    norma_np = np.zeros((Y, X), dtype=weight_np.dtype)

    for m in range(M_images):
        result_np, norma_np = numba_njit_stitch(
            ml_tensor_np,
            result_np,
            norma_np,
            weight_np,
            window,
            step,
            Y,
            X,
            nX,
            times,
            m,
        )

    result = torch.from_numpy(result_np)
    norma = torch.from_numpy(norma_np)
    result = result / norma
    return result, norma


def stitch_parallel_colored(
    ml_tensor: torch.Tensor,
    weight: torch.Tensor,
    window: Tuple[int, int],
    step: Tuple[int, int],
    Y: int,
    X: int,
    nY: int,
    nX: int,
):
    times = nY * nX
    ml_tensor_np = _ensure_numpy(ml_tensor)
    weight_np = _ensure_numpy(weight)

    M_images = ml_tensor_np.shape[0] // times
    assert ml_tensor_np.shape[0] % times == 0

    result_np = np.zeros(
        (M_images, ml_tensor_np.shape[1], Y, X), dtype=ml_tensor_np.dtype
    )
    norma_np = np.zeros((Y, X), dtype=weight_np.dtype)

    color_y_mod = 1
    color_x_mod = 1
    if step[0] > 0:
        color_y_mod = max(1, math.ceil(window[0] / step[0]))
    if step[1] > 0:
        color_x_mod = max(1, math.ceil(window[1] / step[1]))
    if nY > 0:
        color_y_mod = min(color_y_mod, nY)
    if nX > 0:
        color_x_mod = min(color_x_mod, nX)

    for m in range(M_images):
        for color_y_idx in range(color_y_mod):
            for color_x_idx in range(color_x_mod):
                result_np, norma_np = numba_njit_stitch_color(
                    ml_tensor_np,
                    result_np,
                    norma_np,
                    weight_np,
                    window,
                    step,
                    Y,
                    X,
                    nX,
                    times,
                    m,
                    color_y_mod,
                    color_x_mod,
                    color_y_idx,
                    color_x_idx,
                )

    result = torch.from_numpy(result_np)
    norma = torch.from_numpy(norma_np)
    result = result / norma
    return result, norma


[docs] class NCYXQuilt: """ This class allows one to split larger tensors into smaller ones that perhaps do fit into memory. This class is aimed at handling tensors of type (N,C,Y,X) """ def __init__( self, Y: int, X: int, window: Tuple[int, int], step: Tuple[int, int], border: Optional[Union[int, Tuple[int, int]]], border_weight: float = 1.0, ) -> None: """ This class allows one to split larger tensors into smaller ones that perhaps do fit into memory. This class is aimed at handling tensors of type (N,C,Y,X). Parameters ---------- Y : number of elements in the Y direction X : number of elements in the X direction window: The size of the sliding window, a tuple (Ysub, Xsub) step: The step size at which we want to sample the sliding window (Ystep,Xstep) border: Border pixels of the window we want to 'ignore' or down weight when stitching things back border_weight: The weight for the border pixels, should be between 0 and 1. The default of 0.1 should be fine """ self.Y = Y self.X = X self.window = window self.step = step # Normalize and validate border self.border = normalize_border(border, ndim=2) self.border_weight = validate_border_weight(border_weight) # Compute chunk times self.nY, self.nX = compute_chunk_times( dimension_sizes=(Y, X), window=window, step=step ) # Compute weight matrix self.weight = compute_weight_matrix_torch( window=window, border=self.border, border_weight=self.border_weight )
[docs] def border_tensor(self) -> torch.Tensor: """Compute border tensor indicating valid (non-border) regions.""" return compute_border_tensor_torch(window=self.window, border=self.border)
[docs] def get_times(self) -> Tuple[int, int]: """ Compute the number of patches along each spatial dimension. This method calculates how many patches will be created in the Y and X dimensions, ensuring the last patch always fits within the image bounds. Returns ------- Tuple[int, int] A tuple (nY, nX) where: - nY: Number of patches in the Y (height) dimension - nX: Number of patches in the X (width) dimension The total number of patches per image is nY * nX. Examples -------- >>> quilt = NCYXQuilt(Y=128, X=128, window=(32, 32), step=(16, 16)) >>> nY, nX = quilt.get_times() >>> print(f"Patches per image: {nY * nX}") >>> print(f"Total patches for 10 images: {10 * nY * nX}") """ return compute_chunk_times( dimension_sizes=(self.Y, self.X), window=self.window, step=self.step )
[docs] def unstitch_data_pair( self, tensor_in: torch.Tensor, tensor_out: torch.Tensor, missing_label: Optional[Union[int, float]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Split input and output tensors into smaller overlapping patches. This method is useful for training neural networks where you need to process input-output pairs together. The output tensor can optionally have missing labels that will be masked in border regions. Parameters ---------- tensor_in : torch.Tensor Input tensor of shape (N, C, Y, X). The tensor going into the network. tensor_out : torch.Tensor Output tensor of shape (N, C, Y, X) or (N, Y, X). The target tensor. If 3D, will be automatically expanded to 4D. missing_label : Optional[Union[int, float]], optional Label value that indicates missing/invalid data. If provided, pixels in the border region will be set to this value in the output patches. Default is None (no masking). Returns ------- Tuple[torch.Tensor, torch.Tensor] A tuple of (input_patches, output_patches) where: - input_patches: Shape (M, C, window[0], window[1]) - output_patches: Shape (M, C, window[0], window[1]) or (M, window[0], window[1]) where M = N * nY * nX Examples -------- >>> quilt = NCYXQuilt(Y=128, X=128, window=(32, 32), step=(16, 16), border=(5, 5)) >>> input_data = torch.randn(10, 3, 128, 128) >>> target_data = torch.randn(10, 128, 128) >>> inp_patches, tgt_patches = quilt.unstitch_data_pair(input_data, target_data) >>> print(inp_patches.shape) # (M, 3, 32, 32) >>> print(tgt_patches.shape) # (M, 32, 32) """ modsel = None if missing_label is not None: modsel = self.border_tensor() < 0.5 rearranged = False if len(tensor_out.shape) == 3: tensor_out = einops.rearrange(tensor_out, "N Y X -> N () Y X") rearranged = True assert len(tensor_out.shape) == 4 assert len(tensor_in.shape) == 4 assert tensor_in.shape[0] == tensor_out.shape[0] unstitched_in = self.unstitch(tensor_in) unstitched_out = self.unstitch(tensor_out) if modsel is not None: unstitched_out[:, :, modsel] = missing_label if rearranged: assert unstitched_out.shape[1] == 1 unstitched_out = unstitched_out.squeeze(dim=1) return unstitched_in, unstitched_out
[docs] def unstitch(self, tensor: torch.Tensor) -> torch.Tensor: """ Split a tensor into smaller overlapping patches. Parameters ---------- tensor : torch.Tensor Input tensor of shape (N, C, Y, X) where: - N: Number of images - C: Number of channels - Y: Height (must match self.Y) - X: Width (must match self.X) Returns ------- torch.Tensor Patches tensor of shape (M, C, window[0], window[1]) where: - M = N * nY * nX (total number of patches) - window[0], window[1]: Patch dimensions Examples -------- >>> quilt = NCYXQuilt(Y=128, X=128, window=(32, 32), step=(16, 16)) >>> data = torch.randn(10, 3, 128, 128) >>> patches = quilt.unstitch(data) >>> print(patches.shape) # (M, 3, 32, 32) """ N, C, Y, X = tensor.shape result = [] for n in range(N): tmp = tensor[n, ...] for yy in range(self.nY): for xx in range(self.nX): start_y = min(yy * self.step[0], self.Y - self.window[0]) start_x = min(xx * self.step[1], self.X - self.window[1]) stop_y = start_y + self.window[0] stop_x = start_x + self.window[1] patch = tmp[:, start_y:stop_y, start_x:stop_x] result.append(patch) result = einops.rearrange(result, "M C Y X -> M C Y X") return result
[docs] def stitch( self, ml_tensor: torch.Tensor, use_numba: bool = True ) -> Tuple[torch.Tensor, torch.Tensor]: """ Reassemble patches back into full-size tensors. This method takes patches produced by `unstitch()` and stitches them back together, averaging overlapping regions using a weight matrix. Border regions are downweighted according to `border_weight`. Typical workflow: 1. Unstitch the data:: patches = quilt.unstitch(input_images) 2. Process patches with your model:: output_patches = model(patches) 3. Stitch back together:: reconstructed, weights = quilt.stitch(output_patches) Parameters ---------- ml_tensor : torch.Tensor Patches tensor of shape (M, C, window[0], window[1]) where: - M must equal N * nY * nX (number of patches) - C: Number of channels - window: Patch dimensions use_numba : bool, optional Whether to use Numba JIT compilation for faster stitching. Default is True (recommended for performance). Returns ------- Tuple[torch.Tensor, torch.Tensor] A tuple of (reconstructed, weights) where: - reconstructed: Shape (N, C, Y, X) - the stitched result - weights: Shape (Y, X) - normalization weights (number of contributors per pixel) Notes ----- **Important**: When working with classification outputs: - Apply softmax AFTER stitching, not before - Averaging softmaxed tensors ≠ softmax of averaged tensors - Process logits, stitch them, then apply softmax to the final result Example:: # CORRECT: logits = model(patches) stitched_logits, _ = quilt.stitch(logits) probabilities = F.softmax(stitched_logits, dim=1) # WRONG: probs = F.softmax(model(patches), dim=1) result, _ = quilt.stitch(probs) # This is incorrect! Examples -------- >>> quilt = NCYXQuilt(Y=128, X=128, window=(32, 32), step=(16, 16)) >>> data = torch.randn(10, 3, 128, 128) >>> patches = quilt.unstitch(data) >>> processed = model(patches) >>> reconstructed, weights = quilt.stitch(processed) >>> print(reconstructed.shape) # (10, C, 128, 128) """ N, C, Y, X = ml_tensor.shape # we now need to figure out how to stitch this back into what dimension times = self.nY * self.nX M_images = N // times assert N % times == 0 if use_numba: return stitch_parallel_colored( ml_tensor, self.weight, self.window, self.step, self.Y, self.X, self.nY, self.nX, ) result = torch.zeros((M_images, C, self.Y, self.X)) norma = torch.zeros((self.Y, self.X)) for m in range(M_images): for yy in range(self.nY): for xx in range(self.nX): here_and_now = times * m + yy * self.nX + xx start_y = min(yy * self.step[0], self.Y - self.window[0]) start_x = min(xx * self.step[1], self.X - self.window[1]) stop_y = start_y + self.window[0] stop_x = start_x + self.window[1] tmp = ml_tensor[here_and_now, ...] result[m, :, start_y:stop_y, start_x:stop_x] += tmp * self.weight if m == 0: norma[start_y:stop_y, start_x:stop_x] += self.weight result = result / norma return result, norma