import imageio.v2 as imageio
import matplotlib.pyplot as plt
from IPython.display import Image, display
import torch
import numpy as np
import os
import matplotlib.cm as cm
from typing import Any, Dict, Optional, List, Tuple

def plot_holdout_scatter(x_pred, y_true, holdout_idx, title_suffix=""):
    """
    Scatter plot of Predicted vs True marginals.
    Assumes inputs are (N, d) numpy arrays or torch tensors.
    """
    # Detach and convert to numpy if needed
    if hasattr(x_pred, "detach"): x_pred = x_pred.detach().cpu().numpy()
    if hasattr(y_true, "detach"): y_true = y_true.detach().cpu().numpy()

    # Project to 2D (PC1, PC2)
    X2 = x_pred[:, :2]
    Y2 = y_true[:, :2]

    fig = plt.figure(figsize=(6, 6))
    plt.scatter(Y2[:, 0], Y2[:, 1], s=10, alpha=0.5, label="True", c='tab:blue')
    plt.scatter(X2[:, 0], X2[:, 1], s=10, alpha=0.5, label="Pred", c='tab:orange', marker='x')

    plt.title(f"EB Holdout (t={holdout_idx}) {title_suffix}")
    plt.xlabel("PC 1")
    plt.ylabel("PC 2")
    plt.legend()
    plt.axis("equal")
    plt.tight_layout()

    return fig


def plot_multi_holdout_scatter(
        all_true: List[Tuple[torch.Tensor, int]],
        all_pred: List[Tuple[torch.Tensor, int]],
        time_grid: torch.Tensor,
        title: str = "Multi-Holdout Interpolation",
        max_points: int = 500,
):
    """
    Plot all holdout marginals: true (o) vs predicted (x), colored by time.

    Args:
        all_true: List of (positions, holdout_idx) for ground truth
        all_pred: List of (positions, holdout_idx) for predictions
        time_grid: Full time grid tensor
        title: Plot title
        max_points: Max points to plot per marginal (for clarity)

    Returns:
        matplotlib Figure
    """
    n_holdouts = len(all_true)

    # Get data dimensionality
    d = all_true[0][0].shape[-1] if len(all_true) > 0 else 2

    # Create colormap
    cmap = cm.get_cmap('viridis', n_holdouts)
    colors = [cmap(i) for i in range(n_holdouts)]

    if d == 2:
        fig, ax = plt.subplots(figsize=(8, 8))

        for i, ((y_true, h_idx), (y_pred, _)) in enumerate(zip(all_true, all_pred)):
            y_true_np = y_true.numpy()
            y_pred_np = y_pred.numpy()

            # Subsample if needed
            if y_true_np.shape[0] > max_points:
                idx = np.random.choice(y_true_np.shape[0], max_points, replace=False)
                y_true_np = y_true_np[idx]
            if y_pred_np.shape[0] > max_points:
                idx = np.random.choice(y_pred_np.shape[0], max_points, replace=False)
                y_pred_np = y_pred_np[idx]

            t_val = float(time_grid[h_idx].item())
            label_true = f"t={t_val:.1f} true"
            label_pred = f"t={t_val:.1f} pred"

            # True: circles (o)
            ax.scatter(y_true_np[:, 0], y_true_np[:, 1],
                       c=[colors[i]], marker='o', s=20, alpha=0.6,
                       label=label_true, edgecolors='none')
            # Predicted: x markers
            ax.scatter(y_pred_np[:, 0], y_pred_np[:, 1],
                       c=[colors[i]], marker='x', s=20, alpha=0.6,
                       label=label_pred)

        ax.set_xlabel('x')
        ax.set_ylabel('y')
        ax.set_title(title)
        ax.legend(loc='upper right', fontsize=8, ncol=2)
        ax.set_aspect('equal', adjustable='box')

    elif d >= 3:
        # Use first 2 PCs or first 2 dims
        fig, axes = plt.subplots(1, 2, figsize=(14, 6))

        for ax_idx, (dim1, dim2) in enumerate([(0, 1), (0, 2) if d > 2 else (0, 1)]):
            ax = axes[ax_idx]

            for i, ((y_true, h_idx), (y_pred, _)) in enumerate(zip(all_true, all_pred)):
                y_true_np = y_true.numpy()
                y_pred_np = y_pred.numpy()

                if y_true_np.shape[0] > max_points:
                    idx = np.random.choice(y_true_np.shape[0], max_points, replace=False)
                    y_true_np = y_true_np[idx]
                if y_pred_np.shape[0] > max_points:
                    idx = np.random.choice(y_pred_np.shape[0], max_points, replace=False)
                    y_pred_np = y_pred_np[idx]

                t_val = float(time_grid[h_idx].item())

                ax.scatter(y_true_np[:, dim1], y_true_np[:, dim2],
                           c=[colors[i]], marker='o', s=15, alpha=0.5)
                ax.scatter(y_pred_np[:, dim1], y_pred_np[:, dim2],
                           c=[colors[i]], marker='x', s=15, alpha=0.5)

            ax.set_xlabel(f'dim {dim1}')
            ax.set_ylabel(f'dim {dim2}')
            ax.set_title(f'{title} (dims {dim1}-{dim2})')

        # Add legend to first subplot
        from matplotlib.lines import Line2D
        legend_elements = []
        for i, (_, h_idx) in enumerate(all_true):
            t_val = float(time_grid[h_idx].item())
            legend_elements.append(Line2D([0], [0], marker='o', color='w',
                                          markerfacecolor=colors[i], markersize=8,
                                          label=f't={t_val:.1f} true'))
            legend_elements.append(Line2D([0], [0], marker='x', color=colors[i],
                                          markersize=8, linestyle='None',
                                          label=f't={t_val:.1f} pred'))
        axes[0].legend(handles=legend_elements, loc='upper right', fontsize=7, ncol=2)
    else:
        # 1D case
        fig, ax = plt.subplots(figsize=(10, 4))
        for i, ((y_true, h_idx), (y_pred, _)) in enumerate(zip(all_true, all_pred)):
            y_true_np = y_true.numpy().flatten()
            y_pred_np = y_pred.numpy().flatten()

            t_val = float(time_grid[h_idx].item())
            ax.hist(y_true_np, bins=50, alpha=0.5, color=colors[i], label=f't={t_val:.1f} true')
            ax.hist(y_pred_np, bins=50, alpha=0.3, color=colors[i], histtype='step',
                    linewidth=2, label=f't={t_val:.1f} pred')
        ax.legend()
        ax.set_title(title)

    plt.tight_layout()
    return fig


