import argparse
import logging
import os
import sys
from typing import Tuple

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
LEARN_ODE_DIR = os.path.join(BASE_DIR, "learn_ode")
if BASE_DIR not in sys.path:
    sys.path.append(BASE_DIR)
if LEARN_ODE_DIR not in sys.path:
    sys.path.append(LEARN_ODE_DIR)

from utils.model import DeepSetEncoder
from utils.utils import set_seed
from learn_ode.models import drift_MLP


def resolve_device(device_index: int) -> torch.device:
    if device_index < 0 or not torch.cuda.is_available():
        return torch.device("cpu")
    return torch.device(f"cuda:{device_index}")


def load_inputs(
    data_dir: str,
    trajectories_path: str | None,
    macro_feature_path: str | None,
    n_traj: int | None,
) -> Tuple[np.ndarray, np.ndarray]:
    if trajectories_path is None:
        trajectories_path = os.path.join(data_dir, "trajectories.npy")
        # trajectories_path = os.path.join(data_dir, "trajectories_n-traj10k_3gmm.npy")
    if macro_feature_path is None:
        macro_feature_path = os.path.join(data_dir, "macro_feature.npy")
        # macro_feature_path = os.path.join(data_dir, "macro_feature_n-traj10k_3gmm.npy")

    trajectories = np.load(trajectories_path, allow_pickle=True)
    macro_feature = np.load(macro_feature_path, allow_pickle=True)

    if trajectories.ndim != 4:
        raise ValueError(f"trajectories must be 4D, got shape={trajectories.shape}")
    if macro_feature.ndim == 2:
        macro_feature = macro_feature[..., None]
    if macro_feature.ndim != 3:
        raise ValueError(f"macro_feature must be 3D, got shape={macro_feature.shape}")

    if n_traj is not None:
        trajectories = trajectories[:n_traj]
        macro_feature = macro_feature[:n_traj]

    if trajectories.shape[0] != macro_feature.shape[0]:
        raise ValueError("trajectories and macro_feature must have same n_traj")
    if trajectories.shape[1] != macro_feature.shape[1]:
        raise ValueError("trajectories and macro_feature must have same T")

    return trajectories.astype(np.float32), macro_feature.astype(np.float32)


def compute_normalization(train_trajectories: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    mean = train_trajectories.mean(axis=(0, 1, 2), keepdims=True)
    std = train_trajectories.std(axis=(0, 1, 2), keepdims=True)
    std = np.where(std == 0, 1.0, std)
    return mean, std


def normalize_trajectories(
    trajectories: np.ndarray,
    mean: np.ndarray,
    std: np.ndarray,
) -> np.ndarray:
    return (trajectories - mean) / std


class TrajectoryPairDataset(Dataset):
    def __init__(self, trajectories: np.ndarray, macro_feature: np.ndarray) -> None:
        self.trajectories = torch.from_numpy(trajectories).float()
        self.macro_feature = torch.from_numpy(macro_feature).float()
        self.n_traj, self.T, _, _ = self.trajectories.shape
        if self.T < 2:
            raise ValueError("Need at least two timesteps to build transition pairs.")

    def __len__(self) -> int:
        return self.n_traj * (self.T - 1)

    def __getitem__(self, idx: int):
        t = idx % (self.T - 1)
        traj_idx = idx // (self.T - 1)
        x0 = self.trajectories[traj_idx, t]
        x1 = self.trajectories[traj_idx, t + 1]
        m0 = self.macro_feature[traj_idx, t]
        m1 = self.macro_feature[traj_idx, t + 1]
        return x0, x1, m0, m1


def make_dataloaders(
    trajectories: np.ndarray,
    macro_feature: np.ndarray,
    train_frac: float,
    batch_size: int,
    num_workers: int,
    pin_memory: bool,
) -> Tuple[DataLoader, DataLoader, dict]:
    n_traj = trajectories.shape[0]
    n_train = int(train_frac * n_traj)

    train_traj = trajectories[:n_train]
    val_traj = trajectories[n_train:]
    train_macro = macro_feature[:n_train]
    val_macro = macro_feature[n_train:]

    mean, std = compute_normalization(train_traj)
    train_traj = normalize_trajectories(train_traj, mean, std)
    val_traj = normalize_trajectories(val_traj, mean, std)

    train_ds = TrajectoryPairDataset(train_traj, train_macro)
    val_ds = TrajectoryPairDataset(val_traj, val_macro)

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        pin_memory=pin_memory,
        drop_last=False,
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        shuffle=False,
        pin_memory=pin_memory,
        drop_last=False,
    )

    normalization_info = {"mean": mean, "std": std}
    return train_loader, val_loader, normalization_info


