import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass


class CutFillConfig:
    __slots__ = ("min_frac", "max_frac", "fill_value", "loss_on_full")
    def __init__(self, min_frac=0.1, max_frac=0.25, fill_value=0.0, loss_on_full=False):
        self.min_frac = float(min_frac)
        self.max_frac = float(max_frac)
        self.fill_value = float(fill_value)
        self.loss_on_full = bool(loss_on_full)
    def __repr__(self):
        return (f"CutFillConfig(min_frac={self.min_frac}, max_frac={self.max_frac}, "
                f"fill_value={self.fill_value}, loss_on_full={self.loss_on_full})")


def _rand_span(T, min_frac, max_frac):
    min_len = max(1, int(T * min_frac))
    max_len = max(min_len, int(T * max_frac))
    span_len = torch.randint(min_len, max_len + 1, ()).item()
    start = torch.randint(0, max(1, T - span_len + 1), ()).item()
    return start, start + span_len


def apply_cut_and_fill(x: torch.Tensor, cfg: CutFillConfig):
    """
    EEG cut-and-fill for x shaped (B, 1, C, T).
      - Masks one contiguous time span per batch item, across all channels.
      - Fills with cfg.fill_value (scalar or Tensor broadcastable to (1,1,C,1)).

    Returns:
      x_masked: (B, 1, C, T)
      mask:     (B, 1, C, T) boolean (True where masked)
      spans:    list of (start, end) for each item
    """
    assert x.ndim == 4 and x.size(1) == 1, "Expected x with shape (B, 1, C, T)."
    B, _, C, T = x.shape
    device, dtype = x.device, x.dtype

    x_masked = x.clone()
    mask = torch.zeros((B, 1, C, T), dtype=torch.bool, device=device)
    spans = []

    # prepare a broadcastable filler of shape (1, 1, C, 1)
    fv = cfg.fill_value
    if isinstance(fv, torch.Tensor):
        filler = fv.to(device=device, dtype=dtype).view(1, 1, -1, 1)  # expect C or 1*C*1
        if filler.size(2) not in (1, C):
            raise ValueError(f"fill_value Tensor must have C={C} or 1 channel, got {filler.size(2)}")
    else:
        filler = torch.as_tensor(fv, device=device, dtype=dtype).view(1, 1, 1, 1)

    for b in range(B):
        s, e = _rand_span(T, cfg.min_frac, cfg.max_frac)
        spans.append((s, e))
        x_masked[b, :, :, s:e] = filler
        mask[b, :, :, s:e] = True

    return x_masked, mask, spans


def variance_floor_loss_eeg(y_hat: torch.Tensor, floor: float = 0.5):
    """
    Hinge penalty if per-sample std across (C,T) falls below `floor`.
    y_hat: (B, 1, C, T)
    """
    std = y_hat.std(dim=(-2, -1), unbiased=False)     # (B, 1)
    return (torch.abs(floor - std))**2


def reconstruction_loss(x: torch.Tensor, x_hat: torch.Tensor,
                            mask: torch.Tensor | None = None, reduction: str = "mean"):
    """
    Mask-aware MSE for arbitrary shapes. If `mask` is provided, it must be
    broadcastable to `x` (e.g., mask same shape as x or with singleton dims).

    - reduction='mean': sum(diff*mask) / (#masked elements)
    - reduction='sum':  sum(diff*mask)
    - reduction='none': elementwise (same shape as x)

    Works for:
      (B, T, D), (B, 1, C, T), etc.
    """

    # var_reg = variance_floor_loss_eeg(x_hat, floor=0.5)

    if mask is None:
        return F.mse_loss(x_hat, x, reduction=reduction)  # + 5*var_reg.sum()

    diff2 = (x_hat - x) ** 2
    m = mask.to(dtype=diff2.dtype)
    diff2 = diff2 * m

    if reduction == "none":
        return diff2
    elif reduction == "sum":
        return diff2.sum()
    elif reduction == "mean":
        denom = m.sum()
        # avoid div-by-zero if mask is empty (shouldn't happen)
        return diff2.sum() / torch.clamp(denom, min=1.0)
    else:
        raise ValueError(f"Unknown reduction: {reduction}")
