import torch
import torch.nn.functional as F
from typing import Literal, Optional, Tuple


def timestep_to_array(timestep, array):
    step_size = 1000 // len(array)
    level_idx = (999 - timestep) // step_size
    level_idx = min(level_idx, len(array) - 1)
    return array[level_idx]


class SoftmaxWeightedAverage:
    def __init__(
        self,
        mode: Literal["images", "covariance"] = "images",
        device=None,
        dtype=torch.float32,
    ):
        """
        Initialize the weighted average calculator.

        Args:
            mode: Either "images" for averaging image tensors or "covariance" for averaging covariance matrices
            device: Device to store tensors on
            dtype: Data type for computations
        """
        self.mode = mode
        self.device = device
        self.dtype = dtype
        self.sum_weighted = None  # [b, n] for images, [b, n, n] for covariance
        self.sum_weights = None  # [b, n]

    def add(self, x0b: torch.Tensor, logits: torch.Tensor):
        """
        Add a batch of weighted samples to the running average.

        Args:
            x0b:    [k, n]        - shared across batch for images
                  or [k, n, n]    - shared across batch for covariance
            logits: [b, k, n]     - per-batch logits
        """
        b, k, n = logits.shape

        if self.mode == "images":
            assert x0b.shape == (
                k,
                n,
            ), f"Expected x0b of shape ({k}, {n}), got {x0b.shape}"
        else:  # covariance mode
            assert x0b.shape == (
                k,
                n,
                n,
            ), f"Expected x0b of shape ({k}, {n}, {n}), got {x0b.shape}"

        x0b = x0b.to(dtype=self.dtype, device=self.device)
        logits = logits.to(dtype=self.dtype, device=self.device)

        # Softmax in a numerically stable way over dim=1 (k)
        logits_max, _ = logits.max(dim=1, keepdim=True)  # [b, 1, n]
        logits_exp = torch.exp(logits - logits_max)  # [b, k, n]
        weights = logits_exp / logits_exp.sum(dim=1, keepdim=True)  # [b, k, n]

        if self.mode == "images":
            # Weighted sum over k: einsum gives [b, n]
            weighted_sum = torch.einsum("bkn,kn->bn", weights, x0b)  # [b, n]
        else:  # covariance mode
            # Weighted sum over k: einsum gives [b, n, n]
            weighted_sum = torch.einsum(
                "bknm,knm->bnm", weights.unsqueeze(-1), x0b
            )  # [b, n, n]

        weight_sum = weights.sum(dim=1)  # [b, n]

        # Initialize or accumulate
        if self.sum_weighted is None:
            self.sum_weighted = weighted_sum
            self.sum_weights = weight_sum
        else:
            self.sum_weighted += weighted_sum
            self.sum_weights += weight_sum

    def get_average(self) -> Optional[torch.Tensor]:
        """
        Get the weighted average.

        Returns:
            For images: Tensor of shape [b, n]
            For covariance: Tensor of shape [b, n, n]
            None if no samples have been added
        """
        if self.sum_weighted is None:
            return None

        if self.mode == "images":
            return self.sum_weighted / (self.sum_weights + 1e-8)
        else:  # covariance mode
            # Reshape weights to [b, n, 1] for proper broadcasting with [b, n, n]
            weights = self.sum_weights.unsqueeze(-1)
            return self.sum_weighted / (weights + 1e-8)


def all_translations(x: torch.Tensor, a: int) -> torch.Tensor:
    """
    Generate all translations of x in [-a, a] along width and height,
    excluding (0, 0), using mirror padding.

    Args:
        x: Tensor of shape [k, c, w, h]
        a: max shift in both directions

    Yields:
        Tensors of shape [k, c, w, h] for each translation
    """
    k, c, w, h = x.shape
    padded = F.pad(
        x, (a, a, a, a), mode="reflect"
    )  # pad left, right, top, bottom with reflection => [k, c, w+2a, h+2a]

    for dy in range(-a, a + 1):
        for dx in range(-a, a + 1):
            yield padded[:, :, a + dy : a + dy + w, a + dx : a + dx + h]
