import argparse
import math
import os
import sys
from typing import Tuple

import numpy as np
import torch
import matplotlib.pyplot as plt

BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if BASE_DIR not in sys.path:
    sys.path.append(BASE_DIR)

from train import JointDynamicsModel, resolve_device


def rk4_step(model, z, dt):
    dt = torch.as_tensor(dt, device=z.device, dtype=z.dtype)
    k1 = model(z)
    k2 = model(z + 0.5 * dt * k1)
    k3 = model(z + 0.5 * dt * k2)
    k4 = model(z + dt * k3)
    return z + (dt / 6.0) * (k1 + 2.0 * k2 + 2.0 * k3 + k4)


def load_test_inputs(
    trajectories_path: str | None,
    macro_feature_path: str | None,
) -> Tuple[np.ndarray, np.ndarray]:
    
    

    trajectories = np.load(trajectories_path, allow_pickle=True)
    macro_feature = np.load(macro_feature_path, allow_pickle=True)

    if trajectories.ndim != 4:
        raise ValueError(f"trajectories must be 4D, got shape={trajectories.shape}")
    if macro_feature.ndim == 2:
        macro_feature = macro_feature[..., None]
    if macro_feature.ndim != 3:
        raise ValueError(f"macro_feature must be 3D, got shape={macro_feature.shape}")

    if trajectories.shape[0] != macro_feature.shape[0]:
        raise ValueError("trajectories and macro_feature must have same n_traj")
    if trajectories.shape[1] != macro_feature.shape[1]:
        raise ValueError("trajectories and macro_feature must have same T")

    return trajectories.astype(np.float32), macro_feature.astype(np.float32)


def normalize_trajectories(
    trajectories: np.ndarray,
    mean: np.ndarray,
    std: np.ndarray,
) -> np.ndarray:
    return (trajectories - mean) / std


def compute_z_over_time(
    encoder: torch.nn.Module,
    trajectories: np.ndarray,
    device: torch.device,
    batch_traj: int,
) -> np.ndarray:
    encoder.eval()
    n_traj, T, n_particles, dim = trajectories.shape
    z_chunks = []
    for start in range(0, n_traj, batch_traj):
        end = min(start + batch_traj, n_traj)
        chunk = trajectories[start:end]
        chunk_t = torch.from_numpy(chunk).float().to(device)
        batch_size = chunk_t.shape[0]
        chunk_flat = chunk_t.reshape(-1, n_particles, dim)
        with torch.no_grad():
            z_flat = encoder(chunk_flat)
        z = z_flat.reshape(batch_size, T, -1).cpu().numpy()
        z_chunks.append(z)
    return np.concatenate(z_chunks, axis=0)


def build_model_from_checkpoint(
    checkpoint_path: str,
    data_dim: int,
    macro_dim: int,
    device: torch.device,
    fallback_args: dict,
) -> Tuple[JointDynamicsModel, dict, dict]:
    ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
    ckpt_args = ckpt.get("args", {})

    z_dim = ckpt_args.get("z_dim", fallback_args["z_dim"])
    encoder_hidden_dim = ckpt_args.get("encoder_hidden_dim", fallback_args["encoder_hidden_dim"])
    drift_hidden_dim = ckpt_args.get("drift_hidden_dim", fallback_args["drift_hidden_dim"])
    pool = ckpt_args.get("pool", fallback_args["pool"])

    model = JointDynamicsModel(
        in_dim=data_dim,
        z_dim=z_dim,
        macro_dim=macro_dim,
        encoder_hidden_dim=encoder_hidden_dim,
        drift_hidden_dim=drift_hidden_dim,
        pool=pool,
    ).to(device)
    model.load_state_dict(ckpt["model_state_dict"])
    model.eval()
    normalization_info = ckpt.get("normalization_info", None)
    return model, ckpt_args, normalization_info


def load_macro_normalization(macro_norm_path: str | None) -> tuple[np.ndarray | None, np.ndarray | None]:
    if macro_norm_path is None:
        return None, None
    if not os.path.exists(macro_norm_path):
        print(f"Macro normalization file not found: {macro_norm_path}")
        return None, None
    with np.load(macro_norm_path) as norm_data:
        if "Z_min" in norm_data and "Z_max" in norm_data:
            return norm_data["Z_min"], norm_data["Z_max"]
    print(f"Macro normalization file missing Z_min/Z_max: {macro_norm_path}")
    return None, None


