import argparse
import math
import os
import sys
from pathlib import Path

import numpy as np
import torch

SCRIPT_DIR = Path(__file__).resolve().parent
MIX_DIR = SCRIPT_DIR.parent
if str(MIX_DIR) not in sys.path:
    sys.path.insert(0, str(MIX_DIR))

from models import SetEncoderND  # noqa: E402
from learn_sde.learn_mixing_dynamics import build_model, load_config  # noqa: E402
from learn_sde.mmd import compute_mmd_gpu  # noqa: E402


def _load_trajectories(npz_path: str) -> tuple[np.ndarray, np.ndarray]:
    with np.load(npz_path, allow_pickle=True) as data:
        trajectories = data["trajectories"]
        types = data["types"]
    return trajectories, types


def _load_macro_feature(npy_path: str) -> np.ndarray:
    macro_feature = np.load(npy_path, allow_pickle=True)
    if macro_feature.ndim != 3 or macro_feature.shape[2] != 2:
        raise ValueError(
            f"macro_feature must have shape (n_traj, T, 2), got {macro_feature.shape}"
        )
    return macro_feature


def _resolve_n_traj(requested: int | None, available: int) -> int:
    if requested is None or requested <= 0 or requested > available:
        return available
    return requested


def _normalize_macro_feature(
    macro_feature: np.ndarray, macro_min: np.ndarray, macro_max: np.ndarray
) -> np.ndarray:
    scale = macro_max - macro_min
    scale[scale == 0] = 1.0
    return 2.0 * (macro_feature - macro_min) / scale - 1.0


def _denormalize_macro_feature(
    macro_feature: np.ndarray, macro_min: np.ndarray, macro_max: np.ndarray
) -> np.ndarray:
    scale = macro_max - macro_min
    scale[scale == 0] = 1.0
    return 0.5 * (macro_feature + 1.0) * scale + macro_min


def _normalize_trajectories(
    trajectories: np.ndarray, traj_min: np.ndarray, traj_max: np.ndarray
) -> np.ndarray:
    scale = traj_max - traj_min
    scale[scale == 0] = 1.0
    return 2.0 * (trajectories - traj_min) / scale - 1.0


def _to_numpy(value) -> np.ndarray:
    if isinstance(value, torch.Tensor):
        value = value.detach().cpu().numpy()
    return np.asarray(value, dtype=np.float64)


def _load_norm_from_state(
    state: dict,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray] | None:
    norm = state.get("normalization")
    if not isinstance(norm, dict):
        return None
    required = ("macro_min", "macro_max", "traj_min", "traj_max")
    if not all(key in norm for key in required):
        return None
    macro_min = _to_numpy(norm["macro_min"]).reshape(1, 1, -1)
    macro_max = _to_numpy(norm["macro_max"]).reshape(1, 1, -1)
    traj_min = _to_numpy(norm["traj_min"]).reshape(1, 1, 1, -1)
    traj_max = _to_numpy(norm["traj_max"]).reshape(1, 1, 1, -1)
    return macro_min, macro_max, traj_min, traj_max


