import argparse
import math
import os
import sys
from pathlib import Path

import numpy as np
import torch
import yaml
import matplotlib.pyplot as plt

SCRIPT_DIR = Path(__file__).resolve().parent
MIX_DIR = SCRIPT_DIR.parent
if str(MIX_DIR) not in sys.path:
    sys.path.insert(0, str(MIX_DIR))

from models import SetEncoderND  # noqa: E402
from learn_sde.learn_mixing_dynamics import build_model, load_config  # noqa: E402


def _load_trajectories(npz_path: str) -> tuple[np.ndarray, np.ndarray]:
    with np.load(npz_path, allow_pickle=True) as data:
        trajectories = data["trajectories"]
        types = data["types"]
    return trajectories, types


def _load_macro_feature(npy_path: str) -> np.ndarray:
    macro_feature = np.load(npy_path, allow_pickle=True)
    if macro_feature.ndim != 3 or macro_feature.shape[2] != 2:
        raise ValueError(
            f"macro_feature must have shape (n_traj, T, 2), got {macro_feature.shape}"
        )
    return macro_feature


def _resolve_n_traj(requested: int | None, available: int) -> int:
    if requested is None or requested <= 0 or requested > available:
        return available
    return requested


def _normalize_macro_feature(
    macro_feature: np.ndarray, macro_min: np.ndarray, macro_max: np.ndarray
) -> np.ndarray:
    scale = macro_max - macro_min
    scale[scale == 0] = 1.0
    return 2.0 * (macro_feature - macro_min) / scale - 1.0


def _denormalize_macro_feature(
    macro_feature: np.ndarray, macro_min: np.ndarray, macro_max: np.ndarray
) -> np.ndarray:
    scale = macro_max - macro_min
    scale[scale == 0] = 1.0
    return 0.5 * (macro_feature + 1.0) * scale + macro_min


def _normalize_trajectories(
    trajectories: np.ndarray, traj_min: np.ndarray, traj_max: np.ndarray
) -> np.ndarray:
    scale = traj_max - traj_min
    scale[scale == 0] = 1.0
    return 2.0 * (trajectories - traj_min) / scale - 1.0


def _to_numpy(value) -> np.ndarray:
    if isinstance(value, torch.Tensor):
        value = value.detach().cpu().numpy()
    return np.asarray(value, dtype=np.float64)


def _load_norm_from_state(
    state: dict,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray] | None:
    norm = state.get("normalization")
    if not isinstance(norm, dict):
        return None
    required = ("macro_min", "macro_max", "traj_min", "traj_max")
    if not all(key in norm for key in required):
        return None
    macro_min = _to_numpy(norm["macro_min"]).reshape(1, 1, -1)
    macro_max = _to_numpy(norm["macro_max"]).reshape(1, 1, -1)
    traj_min = _to_numpy(norm["traj_min"]).reshape(1, 1, 1, -1)
    traj_max = _to_numpy(norm["traj_max"]).reshape(1, 1, 1, -1)
    return macro_min, macro_max, traj_min, traj_max


