# rollout.py

import torch
import torch.nn.functional as F
from torch import Tensor
from typing import Optional, Callable, Literal, Union, Dict
import torch.nn as nn
from mechanics import leapfrog_auto, x_verlet_auto, resolve_gamma
from losses import compute_loss  # assumed to exist
import numpy as np

# ============================================================
# 1. Utility: Marginal extractor
# ============================================================
from typing import Optional, Tuple
import torch
from torch import Tensor


def get_margs(
    X_em_torch: Tensor,
    time_grid: Tensor,
    p0_idx: int,
    t_idx: int,
    n_sub: Optional[int] = None,
    n_ahead: int = 1,
) -> Tuple[Tensor, Tensor, Tensor, int]:
    device = X_em_torch.device
    T_plus_1 = X_em_torch.shape[2]
    T = T_plus_1 - 1

    if not (0 <= int(t_idx) <= T - 1):
        raise ValueError(f"t_idx out of range: {t_idx} (valid: 0..{T-1})")

    L = int(min(int(n_ahead), T - int(t_idx)))
    t0 = time_grid[t_idx]

    # No future steps requested/available
    if L <= 0:
        x_t_full = X_em_torch[p0_idx, :, t_idx, :]
        x = x_t_full[torch.isfinite(x_t_full).all(dim=1)]
        if n_sub is not None and 0 < int(n_sub) < x.shape[0]:
            idx = torch.randperm(x.shape[0], device=device)[: int(n_sub)]
            x = x[idx]
        Y_future = x.new_empty((0, x.shape[0], x.shape[1]))
        return x, Y_future, t0, 0

    # --- slice current + future (raw, may contain NaNs) ---
    x_t_full = X_em_torch[p0_idx, :, t_idx, :]  # (N,d)
    Ys_full = [X_em_torch[p0_idx, :, t_idx + j, :] for j in range(1, L + 1)]
    Y_future_full = torch.stack(Ys_full, dim=0)  # (L,N,d)

    # --- filter each time independently (NO cross-time identity) ---
    x = x_t_full[torch.isfinite(x_t_full).all(dim=1)]  # (Nx,d)

    Ys = []
    counts = [int(x.shape[0])]
    for j in range(L):
        yj = Y_future_full[j]
        yj = yj[torch.isfinite(yj).all(dim=1)]  # (Nyj,d)
        Ys.append(yj)
        counts.append(int(yj.shape[0]))

    n_eff = min(counts) if n_sub is None else min(int(n_sub), *counts)
    if n_eff <= 0:
        return x.new_empty((0, x.shape[1])), x.new_empty((0, 0, x.shape[1])), t0, 0

    # --- subsample (prefer without replacement) ---
    if n_eff < x.shape[0]:
        idx = torch.randperm(x.shape[0], device=device)[:n_eff]
    else:
        idx = torch.arange(x.shape[0], device=device)
    x_t = x[idx]

    Y_future = torch.stack(
        [
            Yj[torch.randperm(Yj.shape[0], device=device)[:n_eff]] if n_eff < Yj.shape[0] else Yj
            for Yj in Ys
        ],
        dim=0,
    )  # (L, n_eff, d)

    return x_t, Y_future, t0, L

