from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Any, Callable

import jax
import jax.numpy as jnp
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

from ott.geometry import pointcloud, costs
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn

from pipeline.dataloader.dataset import DynamicalSystemDataset
from pipeline.dataloader.transforms import add_noise, crop, normalize, compose
from pipeline.eval.eval_utils import _load_checkpoint, _pick_checkpoint_dir, init_state
from pipeline.training.runtime_resolve import resolve_runtime
from pipeline.training.train_configs import TrainConfig


def _load_train_config(path: Path) -> TrainConfig:
    with open(path, "r", encoding="utf-8") as f:
        raw: dict[str, Any] = json.load(f)
    allowed = {field.name for field in TrainConfig.__dataclass_fields__.values()}
    return TrainConfig(**{k: v for k, v in raw.items() if k in allowed})


def _pick_ckpt_dir(
    experiment_dir: Path | None,
    checkpoint_dir: Path | None,
    checkpoint_name: str | None,
) -> Path:
    if checkpoint_dir is not None:
        ckpt_dir = checkpoint_dir
    else:
        if experiment_dir is None:
            raise ValueError("Provide --experiment_dir or --checkpoint_dir.")
        ckpt_root = experiment_dir / "checkpoints"
        ckpt_dir = _pick_checkpoint_dir(ckpt_root, checkpoint_name)
    if not (ckpt_dir / "checkpoint.pkl").is_file():
        raise FileNotFoundError(f"checkpoint.pkl not found in: {ckpt_dir}")
    return ckpt_dir


def _build_transforms(
    *,
    noise_level: float,
    crop_window_size: int,
    do_normalize: bool,
) -> Callable[..., Any] | None:
    transforms = []
    if noise_level > 0.0:
        transforms.append(add_noise(noise_level=noise_level))
    if crop_window_size and crop_window_size > 0:
        transforms.append(crop(window_size=crop_window_size))
    if do_normalize:
        transforms.append(normalize())
    if not transforms:
        return None
    return compose(transforms)


def _normalize_embeddings(z: jnp.ndarray, eps: float) -> jnp.ndarray:
    mean = jnp.mean(z, axis=(0, 1), keepdims=True)
    std = jnp.std(z, axis=(0, 1), keepdims=True)
    return (z - mean) / (std + eps)


def _compute_transport_matrix(
    z_hat: jnp.ndarray,
    z_true: jnp.ndarray,
    *,
    epsilon: float,
    max_iters: int,
    threshold: float,
    normalize_embeddings: bool,
    normalize_eps: float,
) -> tuple[jnp.ndarray, Any]:
    if z_hat.ndim == 1:
        z_hat = z_hat[:, None]
    if z_true.ndim == 1:
        z_true = z_true[:, None]

    if normalize_embeddings:
        z_hat = _normalize_embeddings(z_hat, eps=normalize_eps)
        z_true = _normalize_embeddings(z_true, eps=normalize_eps)

    t = z_hat.shape[0]
    a = jnp.full((t,), 1.0 / t, dtype=z_hat.dtype)
    b = jnp.full((t,), 1.0 / t, dtype=z_true.dtype)

    geom = pointcloud.PointCloud(
        z_hat,
        z_true,
        epsilon=epsilon,
        cost_fn=costs.SqEuclidean(),
    )
    prob = linear_problem.LinearProblem(geom, a=a, b=b)
    solver = sinkhorn.Sinkhorn(
        lse_mode=True,
        threshold=threshold,
        max_iterations=max_iters,
    )
    out = solver(prob)
    return out.matrix, out


def _svd_coords(data: np.ndarray, n_components: int) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    n_components = min(n_components, data.shape[1])
    u, s, vt = np.linalg.svd(data, full_matrices=False)
    coords = u[:, :n_components] * s[:n_components][None, :]
    return coords, s, vt