def denormalize_macro(values: np.ndarray, macro_min: np.ndarray | None, macro_max: np.ndarray | None) -> np.ndarray:
    if macro_min is None or macro_max is None:
        return values
    macro_min_val = float(np.asarray(macro_min).reshape(-1)[0])
    macro_max_val = float(np.asarray(macro_max).reshape(-1)[0])
    scale = 0.5 * (macro_max_val - macro_min_val)
    return (values + 1.0) * scale + macro_min_val


def plot_true_vs_pred(
    true_series: np.ndarray,
    pred_series: np.ndarray,
    dt: float,
    output_path: str,
    file_name: str,
    ylabel_prefix: str,
) -> None:
    n_exp, steps, n_dim = true_series.shape
    t = np.arange(steps) * dt
    n_row = math.ceil(n_dim / 3)
    fig, axes = plt.subplots(n_row, 3, figsize=(12, 4 * n_row))
    axes = np.atleast_1d(axes).ravel()

    for d in range(n_dim):
        ax = axes[d]
        true_all = true_series[:, :, d]
        pred_all = pred_series[:, :, d]
        for exp_idx in range(n_exp):
            (true_line,) = ax.plot(t, true_all[exp_idx], label=f"true exp{exp_idx}")
            color = true_line.get_color()
            ax.plot(t, pred_all[exp_idx], linestyle="--", color=color, label=f"pred exp{exp_idx}")
        ax.set_xlabel("time")
        ax.set_ylabel(f"{ylabel_prefix}[{d}]")

    for i in range(n_row * 3):
        if i >= n_dim:
            axes[i].axis("off")

    fig.tight_layout()
    out_file = os.path.join(output_path, file_name)
    fig.savefig(out_file, dpi=300)
    plt.close(fig)
    print(f"Saved comparison plot to {out_file}")


def plot_macro(
    true_series: np.ndarray,
    pred_series: np.ndarray,
    dt: float,
    output_path: str,
    file_name: str,
    macro_min: np.ndarray | None = None,
    macro_max: np.ndarray | None = None,
) -> None:
    n_exp, steps, _ = true_series.shape
    time_steps = 100
    t = np.arange(time_steps)
    true_macro = true_series[:, :time_steps, -1]
    pred_macro = pred_series[:, :time_steps, -1]
    if macro_min is not None and macro_max is not None:
        true_macro = denormalize_macro(true_macro, macro_min, macro_max)
        pred_macro = denormalize_macro(pred_macro, macro_min, macro_max)

    fig, ax = plt.subplots(figsize=(6, 4))
    for exp_idx in range(n_exp):
        (true_line,) = ax.plot(t, true_macro[exp_idx], label=f"true exp{exp_idx}")
        color = true_line.get_color()
        ax.plot(t, pred_macro[exp_idx], linestyle="--", color=color, label=f"pred exp{exp_idx}")
    ax.set_xlabel("time")
    ax.set_ylabel("macro")

    fig.tight_layout()
    out_file = os.path.join(output_path, file_name)
    fig.savefig(out_file, dpi=300)
    npz_name = os.path.splitext(file_name)[0] + ".npz"
    npz_path = os.path.join(output_path, npz_name)
    np.savez(npz_path, true_macro=true_macro, pred_macro=pred_macro, t=t)
    plt.close(fig)
    print(f"Saved macro comparison plot to {out_file}")
    print(f"Saved macro trajectories to {npz_path}")


def plot_true_vs_pred_dzdt(
    true_series: np.ndarray,
    pred_dzdt: np.ndarray,
    output_path: str,
    file_name: str,
    ylabel_prefix: str,
) -> None:
    n_exp, steps, n_dim = true_series.shape
    n_row = math.ceil(n_dim / 3)
    fig, axes = plt.subplots(n_row, 3, figsize=(12, 4 * n_row))
    axes = np.atleast_1d(axes).ravel()

    t_idx = np.arange(steps)
    for d in range(n_dim):
        ax = axes[d]
        true_all = true_series[:, :, d]
        pred_all = pred_dzdt[:, :, d]
        for exp_idx in range(n_exp):
            (true_line,) = ax.plot(t_idx, true_all[exp_idx], label=f"true dzdt exp{exp_idx}")
            color = true_line.get_color()
            ax.plot(t_idx, pred_all[exp_idx], linestyle="--", color=color, label=f"pred dzdt exp{exp_idx}")
        ax.set_xlabel("time step")
        ax.set_ylabel(f"{ylabel_prefix}[{d}]")

    for i in range(n_row * 3):
        if i >= n_dim:
            axes[i].axis("off")

    fig.tight_layout()
    out_file = os.path.join(output_path, file_name)
    fig.savefig(out_file, dpi=200)
    plt.close(fig)
    print(f"Saved dzdt comparison plot to {out_file}")


