# losses.py
import torch
from typing import Literal, Optional

Tensor = torch.Tensor

# =============================================================
# 0) Optional GeomLoss import (keep file usable without it)
# =============================================================
from geomloss import SamplesLoss
import torch.nn as nn
from typing import Any, Dict, Optional, Callable, List, Union, Tuple
from mechanics import pick_integrator
from potential_energy_models import make_accel_from_potential


# =============================================================
# 2) GeomLoss (cached objects)
# =============================================================
_SINKHORN_CACHE = {}
_GAUSSIAN_CACHE = {}
_ENERGY_CACHE = {}


def _get_geomloss_sinkhorn(
    *,
    p: int,
    blur: float,
    scaling: float,
    debias: bool,
    backend: str,
):
    key = ("sinkhorn", p, float(blur), float(scaling), bool(debias), str(backend))
    obj = _SINKHORN_CACHE.get(key, None)
    if obj is None:
        obj = SamplesLoss(
            loss="sinkhorn",
            p=p,
            blur=blur,
            scaling=scaling,
            debias=debias,
            backend=backend,
        )
        _SINKHORN_CACHE[key] = obj
    return obj


def _get_geomloss_gaussian(*, blur: float, backend: str):
    key = ("gaussian", float(blur), str(backend))
    obj = _GAUSSIAN_CACHE.get(key, None)
    if obj is None:
        obj = SamplesLoss(
            loss="gaussian",
            blur=blur,
            backend=backend,
        )
        _GAUSSIAN_CACHE[key] = obj
    return obj


def _get_geomloss_energy(*, p: int, backend: str):
    """
    GeomLoss 'energy' corresponds to an energy distance / MMD-like loss.
    It does NOT use blur; it depends on p (typically 1 or 2).
    """
    key = ("energy", int(p), str(backend))
    obj = _ENERGY_CACHE.get(key, None)
    if obj is None:
        obj = SamplesLoss(
            loss="energy",
            p=int(p),
            backend=backend,
        )
        _ENERGY_CACHE[key] = obj
    return obj


def _default_geomloss_backend(x: Tensor, y: Tensor) -> str:
    d = x.shape[1]
    nmax = max(x.shape[0], y.shape[0])

    # RELAXED LIMIT: Allow d <= 10 or even higher
    if d <= 20 and nmax <= 6000:
        return "tensorized"

    return "online"


# =============================================================
# 3) Unified interface (ONE compute_loss)
# =============================================================
LossType = Literal[
    "mmd", "sw2", "sinkhorn",
    "geom_sinkhorn", "geom_gaussian", "geom_energy"
]




def compute_loss(
    x: Tensor,
    y: Tensor,
    loss_type: str = "geom_sinkhorn",
    *,
    p: int = 2,
    blur: float = 0.2,
    scaling: float = 0.9,
    debias: bool = True,
    backend: Optional[str] = None,
    eps: float = 1e-12,
) -> Tensor:
    if x.ndim != 2 or y.ndim != 2:
        raise ValueError(f"x,y must be (N,d) and (M,d). Got {x.shape=} {y.shape=}")
    if x.shape[1] != y.shape[1]:
        raise ValueError(f"Dim mismatch: {x.shape[1]} vs {y.shape[1]}")

    if not loss_type.startswith("geom_"):
        raise ValueError(f"Unknown or unsupported loss_type for this patch: {loss_type}")

    if backend is None:
        backend = _default_geomloss_backend(x, y)

    if loss_type == "geom_sinkhorn":
        loss_fn = _get_geomloss_sinkhorn(p=p, blur=blur, scaling=scaling, debias=debias, backend=backend)
    elif loss_type == "geom_gaussian":
        loss_fn = _get_geomloss_gaussian(blur=blur, backend=backend)
    elif loss_type == "geom_energy":
        loss_fn = _get_geomloss_energy(p=p, backend=backend)
    else:
        raise ValueError(f"Unknown loss_type: {loss_type}")

    # Valid rows
    x_valid = torch.isfinite(x).all(dim=1)
    y_valid = torch.isfinite(y).all(dim=1)

    if not x_valid.any():
        raise ValueError("compute_loss: all rows in x are NaN/Inf.")
    if not y_valid.any():
        raise ValueError("compute_loss: all rows in y are NaN/Inf.")

    x = x[x_valid]
    y = y[y_valid]

    # Uniform weights on filtered supports
    a = torch.full((x.shape[0],), 1.0 / (x.shape[0] + eps), dtype=x.dtype, device=x.device)
    b = torch.full((y.shape[0],), 1.0 / (y.shape[0] + eps), dtype=y.dtype, device=y.device)

    return loss_fn(a, x, b, y)