def make_compare_gif(
        X_true, X_learned, dt,
        true_label="true", est_label="est",
        grid_size=None,
        save_path="temp.gif",
        always_show=True,
        X_null=None,
        show_null=True,
        frame_skip=5,
        fps=5,
        filter_outliers=False,
        times=None,  # NEW feature
        projection="auto",  # NEW feature
        render="auto",  # NEW feature
        bins=80,  # NEW feature
        subsample=None,  # NEW feature
):
    """
    Compare X_true vs X_learned (and optional X_null) as a GIF.

    Hybrid Behavior:
      - If d=2 (or PCA reduced to 2D): Uses "Old Style" (Single plot, overlapping scatter, legend).
      - If d=3: Uses "New Style" (Grid of subplots, rows=datasets, cols=projections).

    Supports:
      - (N,T,d) or (T,N,d) inputs.
      - Outlier filtering, custom timestamps, and subsampling.
    """
    import os
    import numpy as np
    import torch
    import matplotlib.pyplot as plt
    import imageio

    try:
        from IPython.display import display, Image
        _HAS_IPY = True
    except Exception:
        _HAS_IPY = False

    # --- Setup Directories ---
    save_path = str(save_path)
    save_dir = os.path.dirname(save_path)
    if save_dir:
        os.makedirs(save_dir, exist_ok=True)

    # --- Helpers ---
    def to_np(x):
        return x.detach().cpu().numpy() if torch.is_tensor(x) else x

    def canon_shape(A):
        """Return array shaped (N, T, d)."""
        A = to_np(A)
        if A is None:
            return None
        if A.ndim != 3:
            raise ValueError(f"Expected 3D array, got shape {A.shape}")
        # Heuristic: time dimension is typically much smaller than N
        # If first axis looks like time (and N is large), transpose.
        if A.shape[0] < A.shape[1] and A.shape[0] <= 1024 and A.shape[1] >= 256:
            return np.transpose(A, (1, 0, 2))
        return A

    # --- Data Loading & Shape Checking ---
    X_true_np = canon_shape(X_true)
    X_learned_np = canon_shape(X_learned)
    X_null_np = canon_shape(X_null) if (X_null is not None) else None

    can_show_null = bool(show_null) and (X_null_np is not None)

    if X_true_np.shape != X_learned_np.shape:
        raise ValueError(f"Shape mismatch true={X_true_np.shape} learned={X_learned_np.shape}")
    if can_show_null and X_null_np.shape != X_true_np.shape:
        raise ValueError(f"X_null shape {X_null_np.shape} must match {X_true_np.shape}")

    N, T, d = X_true_np.shape
    # --- EB-friendly: if d>3, visualize using first two coordinates (PC1/PC2) ---
    if d > 3:
        X_true_np = X_true_np[..., :2]
        X_learned_np = X_learned_np[..., :2]
        if can_show_null:
            X_null_np = X_null_np[..., :2]
        d = 2
        # ensure we don't later run PCA again
        if projection == "auto":
            proj_mode = "xy"
        else:
            proj_mode = projection

    # --- Subsampling (Optimization) ---
    if subsample is not None and int(subsample) < N:
        rng = np.random.default_rng(0)
        idx = rng.choice(N, size=int(subsample), replace=False)
        X_true_np = X_true_np[idx]
        X_learned_np = X_learned_np[idx]
        if can_show_null:
            X_null_np = X_null_np[idx]
        N = X_true_np.shape[0]

    # --- Dimensionality Reduction (PCA) if needed ---
    if "proj_mode" not in locals():
        if projection == "auto":
            proj_mode = "pca" if d > 3 else ("orth3" if d == 3 else "xy")
        else:
            proj_mode = projection

    if proj_mode == "pca":
        # Compute global PCA over all data/times to ensure stable projection
        mats = [X_true_np.reshape(-1, d), X_learned_np.reshape(-1, d)]
        if can_show_null:
            mats.append(X_null_np.reshape(-1, d))
        Z = np.concatenate(mats, axis=0)
        mask = np.isfinite(Z).all(axis=1)
        Z = Z[mask]

        if Z.shape[0] > 0:
            mu = Z.mean(axis=0, keepdims=True)
            Zc = Z - mu
            C = (Zc.T @ Zc) / max(1, Zc.shape[0])
            evals, evecs = np.linalg.eigh(C)
            W = evecs[:, -2:]  # Top 2 eigenvectors

            def pca2(A):
                return (A - mu) @ W

            X_true_np = pca2(X_true_np)
            X_learned_np = pca2(X_learned_np)
            if can_show_null:
                X_null_np = pca2(X_null_np)
            d = 2  # Now we are in 2D mode

    # --- Calculate Axis Limits ---
    # We do this for both modes to ensure robust outlier handling
    data_list = [X_true_np, X_learned_np]
    if can_show_null:
        data_list.append(X_null_np)
    all_pts = np.concatenate(data_list, axis=0).reshape(-1, d)
    all_pts = all_pts[np.isfinite(all_pts).all(axis=1)]

    if grid_size is None:
        if all_pts.shape[0] == 0:
            lims = [(-1.0, 1.0)] * d
        else:
            if filter_outliers:
                q = 0.995
                lo = np.quantile(all_pts, 1 - q, axis=0)
                hi = np.quantile(all_pts, q, axis=0)
            else:
                lo = np.min(all_pts, axis=0)
                hi = np.max(all_pts, axis=0)
            pad = 0.05 * np.maximum(1e-12, hi - lo)
            lims = [(float(lo[k] - pad[k]), float(hi[k] + pad[k])) for k in range(d)]
    else:
        gs = float(grid_size)
        lims = [(-gs, gs)] * d

    # --- Time Labels ---
    dtf = float(dt)
    if times is not None:
        times = np.asarray(times).reshape(-1)
        # Allow slight mismatches (T vs T+1) commonly found in ODE solvers
        if len(times) not in [T, T + 1]:
            raise ValueError(f"`times` length {len(times)} mismatch with T={T}")

    frames = []

    # ==========================================
    # BRANCH: 2D (Old Behavior)
    # Single plot, overlapping scatters, legend
    # ==========================================
    if d == 2:
        fig, ax = plt.subplots(figsize=(5, 5), dpi=110)

        for k in range(0, T, int(frame_skip)):
            ax.clear()
            ax.set_xlim(*lims[0])
            ax.set_ylim(*lims[1])
            ax.set_aspect("equal")

            # True
            ax.scatter(X_true_np[:, k, 0], X_true_np[:, k, 1],
                       s=5, alpha=0.4, label=true_label, c="tab:blue")

            # Learned
            ax.scatter(X_learned_np[:, k, 0], X_learned_np[:, k, 1],
                       s=5, alpha=0.4, label=est_label, c="tab:orange")

            # Null
            if can_show_null:
                ax.scatter(X_null_np[:, k, 0], X_null_np[:, k, 1],
                           s=5, alpha=0.12, label="Null", c="grey")

            # Title & Legend
            tval = float(times[k]) if times is not None else (k * dtf)
            ax.set_title(f"t = {tval:.3f}")
            ax.grid(alpha=0.3)
            # Legend only needs to be added once, but adding every frame is safe in loop
            ax.legend(loc="upper right", fontsize=8, frameon=True, facecolor='white', framealpha=0.8)

            fig.canvas.draw()
            frame = np.asarray(fig.canvas.buffer_rgba())
            frames.append(frame.copy())

        plt.close(fig)

    # ==========================================
    # BRANCH: 3D (New Behavior)
    # Grid of subplots, separate rows, projections
    # ==========================================
    else:
        # 3 Projections
        panels = [(0, 1), (0, 2), (1, 2)]
        col_titles = ["(x,y)", "(x,z)", "(y,z)"]

        # Rows: true, learned, (optional null)
        row_specs = [(true_label, X_true_np), (est_label, X_learned_np)]
        if can_show_null:
            row_specs.append(("Null", X_null_np))

        nrows = len(row_specs)
        ncols = len(panels)

        render_mode = "hist2d" if (render == "auto" or render == "hist2d") else "scatter"

        fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(4.2 * ncols, 4.0 * nrows), dpi=110)
        # Ensure axes is always 2D array
        if nrows == 1 and ncols == 1:
            axes = np.array([[axes]])
        elif nrows == 1:
            axes = np.array([axes])
        elif ncols == 1:
            axes = np.array([[ax] for ax in axes])

        for k in range(0, T, int(frame_skip)):
            for r, (rlabel, Xr) in enumerate(row_specs):
                for c, (i, j) in enumerate(panels):
                    ax = axes[r, c]
                    ax.clear()
                    ax.set_aspect("equal", adjustable="box")
                    ax.set_xlim(*lims[i])
                    ax.set_ylim(*lims[j])
                    ax.grid(alpha=0.25)

                    pts = Xr[:, k, :]
                    x = pts[:, i]
                    y = pts[:, j]

                    # Basic NaN filtering
                    m = np.isfinite(x) & np.isfinite(y)
                    x = x[m];
                    y = y[m]

                    if render_mode == "hist2d":
                        ax.hist2d(x, y, bins=int(bins), range=[lims[i], lims[j]])
                    else:
                        ax.scatter(x, y, s=4, alpha=0.35)

                    if r == 0:
                        ax.set_title(col_titles[c], fontsize=10)
                    if c == 0:
                        ax.set_ylabel(rlabel)

            tval = float(times[k]) if times is not None else (k * dtf)
            fig.suptitle(f"t = {tval:.3f}", y=0.98, fontsize=12)

            fig.canvas.draw()
            frame = np.asarray(fig.canvas.buffer_rgba())
            frames.append(frame.copy())

        plt.close(fig)

    # --- Save and Display ---
    imageio.mimsave(save_path, frames, fps=int(fps), loop=0)

    if always_show and _HAS_IPY:
        display(Image(filename=save_path))

    return save_path