def _pairwise_stats(x: np.ndarray) -> dict[str, float]:
    if x.ndim == 1:
        x = x[:, None]
    diff = x[:, None, :] - x[None, :, :]
    dists = np.linalg.norm(diff, axis=-1)
    iu = np.triu_indices(dists.shape[0], k=1)
    vals = dists[iu]
    if vals.size == 0:
        return {"mean": float("nan"), "std": float("nan"), "min": float("nan"), "max": float("nan"), "median": float("nan")}
    return {
        "mean": float(vals.mean()),
        "std": float(vals.std()),
        "min": float(vals.min()),
        "max": float(vals.max()),
        "median": float(np.median(vals)),
    }


def _matrix_stats(mat: np.ndarray) -> dict[str, float]:
    return {
        "mean": float(mat.mean()),
        "std": float(mat.std()),
        "min": float(mat.min()),
        "max": float(mat.max()),
    }


def _effective_linear_map(
    summary_type: str,
    summary_params: Any,
    input_dim: int,
    u_true: np.ndarray,
    z_true: np.ndarray,
) -> tuple[np.ndarray | None, np.ndarray | None, str | None]:
    if summary_type == "linear":
        W_eff, b_eff = summary_params[0]
        for W, b in summary_params[1:]:
            W_eff = W_eff @ W
            b_eff = b_eff @ W + b
        return np.asarray(W_eff), np.asarray(b_eff), "exact_linear"
    if summary_type == "mlp":
        # Approximate effective linear map via least squares on current trajectory.
        u2 = u_true.reshape(-1, input_dim)
        z2 = z_true.reshape(u2.shape[0], -1)
        U = np.concatenate([u2, np.ones((u2.shape[0], 1))], axis=1)
        theta, *_ = np.linalg.lstsq(U, z2, rcond=None)
        W_eff = theta[:-1, :]
        b_eff = theta[-1, :]
        return W_eff, b_eff, "approx_linear_fit"
    return None, None, None


def _plot_transport_map(
    u_true: np.ndarray,
    u_mapped: np.ndarray,
    u_hat: np.ndarray,
    out_path: Path,
) -> None:
    d = u_true.shape[1]
    if d >= 3:
        fig = plt.figure(figsize=(7, 6))
        ax = fig.add_subplot(111, projection="3d")
        ax.scatter(
            u_true[:, 0],
            u_true[:, 1],
            u_true[:, 2],
            s=16,
            alpha=0.7,
            c="green",
            label="u_true",
        )
        ax.scatter(
            u_mapped[:, 0],
            u_mapped[:, 1],
            u_mapped[:, 2],
            s=16,
            alpha=0.7,
            c="orange",
            label="transported",
        )
        ax.scatter(
            u_hat[:, 0],
            u_hat[:, 1],
            u_hat[:, 2],
            s=16,
            alpha=0.7,
            c="purple",
            label="u_hat",
        )
        ax.set_xlabel("x")
        ax.set_ylabel("y")
        ax.set_zlabel("z")
    elif d == 2:
        fig, ax = plt.subplots(figsize=(6, 5))
        ax.scatter(u_true[:, 0], u_true[:, 1], s=16, alpha=0.7, c="green", label="u_true")
        ax.scatter(
            u_mapped[:, 0],
            u_mapped[:, 1],
            s=16,
            alpha=0.7,
            c="orange",
            label="transported",
        )
        ax.scatter(
            u_hat[:, 0],
            u_hat[:, 1],
            s=16,
            alpha=0.7,
            c="purple",
            label="u_hat",
        )
        ax.set_xlabel("x")
        ax.set_ylabel("y")
    else:
        fig, ax = plt.subplots(figsize=(6, 4))
        ax.scatter(np.arange(u_true.shape[0]), u_true[:, 0], s=16, alpha=0.7, c="green", label="u_true")
        ax.scatter(
            np.arange(u_mapped.shape[0]),
            u_mapped[:, 0],
            s=16,
            alpha=0.7,
            c="orange",
            label="transported",
        )
        ax.scatter(
            np.arange(u_hat.shape[0]),
            u_hat[:, 0],
            s=16,
            alpha=0.7,
            c="purple",
            label="u_hat",
        )
        ax.set_xlabel("time")
        ax.set_ylabel("value")
    ax.set_title("Transport map: u_true vs transported")
    ax.legend(frameon=False)
    fig.tight_layout()
    fig.savefig(out_path, dpi=200)
    plt.close(fig)


