import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional


class EnhancedMixUp(nn.Module):
    """
    Batch-level MixUp with an optional "resolution" that draws multiple lambdas per pair.

    If pairs is None: use all batch elements as pairs (effective_pairs = B).
    Output batch size becomes effective_pairs * resolution.

    If you want to keep output batch size == B, set pairs = B // resolution.
    """
    def __init__(
        self,
        alpha: float = 1.0,
        num_classes: Optional[int] = None,
        resolution: int = 1,
        pairs: Optional[int] = None,
    ):
        super().__init__()
        self.alpha = float(alpha)
        self.num_classes = num_classes
        self.resolution = int(resolution)
        self.pairs = pairs

        if self.alpha <= 0:
            raise ValueError("alpha must be > 0 for Beta(alpha, alpha).")
        if self.resolution <= 0:
            raise ValueError("resolution must be >= 1.")
        if self.num_classes is None:
            raise ValueError("num_classes must be provided for index labels (CIFAR-10 -> 10).")

        self._dist = torch.distributions.Beta(self.alpha, self.alpha)

    @torch.no_grad()
    def _one_hot(self, y: torch.Tensor) -> torch.Tensor:
        # y: (B,) -> (B, K)
        if y.ndim == 2:
            return y.float()
        return F.one_hot(y, num_classes=self.num_classes).float()

    def forward(self, x: torch.Tensor, y: torch.Tensor):
        """
        x: (B, C, H, W)
        y: (B,) integer labels or (B, K) one-hot/soft labels
        returns:
          x_mix: (P*R, C, H, W)
          y_mix: (P*R, K)
        """
        if x.ndim != 4:
            raise ValueError(f"Expected x to be (B,C,H,W), got {x.shape}")
        B = x.shape[0]
        device = x.device

        # choose number of pairs
        effective_pairs = B if self.pairs is None else min(int(self.pairs), B)

        # random pairing
        idx1 = torch.randperm(B, device=device)[:effective_pairs]
        idx2 = torch.randperm(B, device=device)[:effective_pairs]

        # sample lambdas: (P, R)
        lam = self._dist.sample((effective_pairs, self.resolution)).to(device=device, dtype=x.dtype)
        lam_img = lam.view(effective_pairs, self.resolution, 1, 1, 1)

        x1 = x[idx1].unsqueeze(1)  # (P,1,C,H,W)
        x2 = x[idx2].unsqueeze(1)  # (P,1,C,H,W)
        x_mix = lam_img * x1 + (1.0 - lam_img) * x2
        x_mix = x_mix.reshape(effective_pairs * self.resolution, *x.shape[1:])

        y_oh = self._one_hot(y).to(device=device, dtype=x.dtype)  # (B,K)
        y1 = y_oh[idx1].unsqueeze(1)  # (P,1,K)
        y2 = y_oh[idx2].unsqueeze(1)  # (P,1,K)
        lam_lab = lam.view(effective_pairs, self.resolution, 1)
        y_mix = lam_lab * y1 + (1.0 - lam_lab) * y2
        y_mix = y_mix.reshape(effective_pairs * self.resolution, y_oh.shape[-1])

        return x_mix, y_mix
