import argparse
import json
import os
import sys
from datetime import datetime
from pathlib import Path

import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
import yaml
from tqdm import tqdm

SCRIPT_DIR = Path(__file__).resolve().parent
MIX_DIR = SCRIPT_DIR.parent
if str(MIX_DIR) not in sys.path:
    sys.path.insert(0, str(MIX_DIR))

from models import SetEncoderND  # noqa: E402
from learn_sde.learn_mixing_dynamics import (  # noqa: E402
    build_model,
    load_config,
    mle_loss,
)


def _load_trajectories(npz_path: str) -> tuple[np.ndarray, np.ndarray]:
    with np.load(npz_path, allow_pickle=True) as data:
        trajectories = data["trajectories"]
        types = data["types"]
    return trajectories, types


def _resolve_n_traj(requested: int | None, available: int) -> int:
    if requested is None or requested <= 0 or requested > available:
        return available
    return requested


def _load_macro_feature(npy_path: str) -> np.ndarray:
    macro_feature = np.load(npy_path, allow_pickle=True)
    if macro_feature.ndim != 3 or macro_feature.shape[2] != 2:
        raise ValueError(
            f"macro_feature must have shape (n_traj, T, 2), got {macro_feature.shape}"
        )
    return macro_feature


def _normalize_macro_feature(
    macro_feature: np.ndarray,
    split_idx: int,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    train_macro = macro_feature[:split_idx]
    min_val = train_macro.min(axis=(0, 1), keepdims=True)
    max_val = train_macro.max(axis=(0, 1), keepdims=True)
    scale = max_val - min_val
    scale[scale == 0] = 1.0
    macro_norm = 2.0 * (macro_feature - min_val) / scale - 1.0
    return macro_norm, min_val, max_val


def _normalize_trajectories(
    trajectories: np.ndarray,
    split_idx: int,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    train_traj = trajectories[:split_idx]
    min_val = train_traj.min(axis=(0, 1, 2), keepdims=True)
    max_val = train_traj.max(axis=(0, 1, 2), keepdims=True)
    scale = max_val - min_val
    scale[scale == 0] = 1.0
    traj_norm = 2.0 * (trajectories - min_val) / scale - 1.0
    return traj_norm, min_val, max_val


def _as_float(value, name: str) -> float:
    if isinstance(value, (list, tuple)):
        if len(value) != 1:
            raise ValueError(f"{name} must be a scalar, got {value}")
        value = value[0]
    return float(value)


def _compute_z_per_type(
    set_encoder: SetEncoderND,
    anchors: torch.Tensor,
    types: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    if anchors.dim() != 4:
        raise ValueError("anchors must have shape (B, T, N, D)")
    if types.dim() != 2:
        raise ValueError("types must have shape (B, N)")

    batch, steps, n_points, dim = anchors.shape
    anchors_flat = anchors.reshape(batch * steps, n_points, dim)
    types_flat = types.unsqueeze(1).expand(batch, steps, n_points).reshape(
        batch * steps, n_points
    )

    mask_type1 = types_flat == 1
    mask_type2 = types_flat == 2

    z_type1 = set_encoder(anchors_flat, mask=mask_type1)
    z_type2 = set_encoder(anchors_flat, mask=mask_type2)

    z_type1 = z_type1.reshape(batch, steps, -1)
    z_type2 = z_type2.reshape(batch, steps, -1)
    return z_type1, z_type2


def _train_joint(
    train_loader: DataLoader,
    val_loader: DataLoader,
    set_encoder: SetEncoderND,
    dynamics_model: torch.nn.Module,
    config: dict,
    output_dir: str,
    device: torch.device,
    norm_stats: dict[str, np.ndarray] | None,
) -> None:
    lr = _as_float(config["train"]["opt"]["learning_rate"], "learning_rate")
    optimizer = torch.optim.Adam(
        list(set_encoder.parameters()) + list(dynamics_model.parameters()), lr=lr
    )

    rop = config["train"]["rop"]
    min_scale = _as_float(rop.get("min_scale", 0.0), "min_scale")
    min_lr = lr * min_scale
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        factor=_as_float(rop["factor"], "factor"),
        patience=int(rop["patience"]),
        cooldown=int(rop["cooldown"]),
        threshold=_as_float(rop["rtol"], "rtol"),
        min_lr=min_lr,
    )

    os.makedirs(output_dir, exist_ok=True)
    log_path = os.path.join(output_dir, "loss_log.csv")

    dt_value = float(config["dt"])
    num_epochs = config["train"]["num_epochs"]
    checkpoint_every = config["train"]["checkpoint_every"]

    best_val = float("inf")
    history = {"train": [], "val": []}
    norm_payload = norm_stats or {}

    def _save_checkpoint(path: str) -> None:
        torch.save(
            {
                "encoder_state_dict": set_encoder.state_dict(),
                "dynamics_state_dict": dynamics_model.state_dict(),
                "z_dim": set_encoder.rho[-1].out_features,
                "config": config,
                **norm_payload,
            },
            path,
        )

    with open(log_path, "w", encoding="utf-8") as log_file:
        log_file.write("epoch,train_loss,val_loss\n")
        log_file.flush()
        for epoch in tqdm(range(1, num_epochs + 1), desc="Training"):
            set_encoder.train()
            dynamics_model.train()
            train_losses = []
            for anchors, macro, types in train_loader:
                anchors = anchors.to(device=device)
                macro = macro.to(device=device)
                types = types.to(device=device)

                z_type1, z_type2 = _compute_z_per_type(set_encoder, anchors, types)
                z_macro = torch.cat([z_type1, z_type2, macro], dim=-1)

                x0 = z_macro[:, :-1, :]
                x1 = z_macro[:, 1:, :]
                _, _, dim = x0.shape
                x0_flat = x0.reshape(-1, dim)
                x1_flat = x1.reshape(-1, dim)

                dt = torch.full(
                    (x0_flat.shape[0], 1),
                    dt_value,
                    device=device,
                    dtype=x0_flat.dtype,
                )
                optimizer.zero_grad(set_to_none=True)
                loss = mle_loss(dynamics_model, x0_flat, x1_flat, dt)
                loss.backward()
                optimizer.step()
                train_losses.append(loss.item())

            train_loss = float(np.mean(train_losses)) if train_losses else float("nan")
            history["train"].append(train_loss)

            set_encoder.eval()
            dynamics_model.eval()
            val_losses = []
            with torch.no_grad():
                for anchors, macro, types in val_loader:
                    anchors = anchors.to(device=device)
                    macro = macro.to(device=device)
                    types = types.to(device=device)

                    z_type1, z_type2 = _compute_z_per_type(set_encoder, anchors, types)
                    z_macro = torch.cat([z_type1, z_type2, macro], dim=-1)

                    x0 = z_macro[:, :-1, :]
                    x1 = z_macro[:, 1:, :]
                    _, _, dim = x0.shape
                    x0_flat = x0.reshape(-1, dim)
                    x1_flat = x1.reshape(-1, dim)

                    dt = torch.full(
                        (x0_flat.shape[0], 1),
                        dt_value,
                        device=device,
                        dtype=x0_flat.dtype,
                    )
                    loss = mle_loss(
                        dynamics_model, x0_flat, x1_flat, dt, create_graph=False
                    )
                    val_losses.append(loss.item())

            val_loss = float(np.mean(val_losses)) if val_losses else float("nan")
            history["val"].append(val_loss)
            if val_losses:
                scheduler.step(val_loss)

            if val_loss < best_val:
                best_val = val_loss
                _save_checkpoint(os.path.join(output_dir, "best_model.pt"))

            if checkpoint_every and epoch % checkpoint_every == 0:
                ckpt_path = os.path.join(output_dir, f"checkpoint_epoch_{epoch}.pt")
                _save_checkpoint(ckpt_path)

            log_file.write(f"{epoch},{train_loss:.6f},{val_loss:.6f}\n")
            log_file.flush()

            # print(
            #     f"Epoch {epoch:4d}  train_loss={train_loss:.6f}  val_loss={val_loss:.6f}"
            # )

    _save_checkpoint(os.path.join(output_dir, "model.pt"))
    with open(os.path.join(output_dir, "loss_history.json"), "w", encoding="utf-8") as f:
        json.dump(history, f, indent=2)


def main() -> None:
    default_data_path = MIX_DIR / "generate_data" / "dataset" / "trajectories_large.npz"
    default_macro_path = MIX_DIR / "generate_data" / "dataset" / "macro_feature_large.npy"
    default_config = MIX_DIR / "learn_sde" / "config" / "polymer_dynamics.yaml"

    parser = argparse.ArgumentParser(
        description="Train DeepSet closure variables + dynamics on macro features."
    )
    parser.add_argument("--data_path", type=str, default=str(default_data_path))
    parser.add_argument("--macro_path", type=str, default=str(default_macro_path))
    parser.add_argument("--n_traj", type=int, default=None)

    parser.add_argument("--z_dim", type=int, default=1)
    parser.add_argument("--hidden_dim", type=int, default=128)
    parser.add_argument("--deepset_pool", type=str, default="mean", choices=["mean", "sum", "max"])

    parser.add_argument("--config", type=str, default=str(default_config))
    parser.add_argument("--output_dir", type=str, default=str(SCRIPT_DIR / "trained_models"))
    parser.add_argument("--device", type=int, default=3)
    parser.add_argument("--dtype", type=str, default="float64", choices=["float64", "float32"])
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--batch_size", type=int, default=15)
    parser.add_argument(
        "--drift_type",
        type=str,
        default="mlp",
        choices=["mlp", "onsager"],
    )
    parser.add_argument(
        "--diffusion_type",
        type=str,
        default="diagonal",
        choices=["constant", "diagonal"],
    )
    parser.add_argument("--num_epochs", type=int, default=100)

    args = parser.parse_args()

    device = torch.device(
        f"cuda:{args.device}"
        if (args.device >= 0 and torch.cuda.is_available())
        else "cpu"
    )

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    if args.dtype == "float64":
        torch.set_default_dtype(torch.float64)
    elif args.dtype == "float32":
        torch.set_default_dtype(torch.float32)
    else:
        raise ValueError("dtype must be float32 or float64")

    trajectories, types = _load_trajectories(args.data_path)
    if trajectories.ndim != 4:
        raise ValueError("trajectories must have shape (n_traj, T, N, D)")
    if types.ndim != 2:
        raise ValueError("types must have shape (n_traj, N)")

    n_traj = _resolve_n_traj(args.n_traj, trajectories.shape[0])
    trajectories = trajectories[:n_traj]
    types = types[:n_traj]

    macro_feature = _load_macro_feature(args.macro_path)
    if macro_feature.shape[0] < n_traj:
        raise ValueError(
            f"macro_feature has {macro_feature.shape[0]} trajectories, expected {n_traj}"
        )
    if macro_feature.shape[1] < trajectories.shape[1]:
        raise ValueError("macro_feature has fewer timesteps than trajectories.")
    macro_feature = macro_feature[:n_traj, : trajectories.shape[1], :]

    config = load_config(args.config)
    config.setdefault("model", {})
    config["model"].setdefault("drift", {})
    config["model"].setdefault("diffusion", {})
    config["model"]["drift"]["type"] = args.drift_type
    config["model"]["diffusion"]["type"] = args.diffusion_type
    config["model"]["seed"] = args.seed
    if args.num_epochs is not None:
        config["train"]["num_epochs"] = args.num_epochs

    # train_traj_len = config["train"].get("train_traj_len")
    # if train_traj_len is not None:
    #     trajectories = trajectories[:, :train_traj_len, :, :]
    #     macro_feature = macro_feature[:, :train_traj_len, :]

    print(f"trajectories shape: {trajectories.shape}, macro_feature shape: {macro_feature.shape}")
    split_idx = int(0.8 * n_traj)
    if split_idx <= 0 or split_idx >= n_traj:
        raise ValueError(f"Need at least 2 trajectories, got {n_traj}")

    macro_norm, macro_min, macro_max = _normalize_macro_feature(
        macro_feature, split_idx
    )
    traj_norm, traj_min, traj_max = _normalize_trajectories(
        trajectories, split_idx
    )

    default_dtype = torch.get_default_dtype()
    train_anchors = torch.as_tensor(traj_norm[:split_idx], dtype=default_dtype)
    val_anchors = torch.as_tensor(traj_norm[split_idx:], dtype=default_dtype)
    train_macro = torch.as_tensor(macro_norm[:split_idx], dtype=default_dtype)
    val_macro = torch.as_tensor(macro_norm[split_idx:], dtype=default_dtype)
    train_types = torch.as_tensor(types[:split_idx], dtype=torch.long)
    val_types = torch.as_tensor(types[split_idx:], dtype=torch.long)

    
    if args.batch_size is not None:
        config["train"]["batch_size"] = args.batch_size
    batch_size = config["train"]["batch_size"]
    train_loader = DataLoader(
        TensorDataset(train_anchors, train_macro, train_types),
        batch_size=batch_size,
        pin_memory=True,
        shuffle=True,
    )
    val_loader = DataLoader(
        TensorDataset(val_anchors, val_macro, val_types),
        batch_size=batch_size,
        pin_memory=True,
        shuffle=False,
    )

    data_dim = trajectories.shape[-1]
    set_encoder = SetEncoderND(
        in_dim=data_dim,
        hidden_dim=args.hidden_dim,
        z_dim=args.z_dim,
        pool=args.deepset_pool,
    ).to(device=device)

    reduced_dim = args.z_dim * 2 + macro_feature.shape[2]
    config["reduced_dim"] = reduced_dim
    dynamics_model = build_model(config, reduced_dim, device=device)

    timestamp = datetime.now().strftime("%Y_%m_%d-%H_%M_%S")
    output_dir = os.path.join(args.output_dir, f"Z{args.z_dim}", timestamp)
    os.makedirs(output_dir, exist_ok=True)

    with open(os.path.join(output_dir, "data_shapes.json"), "w", encoding="utf-8") as f:
        json.dump(
            {
                "trajectories_shape": list(trajectories.shape),
                "macro_feature_shape": list(macro_feature.shape),
            },
            f,
            indent=2,
        )

    config_out = os.path.join(output_dir, "config.yaml")
    if not os.path.exists(config_out):
        with open(config_out, "w", encoding="utf-8") as f:
            yaml.safe_dump(config, f, sort_keys=False)

    _train_joint(
        train_loader=train_loader,
        val_loader=val_loader,
        set_encoder=set_encoder,
        dynamics_model=dynamics_model,
        config=config,
        output_dir=output_dir,
        device=device,
        norm_stats={
            "normalization": {
                "macro_min": macro_min,
                "macro_max": macro_max,
                "traj_min": traj_min,
                "traj_max": traj_max,
            }
        },
    )


if __name__ == "__main__":
    main()