def _write_results_txt(
    out_path: Path,
    *,
    transport_proj: np.ndarray,
    transport_v1: np.ndarray | None,
    transport_singular_values: np.ndarray | None,
    weff_dir: np.ndarray | None,
    effective_summary_line: str | None,
    summary_type: str,
) -> None:
    transport_str = np.array2string(
        transport_proj, precision=6, suppress_small=True, threshold=200
    )
    lines = []
    if effective_summary_line is not None:
        lines.append(f"{effective_summary_line}\n")
    else:
        lines.append(f"Effective linear summary: N/A (summary_type={summary_type})\n")
    if transport_v1 is not None:
        v1_str = np.array2string(transport_v1, precision=6, suppress_small=True, threshold=200)
        lines.append("transport_svd_v1:\n")
        lines.append(v1_str + "\n")
    if transport_singular_values is not None:
        sv_str = np.array2string(transport_singular_values, precision=6, suppress_small=True, threshold=200)
        lines.append("transport_svd_singular_values:\n")
        lines.append(sv_str + "\n")
    if transport_v1 is not None and weff_dir is not None:
        v1n = transport_v1 / (np.linalg.norm(transport_v1) + 1e-8)
        wn = weff_dir / (np.linalg.norm(weff_dir) + 1e-8)
        cos = float(abs(np.dot(v1n, wn)))
        lines.append(f"abs_cosine(v1, weff_dir): {cos:.6f}\n")
    lines.append("transport_svd_coords:\n")
    lines.append(transport_str + "\n")
    with open(out_path, "w", encoding="utf-8") as f:
        f.writelines(lines)


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Compute OT transport matrix for a checkpoint + dataset sample."
    )
    parser.add_argument(
        "--experiment_dir",
        type=Path,
        help="Experiment dir containing train_config.json and checkpoints/.",
    )
    parser.add_argument(
        "--train_config",
        type=Path,
        help="Path to train_config.json (required if --experiment_dir is omitted).",
    )
    parser.add_argument(
        "--checkpoint_dir",
        type=Path,
        help="Path to a checkpoint directory containing checkpoint.pkl.",
    )
    parser.add_argument(
        "--checkpoint",
        type=str,
        default="final",
        help="Checkpoint name under experiment_dir/checkpoints (e.g., final, epoch_0009). Use 'auto' to pick latest.",
    )
    parser.add_argument(
        "--data_path",
        type=Path,
        required=True,
        help="Path to dataset .npz file.",
    )
    parser.add_argument("--sample_index", type=int, default=0)
    parser.add_argument(
        "--ot_horizon",
        type=int,
        default=None,
        help="Optional OT horizon (time steps). Defaults to train_config.ot_horizon.",
    )
    parser.add_argument(
        "--normalize",
        action=argparse.BooleanOptionalAction,
        default=True,
        help="Apply per-trajectory normalization before model rollout.",
    )
    parser.add_argument("--noise_level", type=float, default=0.0)
    parser.add_argument("--crop_window_size", type=int, default=0)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument(
        "--output_dir",
        type=Path,
        default=Path("ot_outputs"),
        help="Directory to write outputs (matrix, plot, text).",
    )
    parser.add_argument(
        "--output",
        type=Path,
        default=None,
        help="Optional results text path (.txt). Overrides output_dir/ot_results.txt.",
    )
    parser.add_argument(
        "--plot_path",
        type=Path,
        default=None,
        help="Optional plot path (.png). Defaults to ./transport_map.png.",
    )
    parser.add_argument(
        "--normalize_embeddings",
        action=argparse.BooleanOptionalAction,
        default=True,
        help="Normalize summary embeddings before OT (matches sinkhorn_distance default).",
    )
    parser.add_argument("--normalize_eps", type=float, default=1e-6)
    args = parser.parse_args()

    if args.experiment_dir is None and args.train_config is None:
        raise ValueError("Provide --experiment_dir or --train_config.")

    train_cfg_path = (
        args.train_config
        if args.train_config is not None
        else args.experiment_dir / "train_config.json"
    )
    train_cfg = _load_train_config(train_cfg_path)

    if train_cfg.distance_config.get("type") != "sinkhorn":
        raise ValueError(
            "Transport matrix is only defined for sinkhorn OT. "
            f"Found distance type: {train_cfg.distance_config.get('type')}"
        )

    ckpt_name = None if args.checkpoint == "auto" else args.checkpoint
    ckpt_dir = _pick_ckpt_dir(args.experiment_dir, args.checkpoint_dir, ckpt_name)
    ckpt = _load_checkpoint(ckpt_dir)
    state = init_state(ckpt)

    runtime = resolve_runtime(train_cfg)

    dataset = DynamicalSystemDataset(data_path=str(args.data_path))
    if args.sample_index < 0 or args.sample_index >= len(dataset):
        raise IndexError(
            f"sample_index out of range: {args.sample_index} (len={len(dataset)})"
        )

    traj = dataset.get_traj(args.sample_index)
    transforms = _build_transforms(
        noise_level=args.noise_level,
        crop_window_size=args.crop_window_size,
        do_normalize=args.normalize,
    )
    if transforms is not None:
        rng = jax.random.PRNGKey(args.seed)
        traj = transforms(traj, rng)

    traj = jnp.asarray(traj, dtype=runtime.dtype)
    traj = traj[None, ...]

    horizon = args.ot_horizon
    if horizon is None:
        horizon = int(getattr(train_cfg, "ot_horizon", traj.shape[1]))
    horizon = min(horizon, traj.shape[1])

    u_true = traj[:, :horizon, :]
    u_hat = runtime.rollout_ot(state.emulator.params, u_true)

    z_true = runtime.summary_apply(state.summary.params, u_true)[0]

    distance_cfg_path = Path(
        f"configs/distance/{train_cfg.distance_config['type']}_config.json"
    )
    with open(distance_cfg_path, "r", encoding="utf-8") as f:
        dist_cfg = json.load(f)

    transport, out = _compute_transport_matrix(
        u_true[0],
        u_hat[0],
        epsilon=float(dist_cfg["epsilon"]),
        max_iters=int(dist_cfg["max_iters"]),
        threshold=float(dist_cfg["threshold"]),
        normalize_embeddings=args.normalize_embeddings,
        normalize_eps=args.normalize_eps,
    )

    transport_np = np.asarray(transport)
    u_true_np = np.asarray(u_true[0])
    u_hat_np = np.asarray(u_hat[0])
    row_sums = transport_np.sum(axis=1, keepdims=True)
    row_sums = np.where(row_sums == 0.0, 1.0, row_sums)
    u_mapped = (transport_np @ u_hat_np) / row_sums
    transport_vec = u_mapped - u_true_np
    transport_svd_coords, transport_svd_s, transport_svd_vt = _svd_coords(transport_vec, n_components=3)

    summary_type = train_cfg.summary_config.get("type", "identity")
    z_true_np = np.asarray(z_true)
    if z_true_np.ndim == 1:
        z_true_2d = z_true_np[:, None]
    else:
        z_true_2d = z_true_np
    weff, beff, weff_label = _effective_linear_map(
        summary_type=summary_type,
        summary_params=state.summary.params,
        input_dim=u_true_np.shape[1],
        u_true=u_true_np,
        z_true=z_true_2d,
    )

    effective_summary_line = None
    weff_dir = None
    if weff is not None:
        if weff.ndim == 1:
            weff_dir = weff
        elif weff.ndim == 2 and weff.shape[1] >= 1:
            weff_dir = weff[:, 0]
    if weff is not None and beff is not None and summary_type == "linear":
        if weff.ndim == 2 and weff.shape[1] == 1 and weff.shape[0] >= 3:
            w = weff[:, 0]
            b = beff[0]
            effective_summary_line = (
                f"Effective linear summary: s = {w[0]:.3f}*x + {w[1]:.3f}*y + "
                f"{w[2]:.3f}*z + {b:.3f}"
            )

    output_dir = args.output_dir
    output_dir.mkdir(parents=True, exist_ok=True)

    if args.output is not None and args.output.suffix != ".txt":
        raise ValueError("Output must end with .txt")

    plot_path = args.plot_path or (output_dir / "transport_map.png")
    _plot_transport_map(u_true=u_true_np, u_mapped=u_mapped, u_hat=u_hat_np, out_path=plot_path)

    results_path = args.output if args.output is not None else output_dir / "ot_results.txt"
    _write_results_txt(
        results_path,
        transport_proj=transport_svd_coords,
        transport_v1=transport_svd_vt[0] if transport_svd_vt.shape[0] > 0 else None,
        transport_singular_values=transport_svd_s,
        weff_dir=weff_dir,
        effective_summary_line=effective_summary_line,
        summary_type=summary_type,
    )

    print(f"[OK] transport matrix shape: {transport_np.shape}")
    print(f"[OK] reg_ot_cost: {float(out.reg_ot_cost)}")
    print(f"[OK] transport map plot saved to: {plot_path}")
    print(f"[OK] results text saved to: {results_path}")
    norms_u = np.linalg.norm(u_mapped, axis=1)
    norms_true = np.linalg.norm(u_true_np, axis=1)
    norms_diff = np.linalg.norm(transport_vec, axis=1)
    print(
        "[OK] transport stats: "
        f"max|transported|={norms_u.max():.6f}, "
        f"mean|transported|={norms_u.mean():.6f}, "
        f"max|u_true|={norms_true.max():.6f}, "
        f"mean|u_true|={norms_true.mean():.6f}, "
        f"max|transported - u_true|={norms_diff.max():.6f}, "
        f"mean|transported - u_true|={norms_diff.mean():.6f}"
    )

    t = transport_np.shape[0]
    uniform_val = 1.0 / (t * t)
    uniform_stats = _matrix_stats(transport_np)
    uniform_l1 = float(np.mean(np.abs(transport_np - uniform_val)))
    uniform_l2 = float(np.linalg.norm(transport_np - uniform_val) / (np.linalg.norm(np.full_like(transport_np, uniform_val)) + 1e-8))
    row_sums = transport_np.sum(axis=1)
    col_sums = transport_np.sum(axis=0)
    print(
        "[OK] P uniformity: "
        f"mean={uniform_stats['mean']:.6f}, std={uniform_stats['std']:.6f}, "
        f"min={uniform_stats['min']:.6f}, max={uniform_stats['max']:.6f}, "
        f"L1_to_uniform={uniform_l1:.6f}, relL2_to_uniform={uniform_l2:.6f}, "
        f"row_sum_range=({row_sums.min():.6f},{row_sums.max():.6f}), "
        f"col_sum_range=({col_sums.min():.6f},{col_sums.max():.6f})"
    )

    stats_true = _pairwise_stats(u_true_np)
    stats_hat = _pairwise_stats(u_hat_np)
    cost = np.sum((u_hat_np[:, None, :] - u_true_np[None, :, :]) ** 2, axis=-1)
    stats_cost = _matrix_stats(cost)
    print(
        "[OK] state distances: "
        f"u_true pairwise mean={stats_true['mean']:.6f}, std={stats_true['std']:.6f}, "
        f"min={stats_true['min']:.6f}, max={stats_true['max']:.6f}; "
        f"u_hat pairwise mean={stats_hat['mean']:.6f}, std={stats_hat['std']:.6f}, "
        f"min={stats_hat['min']:.6f}, max={stats_hat['max']:.6f}; "
        f"cost(u_hat,u_true) mean={stats_cost['mean']:.6f}, std={stats_cost['std']:.6f}, "
        f"min={stats_cost['min']:.6f}, max={stats_cost['max']:.6f}"
    )


if __name__ == "__main__":
    main()