def _compute_z_per_type(
    set_encoder: SetEncoderND,
    anchors: torch.Tensor,
    types: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    if anchors.dim() != 4:
        raise ValueError("anchors must have shape (B, T, N, D)")
    if types.dim() != 2:
        raise ValueError("types must have shape (B, N)")

    batch, steps, n_points, dim = anchors.shape
    anchors_flat = anchors.reshape(batch * steps, n_points, dim)
    types_flat = types.unsqueeze(1).expand(batch, steps, n_points).reshape(
        batch * steps, n_points
    )

    mask_type1 = types_flat == 1
    mask_type2 = types_flat == 2

    z_type1 = set_encoder(anchors_flat, mask=mask_type1)
    z_type2 = set_encoder(anchors_flat, mask=mask_type2)

    z_type1 = z_type1.reshape(batch, steps, -1)
    z_type2 = z_type2.reshape(batch, steps, -1)
    return z_type1, z_type2


def _simulate_batch(
    model,
    x0: torch.Tensor,
    steps: int,
    dt: float,
) -> torch.Tensor:
    traj = [x0]
    x = x0
    for _ in range(steps - 1):
        drift = model.drift(None, x, create_graph=False)
        sigma = model.diffusion(None, x)
        noise = torch.randn(x.shape[0], x.shape[1], device=x.device, dtype=x.dtype)
        x = x + drift * dt + (sigma @ noise.unsqueeze(-1)).squeeze(-1) * math.sqrt(dt)
        traj.append(x)
    return torch.stack(traj, dim=1)


def _resolve_model_path(model_dir: str) -> str:
    if os.path.isfile(model_dir) and model_dir.endswith(".pt"):
        return model_dir
    direct_best = os.path.join(model_dir, "best_model.pt")
    if os.path.exists(direct_best):
        return direct_best
    direct_final = os.path.join(model_dir, "model.pt")
    if os.path.exists(direct_final):
        return direct_final

    if not os.path.isdir(model_dir):
        raise FileNotFoundError(f"Model directory not found: {model_dir}")

    subdirs = [
        os.path.join(model_dir, name)
        for name in os.listdir(model_dir)
        if os.path.isdir(os.path.join(model_dir, name))
    ]
    subdirs.sort(key=lambda p: os.path.getmtime(p), reverse=True)
    for subdir in subdirs:
        candidate = os.path.join(subdir, "best_model.pt")
        if os.path.exists(candidate):
            return candidate
        candidate = os.path.join(subdir, "model.pt")
        if os.path.exists(candidate):
            return candidate

    raise FileNotFoundError(
        f"Could not find best_model.pt or model.pt under {model_dir}."
    )


def _infer_encoder_dims(state: dict) -> tuple[int, int]:
    phi_weight = state["phi.0.weight"]
    hidden_dim = phi_weight.shape[0]
    z_dim = state["rho.2.weight"].shape[0]
    return hidden_dim, z_dim


def main() -> None:
    default_exp_case = "right"
    default_data_path = None
    default_macro_path = None
    default_config = MIX_DIR / "learn_sde" / "config" / "polymer_dynamics.yaml"
    out_npz = f"pred_vs_true_{default_exp_case}.npz"
    default_model_dir = SCRIPT_DIR / "trained_models" / "Z1"


    parser = argparse.ArgumentParser(
        description="Evaluate baseline dynamics model against test trajectories."
    )
    parser.add_argument("--exp_case", type=str, default=default_exp_case)
    parser.add_argument("--data_path", type=str, default=default_data_path)
    parser.add_argument("--macro_path", type=str, default=default_macro_path)
    parser.add_argument("--n_traj", type=int, default=None)
    parser.add_argument("--model_dir", type=str, default=str(default_model_dir))
    parser.add_argument("--config", type=str, default=str(default_config))
    parser.add_argument("--output_dir", type=str, default=None)
    parser.add_argument("--num_eval", "--K", type=int, default=500)
    parser.add_argument("--output_fig", type=str, default=None)
    parser.add_argument("--eval_batch_size", type=int, default=10)
    parser.add_argument("--device", type=int, default=0)
    parser.add_argument("--dtype", type=str, default="float64", choices=["float64", "float32"])
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--deepset_pool", type=str, default="mean", choices=["mean", "sum", "max"])

    args = parser.parse_args()

    exp_case = args.exp_case
    if args.data_path is None:
        args.data_path = str(
            MIX_DIR / "generate_data" / "dataset" / f"trajectories_test_{exp_case}.npz"
        )
    if args.macro_path is None:
        args.macro_path = str(
            MIX_DIR / "generate_data" / "dataset" / f"macro_feature_test_{exp_case}.npy"
        )
    if args.output_fig is None:
        args.output_fig = f"mean_std_compare_test_{exp_case}.png"

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    if args.dtype == "float64":
        torch.set_default_dtype(torch.float64)
    elif args.dtype == "float32":
        torch.set_default_dtype(torch.float32)
    else:
        raise ValueError("dtype must be float32 or float64")

    device = torch.device(
        f"cuda:{args.device}"
        if (args.device >= 0 and torch.cuda.is_available())
        else "cpu"
    )

    trajectories, types = _load_trajectories(args.data_path)
    macro_feature = _load_macro_feature(args.macro_path)
    if trajectories.ndim != 4:
        raise ValueError("trajectories must have shape (n_traj, T, N, D)")
    if types.ndim != 2:
        raise ValueError("types must have shape (n_traj, N)")
    if macro_feature.shape[0] != trajectories.shape[0]:
        raise ValueError("macro_feature and trajectories must have same n_traj")
    if macro_feature.shape[1] < trajectories.shape[1]:
        raise ValueError("macro_feature has fewer timesteps than trajectories.")

    n_traj = _resolve_n_traj(args.n_traj, trajectories.shape[0])
    trajectories = trajectories[:n_traj]
    types = types[:n_traj]
    macro_feature = macro_feature[:n_traj, : trajectories.shape[1], :]

    model_path = _resolve_model_path(args.model_dir)
    model_dir = os.path.dirname(model_path)
    state = torch.load(model_path, map_location=device, weights_only=False)

    config = state.get("config")
    if config is None:
        config_path = os.path.join(model_dir, "config.yaml")
        config = load_config(config_path if os.path.exists(config_path) else args.config)

    norm = _load_norm_from_state(state)
    if norm is None:
        raise KeyError("Missing normalization data in checkpoint.")
    macro_min, macro_max, traj_min, traj_max = norm
    macro_normed = _normalize_macro_feature(macro_feature, macro_min, macro_max)
    trajectories_normed = _normalize_trajectories(
        trajectories, traj_min, traj_max
    )

    num_eval = min(args.num_eval, trajectories.shape[0])
    if num_eval <= 0:
        raise ValueError("num_eval must be positive.")
    test_anchors = trajectories_normed[:num_eval]
    test_macro = macro_normed[:num_eval]
    test_types = types[:num_eval]

    encoder_state = state["encoder_state_dict"]
    hidden_dim, z_dim = _infer_encoder_dims(encoder_state)
    data_dim = trajectories.shape[-1]

    set_encoder = SetEncoderND(
        in_dim=data_dim,
        hidden_dim=hidden_dim,
        z_dim=z_dim,
        pool=args.deepset_pool,
    ).to(device=device)
    set_encoder.load_state_dict(encoder_state)
    set_encoder.eval()

    reduced_dim = z_dim * 2 + test_macro.shape[2]
    config["reduced_dim"] = reduced_dim
    dynamics_model = build_model(config, reduced_dim, device=device)
    dynamics_model.load_state_dict(state["dynamics_state_dict"])
    dynamics_model.eval()

    eval_bs = max(1, min(args.eval_batch_size, num_eval))
    steps = test_macro.shape[1]
    true_all = np.zeros((num_eval, steps, reduced_dim), dtype=np.float64)
    pred_all = np.zeros((num_eval, steps, reduced_dim), dtype=np.float64)

    dt_value = float(config["dt"])
    with torch.no_grad():
        for start in range(0, num_eval, eval_bs):
            end = min(start + eval_bs, num_eval)
            anchors_t = torch.as_tensor(
                test_anchors[start:end],
                dtype=torch.get_default_dtype(),
                device=device,
            )
            macro_t = torch.as_tensor(
                test_macro[start:end],
                dtype=torch.get_default_dtype(),
                device=device,
            )
            types_t = torch.as_tensor(
                test_types[start:end],
                dtype=torch.long,
                device=device,
            )

            z_type1, z_type2 = _compute_z_per_type(set_encoder, anchors_t, types_t)
            z_macro_true = torch.cat([z_type1, z_type2, macro_t], dim=-1)
            x0 = z_macro_true[:, 0, :]
            z_macro_pred = _simulate_batch(dynamics_model, x0, steps, dt_value)

            true_np = z_macro_true.cpu().numpy()
            pred_np = z_macro_pred.cpu().numpy()

            true_all[start:end] = true_np
            pred_all[start:end] = pred_np

    macro_dim = test_macro.shape[2]
    macro_slice = slice(reduced_dim - macro_dim, reduced_dim)
    true_plot = true_all.copy()
    pred_plot = pred_all.copy()
    true_plot[..., macro_slice] = _denormalize_macro_feature(
        true_plot[..., macro_slice], macro_min, macro_max
    )
    pred_plot[..., macro_slice] = _denormalize_macro_feature(
        pred_plot[..., macro_slice], macro_min, macro_max
    )

    pred_mean = pred_plot.mean(axis=0)
    true_mean = true_plot.mean(axis=0)
    pred_std = pred_plot.std(axis=0)
    true_std = true_plot.std(axis=0)

    print(f"Evaluated {num_eval} trajectories from test set.")

    output_dir = args.output_dir or os.path.join(model_dir, "eval_test")
    os.makedirs(output_dir, exist_ok=True)
    pred_macro = pred_plot[..., macro_slice]
    true_macro = true_plot[..., macro_slice]
    np.savez_compressed(
        os.path.join(output_dir, out_npz),
        pred=pred_macro,
        true=true_macro,
    )

    t = np.arange(steps)
    n_dim = reduced_dim
    n_cols = min(3, n_dim)
    n_rows = int(math.ceil(n_dim / n_cols))
    fig, axes = plt.subplots(
        n_rows,
        n_cols,
        figsize=(4 * n_cols, 3.5 * n_rows),
        sharex=True,
    )
    axes = np.array(axes).reshape(n_rows, n_cols)
    for dim in range(n_dim):
        ax = axes[dim // n_cols, dim % n_cols]
        ax.plot(t, true_mean[:, dim], label="true mean", color="black", linewidth=1.2)
        ax.fill_between(
            t,
            true_mean[:, dim] - true_std[:, dim],
            true_mean[:, dim] + true_std[:, dim],
            color="black",
            alpha=0.2,
            label="true std" if dim == 0 else None,
        )
        ax.plot(t, pred_mean[:, dim], label="pred mean", color="#1f77b4")
        ax.fill_between(
            t,
            pred_mean[:, dim] - pred_std[:, dim],
            pred_mean[:, dim] + pred_std[:, dim],
            color="#1f77b4",
            alpha=0.2,
            label="pred std" if dim == 0 else None,
        )
        ax.set_ylabel(f"dim {dim}")
    for ax in axes[-1, :]:
        ax.set_xlabel("time step")
    axes[0, 0].legend(loc="best")
    fig.suptitle("Test rollouts: mean/std comparison")
    fig.tight_layout()

    out_fig = os.path.join(output_dir, args.output_fig)
    fig.savefig(out_fig, dpi=200)
    plt.close(fig)
    print(f"Saved figure to {out_fig}")
    with open(os.path.join(output_dir, "config.yaml"), "w", encoding="utf-8") as f:
        yaml.safe_dump(config, f, sort_keys=False)


if __name__ == "__main__":
    main()