class JointDynamicsModel(nn.Module):
    def __init__(
        self,
        in_dim: int,
        z_dim: int,
        macro_dim: int,
        encoder_hidden_dim: int,
        drift_hidden_dim: int,
        pool: str,
    ) -> None:
        super().__init__()
        self.encoder = DeepSetEncoder(
            in_dim=in_dim,
            hidden_dim=encoder_hidden_dim,
            z_dim=z_dim,
            pool=pool,
        )
        self.drift = drift_MLP(input_dim=z_dim + macro_dim, hidden_dim=drift_hidden_dim)

    def z_macro(self, x: torch.Tensor, macro: torch.Tensor) -> torch.Tensor:
        z = self.encoder(x)
        return torch.cat([z, macro], dim=-1)


def compute_loss(
    model: JointDynamicsModel,
    x0: torch.Tensor,
    x1: torch.Tensor,
    m0: torch.Tensor,
    m1: torch.Tensor,
    dt_scalar: float,
    weight_alpha: float,
    weight_eps: float,
) -> torch.Tensor:
    z_macro0 = model.z_macro(x0, m0)
    z_macro1 = model.z_macro(x1, m1)
    dzdt_pred = model.drift(z_macro0)
    dzdt_true = (z_macro1 - z_macro0) / dt_scalar
    weights = (torch.abs(z_macro1 - z_macro0) + weight_eps).pow(weight_alpha)
    return (weights * (dzdt_pred - dzdt_true).pow(2)).mean()


def run_epoch(
    model: JointDynamicsModel,
    loader: DataLoader,
    optimizer: torch.optim.Optimizer | None,
    device: torch.device,
    dt_scalar: float,
    weight_alpha: float,
    weight_eps: float,
) -> float:
    is_train = optimizer is not None
    model.train() if is_train else model.eval()
    losses = []

    for x0, x1, m0, m1 in loader:
        x0 = x0.to(device, non_blocking=True)
        x1 = x1.to(device, non_blocking=True)
        m0 = m0.to(device, non_blocking=True)
        m1 = m1.to(device, non_blocking=True)

        if is_train:
            optimizer.zero_grad()
        loss = compute_loss(
            model=model,
            x0=x0,
            x1=x1,
            m0=m0,
            m1=m1,
            dt_scalar=dt_scalar,
            weight_alpha=weight_alpha,
            weight_eps=weight_eps,
        )
        if is_train:
            loss.backward()
            optimizer.step()
        losses.append(loss.item())

    return float(np.mean(losses)) if losses else float("nan")