def compute_rollout_loss(
    model: JointDynamicsModel,
    trajectories: np.ndarray,
    macro_feature: np.ndarray,
    mean: np.ndarray,
    std: np.ndarray,
    dt: float,
    batch_traj: int,
    device: torch.device,
    eps: float = 1e-12,
) -> float | None:
    n_traj, T, _, _ = trajectories.shape
    val_trajectories = trajectories
    val_macro = macro_feature

    if val_trajectories.shape[0] == 0:
        print("No validation trajectories available for rollout loss.")
        return None

    val_trajectories = normalize_trajectories(val_trajectories, mean, std)
    z_val = compute_z_over_time(
        encoder=model.encoder,
        trajectories=val_trajectories,
        device=device,
        batch_traj=batch_traj,
    )
    z_macro_true = np.concatenate([z_val, val_macro], axis=-1)
    print(f"z_macro_true shape for rollout loss: {z_macro_true.shape}")
    print(z_macro_true[0, 0:10, -1])

    with torch.no_grad():
        z_t = torch.from_numpy(z_macro_true[:, 0, :]).float().to(device)
        preds = [z_t]
        for _ in range(T - 1):
            z_t = rk4_step(model.drift, z_t, dt)
            preds.append(z_t)
        pred_z_macro = torch.stack(preds, dim=1)

        # Macro features are unnormalized for baseline training, so no denorm here.
        true_last = torch.from_numpy(z_macro_true[:, :, -1]).float().to(device)
        pred_last = pred_z_macro[:, :, -1]
        numerator = (pred_last - true_last).pow(2).sum(dim=1)
        denominator = true_last.pow(2).sum(dim=1)
        rel_error = numerator / (denominator + eps)
        rollout_loss = rel_error.mean().item()

    print(f"Rollout mean relative error (last dim): {rollout_loss:.6e}")
    return rollout_loss