# class LearnableFriction(nn.Module):
#     """
#     gamma(raw) = eps + softplus(raw), always > eps.
#     If init is too small, initialize at min_gamma to avoid raw -> -inf and dead grads.
#     """
#     def __init__(self, init: float, eps: float = 1e-6, min_gamma: float = 1e-2):
#         super().__init__()
#         self.eps = float(eps)
#         self.min_gamma = float(min_gamma)
#
#         init = float(init)
#
#         gamma0 = max(init, self.min_gamma)
#         y = max(gamma0 - self.eps, 1e-12)
#
#         # Inverse softplus, stable
#         if y > 20.0:
#             raw = y
#         else:
#             raw = float(np.log(np.expm1(y) + 1e-12))
#
#         self.raw = nn.Parameter(torch.tensor(raw, dtype=torch.float32))
#
#     def friction_tensor(self) -> torch.Tensor:
#         return self.eps + F.softplus(self.raw)
#
#     def value(self) -> float:
#         return float(self.friction_tensor().detach().cpu().item())
class LearnableFriction(nn.Module):
    """
    Learn theta = log(gamma) with gamma = exp(theta). No clipping.
    """
    def __init__(self, init_gamma: float, gamma_init_floor: float = 1e-2):
        super().__init__()
        init_gamma = float(init_gamma)
        gamma0 = init_gamma if init_gamma > 0.0 else float(gamma_init_floor)
        self.theta = nn.Parameter(torch.tensor(np.log(gamma0), dtype=torch.float32))

    def friction_tensor(self) -> torch.Tensor:
        # gamma = exp(theta).
        return torch.exp(self.theta)

    def value(self) -> float:
        return float(self.friction_tensor().detach().cpu().item())

def _microsteps_between(t0, t1, dt_base: float, substeps_per_dt: int) -> int:
    # number of micro-steps to advance from physical time t0 -> t1
    delta = float((t1 - t0).detach().cpu().item())
    if delta < 0:
        raise ValueError(f"Non-monotone time_grid: {t0} -> {t1}")
    return int(round(delta / float(dt_base) * int(substeps_per_dt)))


from typing import Optional, Callable, Dict, Literal, Union
import torch
from torch import Tensor

