from typing import Tuple

import torch
import torch.nn.functional as F


def fold_sum(
    patches: torch.Tensor,
    kernel_size: int,
    stride: int,
    padding: int,
    output_size: Tuple[int, int],
) -> torch.Tensor:
    """
    Folds patches into an image using sum aggregation.
    patches: [num_patches * B, C, kernel_size, kernel_size]
    """
    B, C, H, W = output_size
    predicted_level_noise = F.fold(
        patches.reshape(
            B,
            -1,
            kernel_size * kernel_size * C,
        ).permute(0, 2, 1),
        output_size=(H, W),
        kernel_size=kernel_size,
        stride=stride,
        padding=padding,
    )
    return predicted_level_noise


def fold_mean(
    patches: torch.Tensor,
    kernel_size: int,
    stride: int,
    padding: int,
    output_size: Tuple[int, int],
) -> torch.Tensor:
    """
    Folds patches into an image using mean aggregation.
    patches: [num_patches * B, C, kernel_size, kernel_size]
    """

    B, C, H, W = output_size
    predicted_level_noise = fold_sum(patches, kernel_size, stride, padding, output_size)

    # Normalize by the number of overlapping patches
    ones_patches = torch.ones_like(patches)
    divisor = fold_sum(ones_patches, kernel_size, stride, padding, output_size)
    predicted_level_noise = torch.where(
        divisor > 0,
        predicted_level_noise / divisor,
        torch.zeros_like(predicted_level_noise),
    )
    return predicted_level_noise


def fold_median(patches, kernel_size, stride, padding, output_size):
    """Fold patches back into image using median for overlapping regions.

    Args:
        patches: [B*num_patches, C, kernel_size, kernel_size]
        kernel_size: int
        stride: int
        padding: int
        output_size: tuple (B, C, H, W)
    """
    B, C, H, W = output_size
    num_patches = patches.shape[0] // B
    patches = patches.reshape(B, num_patches, C, kernel_size, kernel_size)

    # Maximum number of overlapping patches at any position
    max_overlaps = ((kernel_size + stride - 1) // stride) ** 2

    # Initialize output tensor to store all values
    output = torch.zeros(B, C, H, W, max_overlaps, device=patches.device)
    counts = torch.zeros(H, W, device=patches.device)

    # Calculate positions for each patch
    for idx in range(num_patches):
        # Calculate top-left corner position for this patch
        h_idx = (idx * stride) // (W - kernel_size + 2 * padding + 1)
        w_idx = (idx * stride) % (W - kernel_size + 2 * padding + 1)

        # For each pixel in the patch
        for i in range(kernel_size):
            for j in range(kernel_size):
                pos_h = h_idx + i
                pos_w = w_idx + j

                # Skip if position is out of bounds
                if pos_h >= H or pos_w >= W:
                    continue

                # Get current count for this position
                curr_count = counts[pos_h, pos_w].long()

                # Store the value
                if curr_count < max_overlaps:
                    output[:, :, pos_h, pos_w, curr_count] = patches[:, idx, :, i, j]
                    counts[pos_h, pos_w] += 1

    # Calculate median along the last dimension, ignoring zeros
    valid_mask = (
        torch.arange(max_overlaps, device=patches.device)[None, None, None, None, :]
        < counts[None, None, :, :, None]
    )
    masked_output = output.masked_fill(~valid_mask, float("nan"))
    result = torch.nanmedian(masked_output, dim=-1).values

    # Fill any remaining holes with zeros
    result = torch.where(
        counts.unsqueeze(0).unsqueeze(0) > 0, result, torch.zeros_like(result)
    )

    return result


def fold_center(
    patches: torch.Tensor,
    kernel_size: int,
    stride: int,
    padding: int,
    output_size: Tuple[int, int],
    fallback_to_mean: bool = True,
) -> torch.Tensor:
    """
    Folds patches into an image by taking a center region of each patch with size equal to stride.
    This ensures no overlap when folding patches back together.
    patches: [num_patches * B, C, kernel_size, kernel_size]
    """

    # Calculate the start index for extracting the center region
    start = (kernel_size - stride) // 2
    end = start + stride

    # Extract center region from each patch
    center_region = patches[
        :, :, start:end, start:end
    ]  # [num_patches * B, C, stride, stride]

    # Calculate the border size we need to add back
    border = (kernel_size - stride) // 2

    # Adjust output size to be smaller by 2*border
    B, C, H, W = output_size
    smaller_output_size = (B, C, H - 2 * border, W - 2 * border)

    # Use regular fold operation - patches will not overlap since size matches stride
    # print(
    #     f"\n center_region: {center_region.shape}",
    #     f"\n kernel_size: {kernel_size}",
    #     f"\n stride: {stride}",
    #     f"\n padding: {padding}",
    #     f"\n output_size: {smaller_output_size}",
    # )
    folded = fold_sum(
        center_region,
        kernel_size=stride,
        stride=stride,
        padding=padding,
        output_size=smaller_output_size,
    )

    # If we need to add borders
    if border > 0:
        # Create padded output with zeros in the border
        padded_output = torch.zeros(output_size, device=patches.device)
        padded_output[:, :, border:-border, border:-border] = folded

        if fallback_to_mean:
            mean_folding = fold_mean(patches, kernel_size, stride, padding, output_size)
            padded_output = torch.where(padded_output == 0, mean_folding, padded_output)
        return padded_output

    return folded