import ot
import numpy as np


def get_w1(M, w_x=None, w_y=None):
    def get_w(w, n):
        if w is None:
            w = np.ones(n)
        # Fix: Ensure tensor weights are moved to CPU
        if isinstance(w, torch.Tensor):
            w = w.detach().cpu()
        w = np.array(w).astype(np.float64)
        w /= w.sum()
        return w

    # Fix: Ensure tensor Cost Matrix is moved to CPU
    if isinstance(M, torch.Tensor):
        M = M.detach().cpu()

    M = np.array(M).astype(np.float64)
    w_x, w_y = get_w(w_x, M.shape[0]), get_w(w_y, M.shape[1])
    return ot.emd2(w_x, w_y, M, numItermax=1e7)


@torch.no_grad()
def rollout_pred_macro(
        *,
        model: torch.nn.Module,
        X_em: torch.Tensor,
        time_grid: torch.Tensor,
        dt_base: float,
        substeps_per_dt: int,
        integrator_name: str,
        friction: Any,
        vel_provider,
        vel_mode: str,
        V_em: Optional[torch.Tensor],
        max_force: Optional[float],
        p0_idx: int,
        idx: Optional[torch.Tensor],
        # NEW ARGUMENT
        dt_integration: Optional[float] = None
) -> torch.Tensor:
    """
    Returns X_pred_macro: (n_eval, T+1, d) on device.
    Uses dt_integration if provided, otherwise derives from substeps_per_dt.
    """
    X_gt = X_em[p0_idx]
    X_gt = X_gt if idx is None else X_gt[idx]
    x0 = X_gt[:, 0, :].detach()
    t0 = float(time_grid[0].item())

    # --- v0 resolution ---
    vel_mode_l = str(vel_mode).lower()
    if vel_provider is None or vel_mode_l == "zero":
        v0 = torch.zeros_like(x0)
    elif vel_mode_l == "bundle":
        if V_em is None: raise ValueError("rollout_pred_macro: vel_mode='bundle' requires V_em.")
        m = int(torch.argmin((time_grid - float(t0)).abs()).item())
        m = max(0, min(m, time_grid.numel() - 1))
        v0_full = V_em[int(p0_idx), :, m, :]
        v0 = v0_full if idx is None else v0_full[idx]
        v0 = v0.to(device=x0.device, dtype=x0.dtype)
    else:
        v0 = vel_provider(p0_idx, x0, t0).detach()

    # --- Determine DT and Substeps ---
    if dt_integration is not None and dt_integration > 0:
        dt_micro = float(dt_integration)
        # How many micro steps per macro (base) step?
        ratio = float(dt_base) / dt_micro
        steps_per_macro = int(round(ratio))
        if abs(ratio - steps_per_macro) > 1e-3:
            print(f"Warning: dt_base ({dt_base}) is not a multiple of dt_sim ({dt_micro}). Drift may occur.")
    else:
        dt_micro = float(dt_base) / int(substeps_per_dt)
        steps_per_macro = int(substeps_per_dt)

    total_micro = (X_gt.shape[1] - 1) * steps_per_macro

    integrator, _ = pick_integrator(str(integrator_name))
    accel_eval = make_accel_from_potential(model, create_graph=False, max_force=max_force)

    # --- Integrate ---
    X_pred = integrator(
        x0=x0, v0=v0, accel=accel_eval, dt=dt_micro, steps=int(total_micro),
        friction=friction, return_all=True, t_start=float(t0),
    )

    # --- Subsample to Macro Time Grid ---
    steps_macro = X_gt.shape[1] - 1
    # Gather indices: 0, 1*SPM, 2*SPM ...
    macro_idx = (torch.arange(0, steps_macro + 1, device=x0.device) * steps_per_macro).long()

    # Safety clamp
    macro_idx = torch.clamp(macro_idx, max=X_pred.shape[1] - 1)

    return X_pred[:, macro_idx, :]

from sklearn.metrics import pairwise_distances