from typing import Any, Optional
from pathlib import Path
from mechanics import pick_integrator
from potential_energy_models import make_accel_from_potential
from matplotlib.animation import FuncAnimation
import wandb

@torch.no_grad()
def maybe_gif(
    step_idx: int,
    *,
    gif_every: int,
    gif_p0_idx: int,
    particles_gif: Optional[int],
    gif_frame_skip: int,
    gif_fps: int,
    substeps_per_dt: int,
    integrator_name: str,
    max_force: Optional[float],
    model: torch.nn.Module,
    X_em: torch.Tensor,               # (num_p0,N,T+1,d)
    time_grid: torch.Tensor,          # (T+1,)
    dt_base: float,
    vel_provider,
    vel_mode: str,
    V_em: Optional[torch.Tensor],
    friction: Any,
    outdir: Path,
    device: torch.device,
    wb_run: Optional[Any] = None,
) -> None:
    if int(gif_every) <= 0:
        return
    if int(step_idx) % int(gif_every) != 0:
        return

    model.eval()

    p0_idx = int(gif_p0_idx)
    X_gt = X_em[p0_idx]  # (N,T+1,d)
    x0_full = X_gt[:, 0, :].detach()

    # --- FIX: Filter NaNs immediately at t=0 ---
    # If we don't do this, Transformer/Attention will crash or output all NaNs
    valid_mask = torch.isfinite(x0_full).all(dim=-1)

    # Apply mask to Ground Truth (so we only track valid particles)
    X_gt = X_gt[valid_mask]
    x0 = x0_full[valid_mask]

    # Now subsample from the VALID set
    idx: Optional[torch.Tensor] = None
    if particles_gif is not None and int(particles_gif) < x0.shape[0]:
        idx = torch.randint(0, x0.shape[0], (int(particles_gif),), device=device)
        X_gt = X_gt[idx]
        x0 = x0[idx]
        # Note: We must also filter V_em later using the same logic if it exists

    t0 = float(time_grid[0].item())
    vel_mode_l = str(vel_mode).lower()

    if vel_provider is None or vel_mode_l == "zero":
        v0 = torch.zeros_like(x0)
    elif vel_mode_l == "bundle":
        if V_em is None:
            raise ValueError("maybe_gif: vel_mode='bundle' requires V_em loaded from bundle.")
        m = int(torch.argmin((time_grid - float(t0)).abs()).item())
        m = max(0, min(m, time_grid.numel() - 1))

        # We must filter V_em exactly how we filtered X_em above
        v0_all = V_em[int(p0_idx), :, m, :]
        v0_valid = v0_all[valid_mask]  # Apply the NaN mask

        v0 = v0_valid if idx is None else v0_valid[idx]  # Apply the subsample mask
        v0 = v0.to(device=x0.device, dtype=x0.dtype)
    else:
        v0 = vel_provider(p0_idx, x0, t0).detach()

    dt_train = float(dt_base) / int(substeps_per_dt)
    total_micro = (X_gt.shape[1] - 1) * int(substeps_per_dt)

    integrator, _ = pick_integrator(str(integrator_name))
    accel_eval = make_accel_from_potential(model, create_graph=False, max_force=max_force)

    X_pred = integrator(
        x0=x0,
        v0=v0,
        accel=accel_eval,
        dt=dt_train,
        steps=int(total_micro),
        friction=friction,
        return_all=True,
        t_start=float(t0),
    )

    steps_macro = X_gt.shape[1] - 1
    macro_idx = (torch.arange(0, steps_macro + 1, device=device) * int(substeps_per_dt)).long()
    X_pred_macro = X_pred[:, macro_idx, :]

    gif_path = str(outdir / f"compare_step{int(step_idx):07d}.gif")
    make_compare_gif(
        X_true=X_gt.detach().cpu(),
        X_learned=X_pred_macro.detach().cpu(),
        dt=float(dt_base),
        times=time_grid.detach().cpu().numpy(),
        save_path=gif_path,
        frame_skip=int(gif_frame_skip),
        fps=int(gif_fps),
        always_show=False,
        projection="auto",
        render="auto",
        subsample=int(particles_gif) if particles_gif is not None else None,
    )

    if wb_run is not None:
        import wandb
        wandb.log({"gif": wandb.Video(gif_path, fps=int(gif_fps), format="gif")}, step=int(step_idx))

    model.train()