def train_rollout_anchor_p0_randk(
        X_em_torch: Tensor,
        time_grid: Tensor,
        accel_train: Callable[[Tensor, float], Tensor],
        optimizer: torch.optim.Optimizer,
        dt_base: float,
        *,
        num_epochs: int,
        max_train_steps: int,
        substeps_per_dt: int,
        kernel_bw2: float,
        loss_type: Literal["geom_sinkhorn", "geom_gaussian"] = "geom_sinkhorn",
        particles_per_batch: Optional[int] = None,
        vel_provider: Optional[Callable[[int, Tensor, float], Tensor]] = None,
        friction: Union[float, Tensor] = 0.0,
        debug: bool = False,
        name: str = "",
        verlet: str = "v",
        geom_p: int = 2,
        geom_scaling: float = 0.9,
        geom_debias: bool = True,
        geom_backend: Optional[str] = None,
        epoch_callback: Optional[Callable[[int, int, float, float], None]] = None,
        # epoch_callback(epoch, k, loss, friction_value)
):
    """
    Random-horizon rollout training anchored at p0.

    Each epoch:
      - sample k uniformly from {1, ..., K_max}
      - sample a population p0_idx
      - roll out from t=0 to t=k
      - compute mean loss across marginals m=1..k

    Returns dict with average and last loss.
    """
    device = X_em_torch.device
    num_p0, N, T_plus_1, d = X_em_torch.shape
    steps = T_plus_1 - 1  # macro steps in the bundle
    dt_train = float(dt_base) / int(substeps_per_dt)

    # horizon cap
    K_max = int(min(max_train_steps, steps))
    if K_max <= 0:
        raise ValueError(f"max_train_steps too small: max_train_steps={max_train_steps}, steps={steps}")

    friction_use = friction
    gamma0 = resolve_gamma(friction_use)
    is_learnable = bool(isinstance(gamma0, torch.Tensor) and gamma0.requires_grad)

    # choose integrator
    integrator = x_verlet_auto if verlet == "x" else leapfrog_auto

    init_val = float(gamma0.detach().item()) if isinstance(gamma0, torch.Tensor) else float(gamma0)
    print(
        f"[{name}] Anchor p0 + random k | dt={dt_train:.6f}, "
        f"Friction={init_val:.4g} (Learnable={is_learnable}), verlet={verlet}"
    )

    loss_sum = 0.0
    loss_count = 0
    last_loss_val = None

    t0_val = float(time_grid[0].item())

    for epoch in range(num_epochs):
        optimizer.zero_grad(set_to_none=True)

        # sample random horizon k in {1,...,K_max}
        k = int(torch.randint(1, K_max + 1, (1,), device=device).item())

        # choose population index
        p0_idx = int(torch.randint(0, num_p0, (1,), device=device).item())

        # robustly extract x0 and future Y_1..Y_k with consistent NaN filtering
        x0, Y_future, t0, L = get_margs(
            X_em_torch, time_grid,
            p0_idx=p0_idx,
            t_idx=0,
            n_sub=particles_per_batch,
            n_ahead=k,
        )
        # L should equal k unless you’re near the end, but at t_idx=0 it should be k.
        if L <= 0 or x0.numel() == 0:
            # nothing valid this epoch (all NaN for this pop / times); skip cleanly
            continue

        # initial velocity
        if vel_provider is not None:
            v0 = vel_provider(p0_idx, x0, t0_val)
        else:
            v0 = torch.zeros_like(x0)

        # rollout from t=0 -> t=k (microsteps)
        total_micro_steps = _microsteps_between(time_grid[0], time_grid[L], dt_base, substeps_per_dt)

        X_all_pred = integrator(
            x0=x0,
            v0=v0,
            accel=accel_train,
            dt=dt_train,
            steps=total_micro_steps,
            friction=resolve_gamma(friction_use),
            return_all=True,
            t_start=t0_val,
        )

        # loss across marginals m=1..L (mean so scale is comparable across varying k)
        per_step_losses = []
        for m in range(1, L + 1):
            t_m = time_grid[m]
            micro_idx = _microsteps_between(time_grid[0], t_m, dt_base, substeps_per_dt)
            micro_idx = min(micro_idx, total_micro_steps)

            x_pred_m = X_all_pred[:, micro_idx, :]
            x_gt_m = Y_future[m - 1]

            loss_m = compute_loss(
                x_pred_m, x_gt_m,
                loss_type=loss_type,
                blur=float(kernel_bw2),
                p=geom_p,
                scaling=geom_scaling,
                debias=geom_debias,
                backend=geom_backend,
            )
            per_step_losses.append(loss_m)

        loss = torch.stack(per_step_losses, dim=0).mean()

        last_loss_val = float(loss.detach().item())
        loss_sum += last_loss_val
        loss_count += 1

        loss.backward()
        optimizer.step()

        if epoch_callback is not None:
            gamma_now = resolve_gamma(friction_use)
            fric_val = float(gamma_now.detach().item()) if isinstance(gamma_now, torch.Tensor) else float(gamma_now)
            epoch_callback(epoch, int(L), last_loss_val, fric_val)

        if debug or ((epoch + 1) % max(1, num_epochs // 10) == 0):
            gamma_now = resolve_gamma(friction_use)
            fric_val = float(gamma_now.detach().item()) if isinstance(gamma_now, torch.Tensor) else float(gamma_now)
            print(f"[{name}] Ep {epoch+1:04d}/{num_epochs} | k={int(L):02d} | loss={last_loss_val:.4e} | fric={fric_val:.4g}")

        del X_all_pred
        if torch.cuda.is_available() and (epoch + 1) % 25 == 0:
            torch.cuda.empty_cache()

    return {
        "train_loss_avg": loss_sum / max(1, loss_count),
        "train_loss_last": float(last_loss_val) if last_loss_val is not None else float("nan"),
        "num_effective_epochs": int(loss_count),
    }

# ============================================================
# 2. Uniform time-sampling rollout trainer
# ============================================================
def train_rollout_uniform(
        X_em_torch: Tensor,
        time_grid: Tensor,
        accel_train: Callable[[Tensor, float], Tensor],
        optimizer: torch.optim.Optimizer,
        dt_base: float,
        *,
        num_epochs: int,
        max_train_steps: int,
        substeps_per_dt: int,
        kernel_bw2: float,
        loss_type: Literal["geom_sinkhorn", "geom_gaussian"] = "geom_sinkhorn",
        particles_per_batch: Optional[int] = None,
        vel_provider: Optional[Callable[[int, Tensor, float], Tensor]] = None,
        friction: Union[float, Tensor] = 0.0,
        debug: bool = False,
        name: str = "",
        verlet: str = "v",
        geom_p: int = 2,
        geom_scaling: float = 0.9,
        geom_debias: bool = True,
        geom_backend: Optional[str] = None,
        epoch_callback: Optional[Callable[[int, float, float], None]] = None,  # NEW
):
    """
    Train with uniform sampling over time steps.

    NEW Args:
        epoch_callback: Optional callback(epoch, loss, friction) called after each epoch
    """
    device = X_em_torch.device
    num_p0, N, T_plus_1, d = X_em_torch.shape
    steps = T_plus_1 - 1
    dt_train = dt_base / substeps_per_dt

    friction_use = friction
    gamma0 = resolve_gamma(friction_use)
    is_learnable = gamma0.requires_grad if isinstance(gamma0, torch.Tensor) else False

    # choose integrator
    if verlet == "x":
        integrator = x_verlet_auto
    else:
        integrator = leapfrog_auto

    # initial friction value for printing
    if isinstance(gamma0, torch.Tensor):
        init_val = float(gamma0.detach().item())
    else:
        init_val = float(gamma0)

    print(
        f"[{name}] dt={dt_train:.4f}, Friction={init_val:.2f} (Learnable={is_learnable}), "
        f"verlet={verlet}"
    )

    # chunk-level tracking
    chunk_loss_sum = 0.0
    chunk_loss_count = 0
    last_loss_val = None

    for epoch in range(num_epochs):
        optimizer.zero_grad()

        # 1) sample a starting time t_idx
        t_idx = int(torch.randint(0, steps, (1,), device=device).item())

        # effective horizon (macro steps)
        K_eff = min(max_train_steps, steps - t_idx)
        if K_eff <= 0:
            continue

        total_loss = 0.0
        num_used_p0 = 0

        # 2) choose populations
        p0_indices = [int(torch.randint(0, num_p0, (1,), device=device).item())]

        for p0_idx in p0_indices:
            x_t, Y_future, t0_val, L = get_margs(
                X_em_torch, time_grid,
                p0_idx=p0_idx,
                t_idx=t_idx,
                n_sub=particles_per_batch,
                n_ahead=K_eff,
            )

            # v0 at the starting time
            if vel_provider is not None:
                v0 = vel_provider(p0_idx, x_t, float(t0_val))
            else:
                v0 = torch.zeros_like(x_t)

            t0 = time_grid[t_idx]
            t_end = time_grid[t_idx + L]
            total_micro_steps = _microsteps_between(t0, t_end, dt_base, substeps_per_dt)

            X_all_pred = integrator(
                x0=x_t,
                v0=v0,
                accel=accel_train,
                dt=dt_train,
                steps=total_micro_steps,
                friction=resolve_gamma(friction_use),
                return_all=True,
                t_start=t0_val,  # this equals float(t0) if your bundle is consistent
            )

            per_step_losses = []
            for m in range(1, L + 1):
                t_m = time_grid[t_idx + m]
                micro_idx = _microsteps_between(t0, t_m, dt_base, substeps_per_dt)

                # safety: rounding can make micro_idx==total_micro_steps+1 occasionally
                micro_idx = min(micro_idx, total_micro_steps)

                x_pred_m = X_all_pred[:, micro_idx, :]
                x_gt_m = Y_future[m - 1]
                loss_m = compute_loss(
                    x_pred_m, x_gt_m,
                    loss_type=loss_type,
                    blur=float(kernel_bw2),
                    p=geom_p,
                    scaling=geom_scaling,
                    debias=geom_debias,
                    backend=geom_backend,
                )
                per_step_losses.append(loss_m)

            loss_p0 = torch.stack(per_step_losses).mean()
            total_loss += loss_p0
            num_used_p0 += 1

            del X_all_pred

        loss = total_loss / float(num_used_p0)

        # accumulate
        last_loss_val = float(loss.detach().item())
        chunk_loss_sum += last_loss_val
        chunk_loss_count += 1

        loss.backward()
        optimizer.step()

        # NEW: Per-epoch callback for fine-grained logging
        if epoch_callback is not None:
            gamma_log = resolve_gamma(friction_use)
            cur_fric = float(gamma_log.detach().item()) if isinstance(gamma_log, torch.Tensor) else float(gamma_log)

            epoch_callback(epoch, last_loss_val, cur_fric)

        if (epoch + 1) % max(1, num_epochs // 10) == 0 or debug:
            gamma_log = resolve_gamma(friction_use)
            cur_fric = float(gamma_log.detach().item()) if isinstance(gamma_log, torch.Tensor) else float(gamma_log)

            print(
                f"[{name}] Ep {epoch + 1:03d} | t={t_idx}, K={K_eff} "
                f"| loss={loss.item():.4e}, fric={cur_fric:.3f}"
            )

        if torch.cuda.is_available() and (epoch + 1) % 25 == 0:
            torch.cuda.empty_cache()

    return {
        "train_loss_avg": chunk_loss_sum / max(1, chunk_loss_count),
        "train_loss_last": float(last_loss_val) if last_loss_val is not None else float("nan"),
    }


def train_rollout_incremental(
        X_em_torch: Tensor,
        time_grid: Tensor,
        accel_train: Callable[[Tensor, float], Tensor],
        optimizer: torch.optim.Optimizer,
        dt_base: float,
        *,
        epochs_per_step: int,
        max_train_steps: int,
        substeps_per_dt: int,
        kernel_bw2: float,
        loss_type: Literal["mmd", "sw2", "sinkhorn", "geom_sinkhorn", "geom_gaussian"] = "geom_sinkhorn",
        particles_per_batch: Optional[int] = None,
        vel_provider: Optional[Callable[[int, Tensor, float], Tensor]] = None,
        friction: Union[float, Tensor] = 0.0,
        debug: bool = False,
        name: str = "",
        verlet: str = "x",
        save_every_k: int = 0,
        gif_every_k: int = 0,
        save_callback: Optional[Callable[[int], None]] = None,
        gif_callback: Optional[Callable[[int], None]] = None,
        stage_end_callback: Optional[Callable[[int, Dict[str, float]], None]] = None,
        epoch_callback: Optional[Callable[[int, int, float, float], None]] = None,  # NEW
):
    """
    Incremental rollout training from p0.

    NEW Args:
        epoch_callback: Optional callback(k, epoch, loss, friction) called after each epoch
    """
    device = X_em_torch.device
    num_p0, N, T_plus_1, d = X_em_torch.shape
    steps = T_plus_1 - 1
    dt_train = dt_base / substeps_per_dt

    max_horizon = min(max_train_steps, steps)
    t0_val = time_grid[0]

    friction_use = friction
    gamma0 = resolve_gamma(friction_use)
    is_learnable = gamma0.requires_grad if isinstance(gamma0, torch.Tensor) else False

    if isinstance(gamma0, torch.Tensor):
        init_val = float(gamma0.detach().item())
    else:
        init_val = float(gamma0)

    # choose integrator
    if verlet == "x":
        integrator = x_verlet_auto
        integrator_name = "x_verlet_auto"
    else:
        integrator = leapfrog_auto
        integrator_name = "leapfrog_auto"

    print(
        f"[{name}] Incremental from p0 | dt={dt_train:.4f}, "
        f"Friction={init_val:.2f} (Learnable={is_learnable}), "
        f"verlet={verlet} ({integrator_name})"
    )

    log_every = max(1, epochs_per_step // 5)

    for k in range(1, max_horizon + 1):
        total_micro_steps = _microsteps_between(time_grid[0], time_grid[k], dt_base, substeps_per_dt)
        print(f"\n[{name}] === Stage k={k}/{max_horizon} (rollout p0 → p{k}) ===")

        # stage-level tracking
        stage_loss_sum = 0.0
        stage_loss_count = 0
        stage_loss_last = None

        for epoch in range(epochs_per_step):
            optimizer.zero_grad(set_to_none=True)

            total_loss = 0.0
            num_used_p0 = 0

            rand_idx = int(torch.randint(0, num_p0, (1,), device=device).item())
            p0_indices = [rand_idx]

            for p0_idx in p0_indices:
                X_gt_pop = X_em_torch[p0_idx]

                # subsample particles
                if particles_per_batch is not None and particles_per_batch < X_gt_pop.shape[0]:
                    idx = torch.randint(0, X_gt_pop.shape[0], (particles_per_batch,), device=device)
                    X_gt_pop_batch = X_gt_pop[idx]
                else:
                    X_gt_pop_batch = X_gt_pop

                x0 = X_gt_pop_batch[:, 0, :]
                mask0 = torch.isfinite(x0).all(dim=1)
                x0 = x0[mask0]

                if x0.shape[0] == 0:
                    # nothing valid to roll out this epoch; skip cleanly
                    continue

                # initial velocity at t=0 (compute AFTER filtering)
                if vel_provider is not None:
                    v0 = vel_provider(p0_idx, x0, float(t0_val))
                else:
                    v0 = torch.zeros_like(x0)

                # initial velocity at t=0
                if vel_provider is not None:
                    v0 = vel_provider(p0_idx, x0, float(t0_val))
                else:
                    v0 = torch.zeros_like(x0)

                # rollout from p0 for k steps
                X_all_pred = integrator(
                    x0=x0,
                    v0=v0,
                    accel=accel_train,
                    dt=dt_train,
                    steps=total_micro_steps,
                    friction= resolve_gamma(friction_use),
                    return_all=True,
                    t_start=t0_val,
                )

                # loss over marginals m = 1..k
                per_step_losses = []
                for m in range(1, k + 1):
                    micro_idx = _microsteps_between(time_grid[0], time_grid[m], dt_base, substeps_per_dt)
                    x_pred_m = X_all_pred[:, micro_idx, :]
                    x_gt_m = X_gt_pop_batch[:, m, :]

                    loss_m = compute_loss(
                        x_pred_m,
                        x_gt_m,
                        loss_type=loss_type,
                        blur=float(kernel_bw2),
                    )
                    per_step_losses.append(loss_m)

                loss_p0 = torch.stack(per_step_losses, dim=0).mean()
                total_loss += loss_p0
                num_used_p0 += 1

                del X_all_pred

            loss = total_loss / float(num_used_p0)

            # accumulate
            stage_loss_last = float(loss.detach().item())
            stage_loss_sum += stage_loss_last
            stage_loss_count += 1

            loss.backward()
            optimizer.step()

            # NEW: Per-epoch callback for fine-grained logging
            if epoch_callback is not None:
                gamma_log = resolve_gamma(friction_use)
                cur_fric = float(gamma_log.detach().item()) if isinstance(gamma_log, torch.Tensor) else float(gamma_log)
                epoch_callback(k, epoch, stage_loss_last, cur_fric)

            # prints
            if debug or ((epoch + 1) % log_every == 0):
                gamma_log = resolve_gamma(friction_use)
                cur_fric = float(gamma_log.detach().item()) if isinstance(gamma_log, torch.Tensor) else float(gamma_log)

                print(
                    f"[{name}] k={k:02d} | Ep {epoch + 1:03d}/{epochs_per_step} "
                    f"| loss={loss.item():.4e}, fric={cur_fric:.3f}"
                )

        # stage-end callback
        if stage_end_callback is not None:
            stage_end_callback(k, {
                "train_loss_stage_avg": stage_loss_sum / max(1, stage_loss_count),
                "train_loss_stage_last": float(stage_loss_last) if stage_loss_last is not None else float("nan"),
            })

        # stage-end save/gif callbacks
        if save_callback is not None and save_every_k and (k % int(save_every_k) == 0):
            save_callback(k)
        if gif_callback is not None and gif_every_k and (k % int(gif_every_k) == 0):
            gif_callback(k)

        if torch.cuda.is_available():
            torch.cuda.empty_cache()