def main() -> None:
    parser = argparse.ArgumentParser(description="Evaluate joint DeepSet encoder + drift MLP on validation data.")
    parser.add_argument("--data_dir", type=str, default=os.path.join(BASE_DIR, "generate_dataset", "data"))
    # parser.add_argument("--trajectories_path", type=str, default=None)
    # parser.add_argument("--macro_feature_path", type=str, default=None)
    parser.add_argument("--checkpoint_dir", type=str, default="trained_models_gmm2/exp1_seed123/")
    parser.add_argument("--output_dir", type=str, default=None)
    # parser.add_argument("--train_frac", type=float, default=0.8)
    parser.add_argument("--first_K", type=int, default=5)
    parser.add_argument("--dt", type=float, default=0.002)
    parser.add_argument("--batch_traj", type=int, default=32)
    parser.add_argument("--device", type=int, default=3)
    parser.add_argument(
        "--rollout_loss",
        action="store_true",
        help="Compute rollout mean relative error on validation trajectories.",
    )
    parser.set_defaults(rollout_loss=True)
    
    parser.add_argument("--z_dim", type=int, default=8)
    parser.add_argument("--encoder_hidden_dim", type=int, default=256)
    parser.add_argument("--drift_hidden_dim", type=int, default=64)
    parser.add_argument("--pool", type=str, default="mean", choices=["mean", "sum", "max"])

    args = parser.parse_args()
    device = resolve_device(args.device)
    print(f"Using device: {device}")



    # # in-distribution
    traj_file_name = "trajectories_inDistribution.npy"
    macro_file_name = "macro_feature_inDistribution.npy"

    # # out-of-distribution 3gmm
    # traj_file_name = "trajectories_outDistribution_3gmm.npy"
    # macro_file_name = "macro_feature_outDistribution_3gmm.npy"

    # out-of-distribution 400N
    # traj_file_name = "trajectories_outDistribution_400N.npy"
    # macro_file_name = "macro_feature_outDistribution_400N.npy"


    trajectories_path = os.path.join(args.data_dir, traj_file_name)
    macro_feature_path = os.path.join(args.data_dir, macro_file_name)
    
    trajectories, macro_feature = load_test_inputs(
        trajectories_path=trajectories_path,
        macro_feature_path=macro_feature_path,
    )
    n_traj, T, n_particles, data_dim = trajectories.shape
    macro_dim = macro_feature.shape[-1]
    val_trajectories = trajectories
    val_macro = macro_feature


    checkpoint_path = os.path.join(args.checkpoint_dir, "best_joint_model.pth")
    if args.output_dir is None:
        args.output_dir = os.path.dirname(args.checkpoint_dir) or "."
    os.makedirs(args.output_dir, exist_ok=True)

    fallback_args = {
        "z_dim": args.z_dim,
        "encoder_hidden_dim": args.encoder_hidden_dim,
        "drift_hidden_dim": args.drift_hidden_dim,
        "pool": args.pool,
    }
    model, ckpt_args, normalization_info = build_model_from_checkpoint(
        checkpoint_path=checkpoint_path,
        data_dim=data_dim,
        macro_dim=macro_dim,
        device=device,
        fallback_args=fallback_args,
    )

    macro_norm_path = os.path.join(args.data_dir, "macro_feature_normalization.npz")
    macro_min, macro_max = load_macro_normalization(macro_norm_path)
    
    mean = normalization_info["mean"]
    std = normalization_info["std"]


    val_trajectories = normalize_trajectories(val_trajectories, mean, std)

    n_exp = min(args.first_K, val_trajectories.shape[0])
    if n_exp == 0:
        print("No validation trajectories available for plotting.")
        return


    sort_idx = np.argsort(val_macro[:, 0, 0])
    val_trajectories = val_trajectories[sort_idx]
    val_macro = val_macro[sort_idx]
    num_tra = val_trajectories.shape[0]
    selected_idx = np.linspace(0, val_trajectories.shape[0]-1, min(args.first_K, val_trajectories.shape[0])).astype(int)
    selected_idx = [int(num_tra * 0.2), int(num_tra * 0.7), int(num_tra * 0.98)]
    # selected_idx = [490, 918, 980]
    
    print(f"Selected trajectory indices for plotting: {selected_idx}")

    val_trajectories = val_trajectories[selected_idx]
    val_macro = val_macro[selected_idx]
    
    z_val = compute_z_over_time(
        encoder=model.encoder,
        trajectories=val_trajectories,
        device=device,
        batch_traj=args.batch_traj,
    )
    z_macro_true = np.concatenate([z_val, val_macro], axis=-1)
    n_dim = z_macro_true.shape[-1]

    with torch.no_grad():
        z_t = torch.from_numpy(z_macro_true[:, 0, :]).float().to(device)
        preds = [z_t]
        for _ in range(T - 1):
            z_t = rk4_step(model.drift, z_t, args.dt)
            preds.append(z_t)
        pred_z_macro = torch.stack(preds, dim=1).cpu().numpy()

    # plot_true_vs_pred(
    #     true_series=z_macro_true,
    #     pred_series=pred_z_macro,
    #     dt=args.dt,
    #     output_path=args.output_dir,
    #     file_name="TEST-true_vs_pred_all_dims.png",
    #     ylabel_prefix="z_macro",
    # )

    plot_macro(
        true_series=z_macro_true,
        pred_series=pred_z_macro,
        dt=args.dt,
        output_path=args.output_dir,
        file_name="TEST-true_vs_pred_macro.png",
        macro_min=macro_min,
        macro_max=macro_max,
    )

    # z_macro_true_t = torch.from_numpy(z_macro_true).float().to(device)
    # z0_input = z_macro_true_t[:, :-1, :].reshape(-1, n_dim)
    # with torch.no_grad():
    #     dzdt_pred = model.drift(z0_input).reshape(n_exp, T - 1, n_dim).cpu().numpy()
    # dzdt_true = (z_macro_true[:, 1:, :] - z_macro_true[:, :-1, :]) / args.dt

    # plot_true_vs_pred_dzdt(
    #     true_series=dzdt_true,
    #     pred_dzdt=dzdt_pred,
    #     output_path=args.output_dir,
    #     file_name="TEST-true_vs_pred_dzdt_all_dims.png",
    #     ylabel_prefix="dzdt",
    # )

    # if args.rollout_loss:
    #     compute_rollout_loss(
    #         model=model,
    #         trajectories=trajectories,
    #         macro_feature=macro_feature,
    #         mean=mean,
    #         std=std,
    #         dt=args.dt,
    #         batch_traj=args.batch_traj,
    #         device=device,
    #     )


if __name__ == "__main__":
    main()