def train_joint_model(
    train_loader: DataLoader,
    val_loader: DataLoader,
    model: JointDynamicsModel,
    dt_scalar: float,
    n_epochs: int,
    lr: float,
    weight_alpha: float,
    weight_eps: float,
    save_dir: str,
    normalization_info: dict,
    device: torch.device,
    args_dict: dict,
) -> None:
    os.makedirs(save_dir, exist_ok=True)
    log_path = os.path.join(save_dir, "train_joint.log")
    logger = logging.getLogger(f"TrainJoint:{log_path}")
    logger.setLevel(logging.INFO)
    logger.handlers = []
    fh = logging.FileHandler(log_path, mode="w")
    fh.setLevel(logging.INFO)
    formatter = logging.Formatter("%(asctime)s - %(message)s")
    fh.setFormatter(formatter)
    logger.addHandler(fh)
    print(f"Logging to {log_path}")
    logger.info("Starting joint training")

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    best_val_loss = float("inf")

    for epoch in tqdm(range(1, n_epochs + 1), desc="Epochs"):
        train_loss = run_epoch(
            model=model,
            loader=train_loader,
            optimizer=optimizer,
            device=device,
            dt_scalar=dt_scalar,
            weight_alpha=weight_alpha,
            weight_eps=weight_eps,
        )
        with torch.no_grad():
            val_loss = run_epoch(
                model=model,
                loader=val_loader,
                optimizer=None,
                device=device,
                dt_scalar=dt_scalar,
                weight_alpha=weight_alpha,
                weight_eps=weight_eps,
            )

        msg = f"Epoch {epoch:4d} | Train Loss: {train_loss:8.6f} | Val Loss: {val_loss:8.6f}"
        logger.info(msg)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_path = os.path.join(save_dir, "best_joint_model.pth")
            checkpoint = {
                "model_state_dict": model.state_dict(),
                "normalization_info": normalization_info,
                "best_val_loss": best_val_loss,
                "args": args_dict,
            }
            torch.save(checkpoint, save_path)
            logger.info(f"Saved checkpoint to {save_path}")

    logger.info("Training finished")


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Train DeepSet encoder and drift MLP together using trajectories and macro features."
    )
    parser.add_argument("--data_dir", type=str, default=os.path.join(BASE_DIR, "generate_dataset", "data"))
    parser.add_argument("--trajectories_path", type=str, default=None)
    parser.add_argument("--macro_feature_path", type=str, default=None)
    parser.add_argument("--n_traj", type=int, default=None)
    parser.add_argument("--train_frac", type=float, default=0.8)

    parser.add_argument("--z_dim", type=int, default=8)
    parser.add_argument("--encoder_hidden_dim", type=int, default=128)
    parser.add_argument("--drift_hidden_dim", type=int, default=64)
    parser.add_argument("--pool", type=str, default="mean", choices=["mean", "sum", "max"])

    parser.add_argument("--dt", type=float, default=0.002)
    parser.add_argument("--epochs", type=int, default=100)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--batch_size", type=int, default=1024)
    parser.add_argument("--num_workers", type=int, default=0)
    parser.add_argument("--weight_alpha", type=float, default=0.0)
    parser.add_argument("--weight_eps", type=float, default=1e-6)
    # parser.add_argument("--save_dir", type=str, default="./trained_models_gmm3")
    parser.add_argument("--save_dir", type=str, default="./trained_models_gmm2")

    parser.add_argument("--device", type=int, default=3)
    parser.add_argument("--seed", type=int, default=42)

    args = parser.parse_args()
    set_seed(args.seed)
    device = resolve_device(args.device)
    print(f"Using device: {device}")

    trajectories, macro_feature = load_inputs(
        data_dir=args.data_dir,
        trajectories_path=args.trajectories_path,
        macro_feature_path=args.macro_feature_path,
        n_traj=args.n_traj,
    )
    n_traj, T, n_particles, data_dim = trajectories.shape
    macro_dim = macro_feature.shape[-1]
    print(f"trajectories shape: {trajectories.shape}")
    print(f"macro_feature shape: {macro_feature.shape}")

    pin_memory = device.type == "cuda"
    train_loader, val_loader, normalization_info = make_dataloaders(
        trajectories=trajectories,
        macro_feature=macro_feature,
        train_frac=args.train_frac,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=pin_memory,
    )

    model = JointDynamicsModel(
        in_dim=data_dim,
        z_dim=args.z_dim,
        macro_dim=macro_dim,
        encoder_hidden_dim=args.encoder_hidden_dim,
        drift_hidden_dim=args.drift_hidden_dim,
        pool=args.pool,
    ).to(device)

    args_dict = vars(args).copy()
    args_dict.update(
        {
            "data_dim": data_dim,
            "n_particles": n_particles,
            "macro_dim": macro_dim,
            "T": T,
            "n_traj": n_traj,
        }
    )
    os.makedirs(args.save_dir, exist_ok=True)
    experiments_done = len(os.listdir(args.save_dir))
    save_dir = os.path.join(args.save_dir, f"exp{experiments_done + 1}_seed{args.seed}")

    train_joint_model(
        train_loader=train_loader,
        val_loader=val_loader,
        model=model,
        dt_scalar=args.dt,
        n_epochs=args.epochs,
        lr=args.lr,
        weight_alpha=args.weight_alpha,
        weight_eps=args.weight_eps,
        save_dir=save_dir,
        normalization_info=normalization_info,
        device=device,
        args_dict=args_dict,
    )


if __name__ == "__main__":
    main()