@torch.no_grad()
def compute_train_forecast_w1(
        *,
        model: torch.nn.Module,
        X_em: torch.Tensor,
        time_grid: torch.Tensor,
        dt_base: float,
        substeps_per_dt: int,
        integrator_name: str,
        friction: Any,
        kernel_bw2: float,  # Unused but kept for signature compatibility
        geom_scaling: float,  # Unused but kept for signature compatibility
        geom_debias: bool,  # Unused but kept for signature compatibility
        geom_backend: str,  # Unused but kept for signature compatibility
        vel_provider,
        vel_mode: str,
        V_em: Optional[torch.Tensor],
        max_force: Optional[float],
        particles_eval: Optional[int],
        T_train_plus_1: int,
        dt_integration: Optional[float] = None,
        # NEW PARAMETERS
        eval_mode: str = "forecast",
        holdout_idx: Optional[int] = None,
) -> Dict[str, float]:
    """
    Compute W1 metrics with standard error and interpolation support.

    Args:
        model: Neural network model
        X_em: Ground truth positions (num_p0, N, T+1, d)
        time_grid: Time points
        dt_base: Base timestep from data
        substeps_per_dt: Number of substeps per dt_base (if dt_integration not provided)
        integrator_name: Integration scheme ('verlet', 'euler', etc.)
        friction: Friction parameter
        kernel_bw2: (unused, kept for compatibility)
        geom_scaling: (unused, kept for compatibility)
        geom_debias: (unused, kept for compatibility)
        geom_backend: (unused, kept for compatibility)
        vel_provider: Velocity provider function
        vel_mode: Velocity mode ('zero', 'bundle', 'analytic')
        V_em: Velocity data (if vel_mode='bundle')
        max_force: Maximum force clipping
        particles_eval: Number of particles to evaluate (None = all)
        T_train_plus_1: Number of training time steps (for forecast mode)
        dt_integration: Override integration timestep (if None, uses dt_base/substeps_per_dt)
        eval_mode: "forecast" (train on first portion) or "interpolate" (leave-one-out)
        holdout_idx: Required if eval_mode="interpolate"

    Returns:
        Dict containing:
            - train_w1_avg: Mean W1 on training marginals
            - train_w1_se: Standard error on training marginals
            - train_w1_std: Standard deviation on training marginals
            - train_count: Number of (p0, marginal) pairs evaluated
            - forecast_w1_avg: Mean W1 on test marginals
            - forecast_w1_se: Standard error on test marginals
            - forecast_w1_std: Standard deviation on test marginals
            - forecast_count: Number of (p0, marginal) pairs evaluated
    """
    from mechanics import pick_integrator
    from potential_energy_models import make_accel_from_potential

    num_p0, N, T_total_plus_1, _ = X_em.shape

    # Determine train and test marginals based on mode
    if eval_mode == "forecast":
        # Forecast mode: train on [1, T_train_plus_1), test on [T_train_plus_1, T_total)
        train_ms = list(range(1, int(T_train_plus_1)))
        test_ms = list(range(int(T_train_plus_1), int(T_total_plus_1)))

        print(f"\n[W1 Eval - Forecast Mode]")
        print(f"  Train marginals: indices {min(train_ms) if train_ms else 'N/A'} to {max(train_ms) if train_ms else 'N/A'} ({len(train_ms)} total)")
        print(f"  Forecast marginals: indices {min(test_ms) if test_ms else 'N/A'} to {max(test_ms) if test_ms else 'N/A'} ({len(test_ms)} total)")

    elif eval_mode == "interpolate":
        # Interpolate mode: test on holdout, train on everything else
        if holdout_idx is None:
            raise ValueError("holdout_idx required for interpolate mode")

        train_ms = [i for i in range(1, int(T_total_plus_1)) if i != holdout_idx]
        test_ms = [holdout_idx]

        print(f"\n[W1 Eval - Interpolate Mode]")
        print(f"  Holdout marginal: index {holdout_idx} (t={time_grid[holdout_idx].item():.2f})")
        print(f"  Train marginals: all others ({len(train_ms)} total)")

    else:
        raise ValueError(f"Unknown eval_mode: {eval_mode}. Must be 'forecast' or 'interpolate'")

    print(f"  Initial conditions (p0): {num_p0}")
    print(f"  Particles per trajectory: {N if particles_eval is None else f'{particles_eval} (subsampled)'}")

    train_vals = []
    test_vals = []

    # Integration parameters
    if dt_integration is not None and dt_integration > 0:
        dt_micro = float(dt_integration)
        steps_per_macro = int(round(float(dt_base) / dt_micro))
        print(f"  dt_micro: {dt_micro:.4f} (explicit, steps_per_macro={steps_per_macro})")
    else:
        dt_micro = float(dt_base) / int(substeps_per_dt)
        steps_per_macro = int(substeps_per_dt)
        print(f"  dt_micro: {dt_micro:.4f} (from substeps={substeps_per_dt})")

    integrator, _ = pick_integrator(str(integrator_name))

    # Iterate over initial conditions
    for p0_idx in range(int(num_p0)):
        # Optionally subsample particles
        idx = None
        if particles_eval is not None and int(particles_eval) < int(N):
            idx = torch.randint(0, int(N), (int(particles_eval),), device=X_em.device)

        X_gt = X_em[p0_idx]
        X_gt = X_gt if idx is None else X_gt[idx]

        # Initial conditions
        x0 = X_gt[:, 0, :].detach()
        t0 = float(time_grid[0].item())

        # Get initial velocity
        vel_mode_l = str(vel_mode).lower()
        if vel_provider is None or vel_mode_l == "zero":
            v0 = torch.zeros_like(x0)
        elif vel_mode_l == "bundle":
            if V_em is None:
                raise ValueError("vel_mode='bundle' requires V_em")
            m = int(torch.argmin((time_grid - float(t0)).abs()).item())
            m = max(0, min(m, time_grid.numel() - 1))
            v0_full = V_em[int(p0_idx), :, m, :]
            v0 = v0_full if idx is None else v0_full[idx]
            v0 = v0.to(device=x0.device, dtype=x0.dtype)
        else:
            v0 = vel_provider(p0_idx, x0, t0).detach()

        # Simulate full trajectory
        total_micro = (X_gt.shape[1] - 1) * steps_per_macro
        accel_eval = make_accel_from_potential(model, create_graph=False, max_force=max_force)

        X_pred = integrator(
            x0=x0, v0=v0, accel=accel_eval, dt=dt_micro, steps=int(total_micro),
            friction=friction, return_all=True, t_start=float(t0),
        )

        # Downsample to macro time grid
        steps_macro = X_gt.shape[1] - 1
        macro_idx = (torch.arange(0, steps_macro + 1, device=x0.device) * steps_per_macro).long()
        macro_idx = torch.clamp(macro_idx, max=X_pred.shape[1] - 1)
        X_pred_macro = X_pred[:, macro_idx, :]

        # Convert to numpy for W1 computation
        X_pred_np = X_pred_macro.detach().cpu().double().numpy()
        X_gt_np = X_gt.detach().cpu().double().numpy()

        # Compute W1 for train marginals
        for m in train_ms:
            if m >= X_pred_np.shape[1] or m >= X_gt_np.shape[1]:
                print(f"  Warning: Skipping train marginal {m} (out of bounds)")
                continue
            M = pairwise_distances(X_pred_np[:, m, :], X_gt_np[:, m, :], metric='euclidean')
            val = get_w1(M)
            train_vals.append(val)

        # Compute W1 for test marginals
        for m in test_ms:
            if m >= X_pred_np.shape[1] or m >= X_gt_np.shape[1]:
                print(f"  Warning: Skipping test marginal {m} (out of bounds)")
                continue
            M = pairwise_distances(X_pred_np[:, m, :], X_gt_np[:, m, :], metric='euclidean')
            val = get_w1(M)
            test_vals.append(val)

        del X_pred, X_pred_macro

    # Compute statistics
    out: Dict[str, float] = {}

    # Train statistics
    if len(train_vals) > 0:
        train_arr = np.array(train_vals)
        out["train_w1_avg"] = float(np.mean(train_arr))
        out["train_w1_se"] = float(np.std(train_arr, ddof=1) / np.sqrt(len(train_arr)))  # Standard Error
        out["train_w1_std"] = float(np.std(train_arr, ddof=1))  # Standard deviation
        out["train_count"] = len(train_vals)
    else:
        out["train_w1_avg"] = float("nan")
        out["train_w1_se"] = float("nan")
        out["train_w1_std"] = float("nan")
        out["train_count"] = 0

    # Forecast statistics
    if len(test_vals) > 0:
        test_arr = np.array(test_vals)
        out["forecast_w1_avg"] = float(np.mean(test_arr))
        out["forecast_w1_se"] = float(np.std(test_arr, ddof=1) / np.sqrt(len(test_arr)))  # Standard Error
        out["forecast_w1_std"] = float(np.std(test_arr, ddof=1))  # Standard deviation
        out["forecast_count"] = len(test_vals)
    else:
        out["forecast_w1_avg"] = float("nan")
        out["forecast_w1_se"] = float("nan")
        out["forecast_w1_std"] = float("nan")
        out["forecast_count"] = 0

    # Print summary
    print(f"  → Train W1: {out['train_w1_avg']:.6f} ± {out['train_w1_se']:.6f} (n={out['train_count']})")
    print(f"  → Forecast W1: {out['forecast_w1_avg']:.6f} ± {out['forecast_w1_se']:.6f} (n={out['forecast_count']})")

    return out

