from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import warnings

# -----------------------------
# Utilities
# -----------------------------

def _to_torch(x: Union[np.ndarray, torch.Tensor], device: torch.device, dtype: torch.dtype) -> torch.Tensor:
    if isinstance(x, torch.Tensor):
        return x.to(device=device, dtype=dtype)
    return torch.as_tensor(x, device=device, dtype=dtype)


def set_torch_seed(seed: int = 42) -> None:
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)


def conformal_quantile_higher(scores: np.ndarray, alpha: float) -> float:
    """
    Split-conformal finite-sample quantile with "higher"/ceil correction:
      q = sorted_scores[k], k = ceil((n+1)*(1-alpha)) - 1  (0-indexed)
    """
    s = np.asarray(scores, dtype=float)
    s = s[np.isfinite(s)]
    n = s.size
    if n == 0:
        return float("nan")
    s_sorted = np.sort(s)
    k = int(np.ceil((n + 1) * (1.0 - alpha)) - 1)
    k = max(0, min(k, n - 1))
    return float(s_sorted[k])


def train_val_split(
    X: np.ndarray,
    Y: np.ndarray,
    val_frac: float = 0.5,
    seed: int = 42,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Deterministic split (shuffle) into validation and calibration portions.
    """
    assert 0.0 < val_frac < 1.0
    n = X.shape[0]
    rng = np.random.default_rng(seed)
    idx = np.arange(n)
    rng.shuffle(idx)
    n_val = int(round(val_frac * n))
    val_idx = idx[:n_val]
    cal_idx = idx[n_val:]
    return X[val_idx], Y[val_idx], X[cal_idx], Y[cal_idx]


# -----------------------------
# Conditional RealNVP (minimal)
# -----------------------------

class MLP(nn.Module):
    def __init__(self, in_dim: int, out_dim: int, hidden: Sequence[int] = (256, 256), dropout: float = 0.0):
        super().__init__()
        layers: List[nn.Module] = []
        d = in_dim
        for h in hidden:
            layers.append(nn.Linear(d, h))
            layers.append(nn.ReLU())
            if dropout and dropout > 0:
                layers.append(nn.Dropout(dropout))
            d = h
        layers.append(nn.Linear(d, out_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class ConditionalAffineCoupling(nn.Module):
    """
    One affine coupling layer with conditioning on x.
    Forward transforms y -> y' and accumulates log|det J|.
    """
    def __init__(
        self,
        x_dim: int,
        y_dim: int,
        mask: torch.Tensor,  # shape (y_dim,)
        hidden: Sequence[int] = (256, 256),
        scale_clip: float = 0.8,
    ):
        super().__init__()
        assert mask.ndim == 1 and mask.shape[0] == y_dim
        self.register_buffer("mask", mask.float())
        self.y_dim = y_dim
        self.scale_clip = float(scale_clip)

        in_dim = x_dim + y_dim  # x plus masked y
        self.net_s = MLP(in_dim, y_dim, hidden=hidden)
        self.net_t = MLP(in_dim, y_dim, hidden=hidden)

    def forward(self, y: torch.Tensor, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        y, x: (B, y_dim), (B, x_dim)
        Returns: y_out, logdet (B,)
        """
        m = self.mask
        y_masked = y * m
        h = torch.cat([x, y_masked], dim=-1)
        s = torch.tanh(self.net_s(h)) * self.scale_clip
        t = self.net_t(h)

        # apply only to unmasked dims
        s = s * (1.0 - m)
        t = t * (1.0 - m)

        y_out = y_masked + (1.0 - m) * (y * torch.exp(s) + t)
        logdet = s.sum(dim=-1)
        return y_out, logdet

    def inverse(self, y: torch.Tensor, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Inverse transform y -> y_inv.
        Returns: y_inv, inv_logdet (B,) = - forward_logdet at y_inv
        """
        m = self.mask
        y_masked = y * m
        h = torch.cat([x, y_masked], dim=-1)
        s = torch.tanh(self.net_s(h)) * self.scale_clip
        t = self.net_t(h)

        s = s * (1.0 - m)
        t = t * (1.0 - m)

        y_inv = y_masked + (1.0 - m) * ((y - t) * torch.exp(-s))
        inv_logdet = -s.sum(dim=-1)
        return y_inv, inv_logdet


class ConditionalRealNVP(nn.Module):
    """
    Stack of conditional affine coupling layers with simple fixed permutations.
    Provides:
      - forward(y,x) -> (z, logdet_fwd)
      - inverse(z,x) -> (y, logdet_inv)
    where z has same dimension as y (base = N(0,I)).
    """
    def __init__(
        self,
        x_dim: int,
        y_dim: int,
        n_layers: int = 8,
        hidden: Sequence[int] = (256, 256),
        scale_clip: float = 0.8,
        use_flip_permutation: bool = True,
    ):
        super().__init__()
        self.x_dim = int(x_dim)
        self.y_dim = int(y_dim)
        self.n_layers = int(n_layers)
        self.use_flip_permutation = bool(use_flip_permutation)

        layers: List[nn.Module] = []
        for i in range(n_layers):
            # alternating binary mask
            mask = torch.zeros(y_dim)
            mask[i % 2 :: 2] = 1.0
            layers.append(ConditionalAffineCoupling(x_dim, y_dim, mask=mask, hidden=hidden, scale_clip=scale_clip))
        self.layers = nn.ModuleList(layers)

    def _permute(self, y: torch.Tensor) -> torch.Tensor:
        if not self.use_flip_permutation:
            return y
        return torch.flip(y, dims=[-1])

    def forward(self, y: torch.Tensor, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        z = y
        logdet = torch.zeros(z.shape[0], device=z.device, dtype=z.dtype)
        for layer in self.layers:
            z, ld = layer(z, x)
            logdet = logdet + ld
            z = self._permute(z)
        return z, logdet

    def inverse(self, z: torch.Tensor, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        y = z
        logdet = torch.zeros(y.shape[0], device=y.device, dtype=y.dtype)
        for layer in reversed(self.layers):
            y = self._permute(y)  # invert permutation (self-inverse)
            y, ld = layer.inverse(y, x)
            logdet = logdet + ld
        return y, logdet


def standard_normal_logprob(z: torch.Tensor) -> torch.Tensor:
    # (B, d)
    return -0.5 * (z**2 + np.log(2.0 * np.pi)).sum(dim=-1)


@dataclass
class VSPSState:
    flow: ConditionalRealNVP
    K_star: int
    gamma: float
    sort_ascending: bool
    M_samples: int


# -----------------------------
# Training the flow
# -----------------------------

def train_vsps_flow(
    X_train: np.ndarray,
    Y_train: np.ndarray,
    *,
    n_layers: int = 8,
    hidden: Sequence[int] = (256, 256),
    lr: float = 2e-4,
    weight_decay: float = 1e-6,
    batch_size: int = 512,
    n_epochs: int = 50,
    device: Optional[Union[str, torch.device]] = None,
    dtype: torch.dtype = torch.float32,
    seed: int = 42,
    verbose: bool = True,
) -> ConditionalRealNVP:
    """
    Train conditional RealNVP by maximum likelihood on TRAIN split.
    """
    set_torch_seed(seed)
    device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))

    Xtr = _to_torch(X_train, device=device, dtype=dtype)
    Ytr = _to_torch(Y_train, device=device, dtype=dtype)

    x_dim = Xtr.shape[1]
    y_dim = Ytr.shape[1]

    flow = ConditionalRealNVP(
        x_dim=x_dim, y_dim=y_dim, n_layers=n_layers, hidden=hidden,
    ).to(device=device, dtype=dtype)

    opt = torch.optim.Adam(flow.parameters(), lr=lr, weight_decay=weight_decay)

    ds = torch.utils.data.TensorDataset(Xtr, Ytr)
    dl = torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=False)

    flow.train()
    for epoch in range(1, n_epochs + 1):
        losses = []
        for xb, yb in dl:
            opt.zero_grad(set_to_none=True)
            z, logdet = flow.forward(yb, xb)
            log_p = standard_normal_logprob(z) + logdet
            loss = (-log_p).mean()
            loss.backward()
            nn.utils.clip_grad_norm_(flow.parameters(), max_norm=5.0)
            opt.step()
            losses.append(loss.detach().item())
        if verbose and (epoch == 1 or epoch % 10 == 0 or epoch == n_epochs):
            print(f"[VSPS][Flow] epoch {epoch:03d}/{n_epochs} | NLL {np.mean(losses):.4f}")
    flow.eval()
    return flow


# -----------------------------
# Sampling centers and scoring
# -----------------------------

@torch.no_grad()
def sample_y_and_logdet_fwd(
    flow: ConditionalRealNVP,
    X: Union[np.ndarray, torch.Tensor],
    M: int,
    *,
    device: Optional[Union[str, torch.device]] = None,
    dtype: torch.dtype = torch.float32,
    batch_size: int = 256,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    For each x, sample M points y_m ~ p(y|x) using inverse flow,
    and compute forward log|det ∂f(y,x)/∂y| at those y_m.

    Returns:
      y_samples: (N, M, d_y)
      logdet_fwd: (N, M)
    """
    device = torch.device(device or next(flow.parameters()).device)
    X_t = _to_torch(X, device=device, dtype=dtype)
    N = X_t.shape[0]
    d_y = flow.y_dim

    y_all = []
    ld_all = []

    for start in range(0, N, batch_size):
        xb = X_t[start : start + batch_size]  # (B, x_dim)
        B = xb.shape[0]

        z = torch.randn((B, M, d_y), device=device, dtype=dtype)
        z_flat = z.reshape(B * M, d_y)
        xb_rep = xb.unsqueeze(1).expand(B, M, xb.shape[1]).reshape(B * M, xb.shape[1])

        y_flat, _ = flow.inverse(z_flat, xb_rep)  # (B*M, d_y)
        # compute fwd logdet at y
        _, logdet_fwd = flow.forward(y_flat, xb_rep)  # (B*M,)
        y = y_flat.reshape(B, M, d_y)
        ld = logdet_fwd.reshape(B, M)

        y_all.append(y.cpu())
        ld_all.append(ld.cpu())

    y_samples = torch.cat(y_all, dim=0)
    logdet_fwd = torch.cat(ld_all, dim=0)
    return y_samples, logdet_fwd


@torch.no_grad()
def select_topk_centers(
    y_samples: torch.Tensor,      # (N, M, d_y) on CPU ok
    logdet_fwd: torch.Tensor,     # (N, M)
    K: int,
    *,
    sort_ascending: bool = True,
) -> torch.Tensor:
    """
    Select K centers per x based on volume score (logdet_fwd).

    By default, keeps smallest logdet (sort_ascending=True).
    (Double-check with your reading of the paper; flip if needed.)
    Returns centers: (N, K, d_y)
    """
    assert y_samples.ndim == 3 and logdet_fwd.ndim == 2
    N, M, d_y = y_samples.shape
    K = int(min(K, M))
    scores = logdet_fwd  # (N, M)

    if sort_ascending:
        # smallest first
        idx = torch.topk(-scores, k=K, dim=1).indices  # largest of (-scores) = smallest scores
    else:
        idx = torch.topk(scores, k=K, dim=1).indices

    idx_exp = idx.unsqueeze(-1).expand(N, K, d_y)
    centers = torch.gather(y_samples, dim=1, index=idx_exp)
    return centers


def min_dist_to_centers(
    y_true: np.ndarray,
    centers: np.ndarray,
) -> np.ndarray:
    """
    y_true: (N, d_y)
    centers: (N, K, d_y)
    returns: (N,) min Euclidean distance
    """
    y = np.asarray(y_true, dtype=float)
    c = np.asarray(centers, dtype=float)
    diff = c - y[:, None, :]
    dists = np.linalg.norm(diff, axis=-1)  # (N, K)
    return dists.min(axis=1)


# -----------------------------
# K selection and gamma calibration
# -----------------------------

def choose_k_star(
    flow: ConditionalRealNVP,
    X_val: np.ndarray,
    Y_val: np.ndarray,
    *,
    K_grid: Sequence[int] = (1, 2, 4, 8, 16, 32),
    alpha: float = 0.1,
    M: int = 1024,
    sort_ascending: bool = True,
    n_mc_volume: int = 2048,
    sampling_radius: Optional[Union[float, np.ndarray]] = None,
    eval_points: int = 128,
    device: Optional[Union[str, torch.device]] = None,
    dtype: torch.dtype = torch.float32,
    batch_size: int = 128,
    seed: int = 42,
    verbose: bool = True,
) -> int:
    """
    Select K* on a VALIDATION split (independent of conformal calibration).

    We use a safe proxy: for each K,
      - compute centers on a subset of val points,
      - compute proxy gamma = (1-alpha) quantile of min-dist on val,
      - estimate average volume of union-of-balls via MC around mean center.

    Pick K minimizing estimated volume.
    """
    set_torch_seed(seed)
    N = X_val.shape[0]
    if N == 0:
        return int(K_grid[0])

    rng = np.random.default_rng(seed)
    sub = min(eval_points, N)
    idx = rng.choice(N, size=sub, replace=False)
    Xs = X_val[idx]
    Ys = Y_val[idx]

    # default sampling radius (per-dim) if not provided
    d_y = Ys.shape[1]
    y_std = np.std(Y_val, axis=0)
    sampling_radius = 3 * y_std

    best_K = int(K_grid[0])
    best_vol = float("inf")

    # Pre-sample candidates once per K? We need per x anyway; do per K to match selection.
    for K in K_grid:
        y_samp, ld = sample_y_and_logdet_fwd(flow, Xs, M, device=device, dtype=dtype, batch_size=batch_size)
        centers = select_topk_centers(y_samp, ld, K, sort_ascending=sort_ascending).numpy()

        scores = min_dist_to_centers(Ys, centers)
        gamma_proxy = np.quantile(scores, 1.0 - alpha)  # proxy; NOT conformal

        vols = []
        for i in range(sub):
            vols.append(
                mc_volume_union_balls(
                    centers[i],
                    gamma_proxy,
                    sampling_radius=sampling_radius,
                    n_mc=n_mc_volume,
                    seed=seed + i,
                )
            )
        avg_vol = float(np.mean(vols))

        if verbose:
            print(f"[VSPS][K*] K={int(K):3d} | gamma_proxy={gamma_proxy:.4f} | avg_vol≈{avg_vol:.4e}")

        if np.isfinite(avg_vol) and avg_vol < best_vol:
            best_vol = avg_vol
            best_K = int(K)

    if verbose:
        print(f"[VSPS][K*] selected K* = {best_K} (min avg_vol≈{best_vol:.4e})")
    return best_K


def calibrate_gamma(
    flow: ConditionalRealNVP,
    X_cal: np.ndarray,
    Y_cal: np.ndarray,
    *,
    K: int,
    alpha: float = 0.1,
    M: int = 1024,
    sort_ascending: bool = True,
    device: Optional[Union[str, torch.device]] = None,
    dtype: torch.dtype = torch.float32,
    batch_size: int = 256,
    seed: int = 42,
    verbose: bool = True,
) -> float:
    """
    Conformal calibration of gamma on an independent calibration split.
    """
    set_torch_seed(seed)
    y_samp, ld = sample_y_and_logdet_fwd(flow, X_cal, M, device=device, dtype=dtype, batch_size=batch_size)
    centers = select_topk_centers(y_samp, ld, K, sort_ascending=sort_ascending).numpy()

    scores = min_dist_to_centers(Y_cal, centers)
    gamma = conformal_quantile_higher(scores, alpha)

    if verbose:
        print(f"[VSPS][Cal] gamma={gamma:.4f} from n={scores.size} cal scores (alpha={alpha})")
    return float(gamma)


# -----------------------------
# Volume estimation (MC)
# -----------------------------

def mc_volume_union_balls(
    centers: np.ndarray,  # (K, d)
    gamma: float,
    *,
    sampling_radius: Union[float, np.ndarray],
    n_mc: int = 4096,
    seed: int = 0,
) -> float:
    """
    MC estimate of volume of union_{k<=K} B(center_k, gamma),
    using a hypercube proposal around y_center = mean(centers)
    with per-dim half-width sampling_radius.

    Returns estimated volume (in Y-space).
    """
    c = np.asarray(centers, dtype=float)
    K, d = c.shape
    r = np.asarray(sampling_radius, dtype=float)
    if r.ndim == 0:
        r = np.full((d,), float(r), dtype=float)
    assert r.shape == (d,)

    y0 = c.mean(axis=0)
    rng = np.random.default_rng(seed)
    samp = y0 + rng.uniform(low=-r, high=r, size=(n_mc, d))
    # accept if within gamma of at least one center
    dists = np.linalg.norm(samp[:, None, :] - c[None, :, :], axis=-1)  # (n_mc, K)
    accept = (dists.min(axis=1) <= float(gamma)).mean()
    ref_vol = float(np.prod(2.0 * r))
    return float(accept * ref_vol)


# -----------------------------
# End-to-end VSPS evaluation
# -----------------------------

def eval_vsps_cp(
    X_train: np.ndarray,
    Y_train: np.ndarray,
    X_cal: np.ndarray,
    Y_cal: np.ndarray,
    X_test: np.ndarray,
    Y_test: np.ndarray,
    *,
    alpha: float = 0.1,
    # Flow training
    n_layers: int = 8,
    hidden: Sequence[int] = (256, 256),
    lr: float = 2e-4,
    weight_decay: float = 1e-6,
    batch_size_train: int = 512,
    n_epochs: int = 50,
    # VSPS sampling
    M: int = 1024,
    sort_ascending: bool = True,
    # K selection
    val_frac: float = 0.5,
    K_grid: Sequence[int] = (1, 2, 4, 8, 16, 32),
    # Volume estimation
    n_mc_volume: int = 20000,
    sampling_radius: Optional[Union[float, np.ndarray]] = None,
    n_test_volume_points: int = 50,
    # Misc
    device: Optional[Union[str, torch.device]] = None,
    dtype: torch.dtype = torch.float32,
    seed: int = 42,
    verbose: bool = True,
) -> Tuple[Dict[str, float], VSPSState]:
    """
    Train flow on TRAIN; select K* on VAL; calibrate gamma on CAL; evaluate on TEST.

    Returns:
      metrics dict with keys:
        coverage, target_coverage, average_volume, log_volume_normalized
      VSPSState with trained flow and calibrated parameters.
    """
    set_torch_seed(seed)
    device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))

    # 1) Train conditional flow on TRAIN
    flow = train_vsps_flow(
        X_train, Y_train,
        n_layers=n_layers, hidden=hidden,
        lr=lr, weight_decay=weight_decay,
        batch_size=batch_size_train, n_epochs=n_epochs,
        device=device, dtype=dtype, seed=seed, verbose=verbose,
    )
    
    # 2) Split provided CAL into (VAL for K*) and (CONF-CAL for gamma)
    X_val, Y_val, X_confcal, Y_confcal = train_val_split(X_cal, Y_cal, val_frac=val_frac, seed=seed)

    # Convert validation data to torch tensors for diagnostic
    X_val_torch = _to_torch(X_val, device=device, dtype=dtype)
    Y_val_torch = _to_torch(Y_val, device=device, dtype=dtype)
    
    with torch.no_grad():
        z_sample, logdet = flow.forward(Y_val_torch, X_val_torch)
        # z_sample devrait être ~N(0,I)
        print(f"Latent mean: {z_sample.mean(0)}")
        print(f"Latent std: {z_sample.std(0)}")

    # 3) Choose K* using validation split
    K_star = choose_k_star(
        flow, X_val, Y_val,
        K_grid=K_grid, alpha=alpha,
        M=M, sort_ascending=sort_ascending,
        n_mc_volume=max(1024, n_mc_volume // 2),
        sampling_radius=sampling_radius,
        eval_points=min(n_test_volume_points, max(32, X_val.shape[0])),
        device=device, dtype=dtype,
        seed=seed, verbose=verbose,
    )

    # 4) Calibrate gamma on confcal split
    gamma = calibrate_gamma(
        flow, X_confcal, Y_confcal,
        K=K_star, alpha=alpha, M=M,
        sort_ascending=sort_ascending,
        device=device, dtype=dtype,
        batch_size=256,
        seed=seed, verbose=verbose,
    )
    if gamma > np.percentile(np.std(Y_train, axis=0), 90):
        warnings.warn(f"gamma={gamma:.3f} semble très grand")
    # 5) Evaluate on TEST: coverage
    y_samp_test, ld_test = sample_y_and_logdet_fwd(flow, X_test, M, device=device, dtype=dtype, batch_size=128)
    centers_test = select_topk_centers(y_samp_test, ld_test, K_star, sort_ascending=sort_ascending).numpy()
    test_scores = min_dist_to_centers(Y_test, centers_test)
    covered = (test_scores <= gamma).astype(float)
    coverage = float(np.mean(covered))

    # 6) Estimate average volume on a subset of test points
    rng = np.random.default_rng(seed)
    nT = X_test.shape[0]
    sub = min(int(n_test_volume_points), nT)
    idx = rng.choice(nT, size=sub, replace=False)

    d_y = Y_test.shape[1]
    y_std = np.std(Y_train, axis=0)
    sampling_radius = 3 * y_std

    vols = []
    for j, i in enumerate(idx):
        vols.append(
            mc_volume_union_balls(
                centers_test[i],
                gamma,
                sampling_radius=sampling_radius,
                n_mc=n_mc_volume,
                seed=seed + 10_000 + j,
            )
        )
    avg_vol = float(np.mean(vols)) if len(vols) else float("nan")
    log_vol_norm = float(np.log(avg_vol) / d_y) if (avg_vol > 0.0 and np.isfinite(avg_vol)) else float("-inf")

    metrics = {
        "coverage": coverage,
        "target_coverage": float(1.0 - alpha),
        "average_volume": avg_vol,
        "log_volume_normalized": log_vol_norm,
    }
    state = VSPSState(flow=flow, K_star=int(K_star), gamma=float(gamma), sort_ascending=bool(sort_ascending), M_samples=int(M))

    if verbose:
        print(f"[VSPS][Test] coverage={coverage:.4f} (target={1.0-alpha:.4f}) | avg_vol≈{avg_vol:.4e} | log_vol_norm={log_vol_norm:.4f}")

    return metrics, state


# -----------------------------
# Quick sanity check (optional)
# -----------------------------
if __name__ == "__main__":
    # Tiny synthetic sanity check
    set_torch_seed(42)
    n = 2000
    x_dim = 5
    y_dim = 3
    X = np.random.randn(n, x_dim).astype(np.float32)
    # simple conditional Gaussian with nonlinear mean
    mu = np.tanh(X @ (np.random.randn(x_dim, y_dim) * 0.5)).astype(np.float32)
    Y = mu + 0.5 * np.random.randn(n, y_dim).astype(np.float32)

    X_tr, Y_tr = X[:1200], Y[:1200]
    X_cal, Y_cal = X[1200:1600], Y[1200:1600]
    X_te, Y_te = X[1600:], Y[1600:]

    metrics, _ = eval_vsps_cp(
        X_tr, Y_tr, X_cal, Y_cal, X_te, Y_te,
        alpha=0.1,
        n_epochs=20,
        M=256,
        K_grid=(1, 2, 4, 8),
        n_mc_volume=1024,
        n_test_volume_points=64,
        verbose=True,
    )
    print(metrics)