def _compute_z_per_type(
    set_encoder: SetEncoderND,
    anchors: torch.Tensor,
    types: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    if anchors.dim() != 4:
        raise ValueError("anchors must have shape (B, T, N, D)")
    if types.dim() != 2:
        raise ValueError("types must have shape (B, N)")

    batch, steps, n_points, dim = anchors.shape
    anchors_flat = anchors.reshape(batch * steps, n_points, dim)
    types_flat = types.unsqueeze(1).expand(batch, steps, n_points).reshape(
        batch * steps, n_points
    )

    mask_type1 = types_flat == 1
    mask_type2 = types_flat == 2

    z_type1 = set_encoder(anchors_flat, mask=mask_type1)
    z_type2 = set_encoder(anchors_flat, mask=mask_type2)

    z_type1 = z_type1.reshape(batch, steps, -1)
    z_type2 = z_type2.reshape(batch, steps, -1)
    return z_type1, z_type2


def _simulate_batch(
    model,
    x0: torch.Tensor,
    steps: int,
    dt: float,
) -> torch.Tensor:
    traj = [x0]
    x = x0
    for _ in range(steps - 1):
        drift = model.drift(None, x, create_graph=False)
        sigma = model.diffusion(None, x)
        noise = torch.randn(x.shape[0], x.shape[1], device=x.device, dtype=x.dtype)
        x = x + drift * dt + (sigma @ noise.unsqueeze(-1)).squeeze(-1) * math.sqrt(dt)
        traj.append(x)
    return torch.stack(traj, dim=1)


def _resolve_model_path(model_dir: str) -> str:
    if os.path.isfile(model_dir) and model_dir.endswith(".pt"):
        return model_dir
    direct_best = os.path.join(model_dir, "best_model.pt")
    if os.path.exists(direct_best):
        return direct_best
    direct_final = os.path.join(model_dir, "model.pt")
    if os.path.exists(direct_final):
        return direct_final

    if not os.path.isdir(model_dir):
        raise FileNotFoundError(f"Model directory not found: {model_dir}")

    subdirs = [
        os.path.join(model_dir, name)
        for name in os.listdir(model_dir)
        if os.path.isdir(os.path.join(model_dir, name))
    ]
    subdirs.sort(key=lambda p: os.path.getmtime(p), reverse=True)
    for subdir in subdirs:
        candidate = os.path.join(subdir, "best_model.pt")
        if os.path.exists(candidate):
            return candidate
        candidate = os.path.join(subdir, "model.pt")
        if os.path.exists(candidate):
            return candidate

    raise FileNotFoundError(
        f"Could not find best_model.pt or model.pt under {model_dir}."
    )


def _infer_encoder_dims(state: dict) -> tuple[int, int]:
    phi_weight = state["phi.0.weight"]
    hidden_dim = phi_weight.shape[0]
    z_dim = state["rho.2.weight"].shape[0]
    return hidden_dim, z_dim


def main() -> None:
    # default_exp_case = "right"
    # default_data_path = None
    # default_macro_path = None
    
    # default_exp_case = "inDistribution"
    default_exp_case = "diffN"
    default_data_path = MIX_DIR / "generate_data" / "dataset" / f"trajectories_{default_exp_case}_test.npz"
    default_macro_path = MIX_DIR / "generate_data" / "dataset" / f"macro_feature_{default_exp_case}_test.npy"    
    default_model_dir = SCRIPT_DIR / "trained_models" / "Z1" / "exp1"

    default_config = MIX_DIR / "learn_sde" / "config" / "polymer_dynamics.yaml"

    parser = argparse.ArgumentParser(
        description="Compute MMD between predicted and ground-truth macro trajectories."
    )
    parser.add_argument("--exp_case", type=str, default=default_exp_case)
    parser.add_argument("--data_path", type=str, default=default_data_path)
    parser.add_argument("--macro_path", type=str, default=default_macro_path)
    parser.add_argument("--n_traj", type=int, default=None)
    parser.add_argument("--model_dir", type=str, default=str(default_model_dir))
    parser.add_argument("--config", type=str, default=str(default_config))
    parser.add_argument(
        "--num_samples",
        "--num_eval",
        dest="num_samples",
        type=int,
        default=None,
        help="Number of trajectories to evaluate (default: all in the test file).",
    )
    parser.add_argument("--eval_batch_size", type=int, default=10)
    parser.add_argument("--device", type=int, default=0)
    parser.add_argument("--dtype", type=str, default="float64", choices=["float64", "float32"])
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--deepset_pool", type=str, default="mean", choices=["mean", "sum", "max"])

    args = parser.parse_args()

    exp_case = args.exp_case
    if args.data_path is None:
        args.data_path = str(
            MIX_DIR / "generate_data" / "dataset" / f"trajectories_test_{exp_case}.npz"
        )
    if args.macro_path is None:
        args.macro_path = str(
            MIX_DIR / "generate_data" / "dataset" / f"macro_feature_test_{exp_case}.npy"
        )

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    if args.dtype == "float64":
        torch.set_default_dtype(torch.float64)
    elif args.dtype == "float32":
        torch.set_default_dtype(torch.float32)
    else:
        raise ValueError("dtype must be float32 or float64")

    device = torch.device(
        f"cuda:{args.device}"
        if (args.device >= 0 and torch.cuda.is_available())
        else "cpu"
    )

    trajectories, types = _load_trajectories(args.data_path)
    macro_feature = _load_macro_feature(args.macro_path)
    if trajectories.ndim != 4:
        raise ValueError("trajectories must have shape (n_traj, T, N, D)")
    if types.ndim != 2:
        raise ValueError("types must have shape (n_traj, N)")
    if macro_feature.shape[0] != trajectories.shape[0]:
        raise ValueError("macro_feature and trajectories must have same n_traj")
    if macro_feature.shape[1] < trajectories.shape[1]:
        raise ValueError("macro_feature has fewer timesteps than trajectories.")

    n_traj = _resolve_n_traj(args.n_traj, trajectories.shape[0])
    trajectories = trajectories[:n_traj]
    types = types[:n_traj]
    macro_feature = macro_feature[:n_traj, : trajectories.shape[1], :]

    model_path = _resolve_model_path(args.model_dir)
    model_dir = os.path.dirname(model_path)
    state = torch.load(model_path, map_location=device, weights_only=False)

    config = state.get("config")
    if config is None:
        config_path = os.path.join(model_dir, "config.yaml")
        config = load_config(config_path if os.path.exists(config_path) else args.config)

    norm = _load_norm_from_state(state)
    if norm is None:
        raise KeyError("Missing normalization data in checkpoint.")
    macro_min, macro_max, traj_min, traj_max = norm
    macro_normed = _normalize_macro_feature(macro_feature, macro_min, macro_max)
    trajectories_normed = _normalize_trajectories(
        trajectories, traj_min, traj_max
    )

    num_samples = _resolve_n_traj(args.num_samples, trajectories.shape[0])
    if num_samples <= 0:
        raise ValueError("num_samples must be positive.")
    test_anchors = trajectories_normed[:num_samples]
    test_macro = macro_normed[:num_samples]
    test_macro_true = macro_feature[:num_samples]
    test_types = types[:num_samples]

    encoder_state = state["encoder_state_dict"]
    hidden_dim, z_dim = _infer_encoder_dims(encoder_state)
    data_dim = trajectories.shape[-1]

    set_encoder = SetEncoderND(
        in_dim=data_dim,
        hidden_dim=hidden_dim,
        z_dim=z_dim,
        pool=args.deepset_pool,
    ).to(device=device)
    set_encoder.load_state_dict(encoder_state)
    set_encoder.eval()

    macro_dim = test_macro.shape[2]
    reduced_dim = z_dim * 2 + macro_dim
    config["reduced_dim"] = reduced_dim
    dynamics_model = build_model(config, reduced_dim, device=device)
    dynamics_model.load_state_dict(state["dynamics_state_dict"])
    dynamics_model.eval()

    eval_bs = max(1, min(args.eval_batch_size, num_samples))
    steps = test_macro.shape[1]
    pred_macro = np.zeros((num_samples, steps, macro_dim), dtype=np.float64)

    dt_value = float(config["dt"])
    macro_slice = slice(reduced_dim - macro_dim, reduced_dim)
    with torch.no_grad():
        for start in range(0, num_samples, eval_bs):
            end = min(start + eval_bs, num_samples)
            anchors_t = torch.as_tensor(
                test_anchors[start:end],
                dtype=torch.get_default_dtype(),
                device=device,
            )
            macro_t = torch.as_tensor(
                test_macro[start:end],
                dtype=torch.get_default_dtype(),
                device=device,
            )
            types_t = torch.as_tensor(
                test_types[start:end],
                dtype=torch.long,
                device=device,
            )

            z_type1, z_type2 = _compute_z_per_type(set_encoder, anchors_t, types_t)
            z_macro_true = torch.cat([z_type1, z_type2, macro_t], dim=-1)
            x0 = z_macro_true[:, 0, :]
            z_macro_pred = _simulate_batch(dynamics_model, x0, steps, dt_value)

            pred_batch = z_macro_pred[:, :, macro_slice].cpu().numpy()
            pred_macro[start:end] = _denormalize_macro_feature(
                pred_batch, macro_min, macro_max
            )

    true_macro = np.asarray(test_macro_true, dtype=np.float64)
    pred_t = torch.as_tensor(pred_macro, device=device, dtype=torch.get_default_dtype())
    true_t = torch.as_tensor(true_macro, device=device, dtype=torch.get_default_dtype())
    mmd_per_timestep = compute_mmd_gpu(pred_t, true_t)
    mean_mmd = float(mmd_per_timestep.mean())

    print("MMD per timestep shape:")
    print(mmd_per_timestep.shape)
    print(f"Mean MMD: {mean_mmd:.6e}")


if __name__ == "__main__":
    main()
