import os
import logging
import argparse

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm

import matplotlib.pyplot as plt

from utils.utils import set_seed
from utils.model import MLPDecoder
from models import SetEncoderND

# python -m pip install chamferdist
from chamferdist import ChamferDistance


# ---------------------------------------------------------------------------
# Data loading
# ---------------------------------------------------------------------------

def _should_record(time_step: int) -> bool:
    i = time_step + 1
    if i <= 300:
        return True
    elif i % 20 == 0:
        return True
    else:
        return False


def _load_input_arrays(data_path: str):
    data = np.load(data_path, allow_pickle=True)
    if not isinstance(data, np.lib.npyio.NpzFile):
        raise ValueError("Expected an NPZ file with 'trajectories' and 'types'.")
    if "trajectories" not in data.files:
        raise KeyError("NPZ file must contain 'trajectories'.")
    if "types" not in data.files:
        raise KeyError("NPZ file must contain 'types'.")
    trajectories = data["trajectories"]
    types = data["types"]
    data.close()
    return trajectories, types


def make_dataloaders(
    data_path: str,
    bs: int,
    n_traj: int,
    device: str | torch.device = "cpu",
):
    """
    Load training & validation data and wrap them in TensorDataset/DataLoader.

    Each file is assumed to be an NPZ with 'trajectories' and 'types'.
    We use trajectories of shape [E, T, M, D] and types of shape [E, M].

    We reshape to (E*T_sel, M, D) and repeat types to match each timestep.
    """

    input_data, types = _load_input_arrays(data_path)
    assert input_data.shape[0] >= n_traj, f"data has {input_data.shape[0]} trajectories, but n_traj={n_traj}"
    input_data = input_data[:n_traj]  # (E, T, M, D)
    if types is None:
        raise ValueError("types is required for type-split DeepSet encoding.")
    if types.ndim != 2:
        raise ValueError(f"types must be 2D (num_traj, N), got shape {types.shape}")
    if types.shape[0] < n_traj:
        raise ValueError(f"types has {types.shape[0]} trajectories, but n_traj={n_traj}")
    types = types[:n_traj]
    if types.shape[1] != input_data.shape[2]:
        raise ValueError(
            f"types has N={types.shape[1]}, but data has N={input_data.shape[2]}"
        )
    unique_types = np.unique(types)
    if not np.all(np.isin(unique_types, [1, 2])):
        raise ValueError(f"types must contain only 1 and 2, got {unique_types}")

    print(f"input_data shape: {input_data.shape}")
    print(f"types shape: {types.shape}")

    # Select timesteps according to _should_record
    selected_timesteps = [t for t in range(input_data.shape[1]) if _should_record(t)]
    input_data = input_data[:, selected_timesteps, :, :]  # (E, T_sel, M, D)
    print(f"Selected {len(selected_timesteps)} timesteps: {selected_timesteps}")
    print(f"input_data shape after timestep selection: {input_data.shape}")

    # normalize data to [-1, 1] in each dimension
    data_min = input_data.min(axis=(0, 1, 2), keepdims=True)
    data_max = input_data.max(axis=(0, 1, 2), keepdims=True)
    scale = data_max - data_min
    scale[scale == 0] = 1.0
    norm_input_data = 2.0 * (input_data - data_min) / scale - 1.0

    normalization_info = {
        "min": data_min,
        "max": data_max,
    }
    print(f"Data normalization info: {normalization_info}")

    train_num = int(0.8 * norm_input_data.shape[0])
    train_pos = norm_input_data[:train_num]   # (E_train, T, M, D)
    val_pos = norm_input_data[train_num:]     # (E_val, T, M, D)
    train_types = types[:train_num]
    val_types = types[train_num:]

    print(
        f"train x,y max: {train_pos.max(axis=(0, 1, 2))}, "
        f"min: {train_pos.min(axis=(0, 1, 2))}"
    )
    print(
        f"val   x,y max: {val_pos.max(axis=(0, 1, 2))}, "
        f"min: {val_pos.min(axis=(0, 1, 2))}"
    )

    # Reshape: (E,T_sel,M,D) -> (E*T_sel, M, D)
    train_anchors = torch.from_numpy(
        train_pos.reshape(-1, train_pos.shape[2], train_pos.shape[3])
    ).float()

    val_anchors = torch.from_numpy(
        val_pos.reshape(-1, val_pos.shape[2], val_pos.shape[3])
    ).float()

    # Repeat types per timestep: (E, M) -> (E, T_sel, M) -> (E*T_sel, M)
    train_types = np.repeat(train_types[:, None, :], train_pos.shape[1], axis=1)
    val_types = np.repeat(val_types[:, None, :], val_pos.shape[1], axis=1)
    train_types = torch.from_numpy(train_types.reshape(-1, train_types.shape[2])).long()
    val_types = torch.from_numpy(val_types.reshape(-1, val_types.shape[2])).long()

    assert train_anchors.dim() == 3
    assert val_anchors.dim() == 3

    dim = train_pos.shape[3]
    n_anchors = train_pos.shape[2]

    # For autoencoder, we also need the flattened input dimension
    input_dim = n_anchors * dim

    # Free large numpy arrays once tensors are created
    del input_data, norm_input_data, train_pos, val_pos

    # Move data to device if GPU is available (for A100 with sufficient memory)
    if str(device) != "cpu":
        print(f"Preloading data to {device} for faster training...")
        train_anchors = train_anchors.to(device)
        val_anchors = val_anchors.to(device)
        train_types = train_types.to(device)
        val_types = val_types.to(device)
        use_pin_memory = False
    else:
        use_pin_memory = True

    # Wrap in datasets/loaders; each item is a single experiment's anchor set: (M,D)
    train_ds = TensorDataset(train_anchors, train_types)
    val_ds = TensorDataset(val_anchors, val_types)
    print(f"train_anchors shape: {train_anchors.shape}, val_anchors shape: {val_anchors.shape}")
    print(f"train_types shape: {train_types.shape}, val_types shape: {val_types.shape}")

    train_loader = DataLoader(train_ds, batch_size=bs, pin_memory=use_pin_memory, shuffle=True, drop_last=False)
    val_loader = DataLoader(val_ds, batch_size=bs, pin_memory=use_pin_memory, shuffle=True, drop_last=False)

    return train_loader, val_loader, dim, n_anchors, input_dim, normalization_info


# ---------------------------------------------------------------------------
# Model
# ---------------------------------------------------------------------------

class TypeSplitDeepSetAutoencoder(nn.Module):
    """
    Shared DeepSet encoder is applied separately to type-1 and type-2 particles.
    The two latent codes are concatenated and decoded by an MLP to reconstruct
    all particles (decoder does not distinguish types).
    """

    def __init__(
        self,
        in_dim: int,
        n_points: int,
        hidden_dim: int = 256,
        z_dim: int = 16,
        pool: str = "mean",
    ):
        super().__init__()
        self.in_dim = in_dim
        self.n_points = n_points
        self.z_dim = z_dim
        self.encoder = SetEncoderND(in_dim=in_dim, hidden_dim=hidden_dim, z_dim=z_dim, pool=pool)
        self.decoder = MLPDecoder(z_dim * 2, hidden_dim, n_points * in_dim)

    def forward(self, x: torch.Tensor, types: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """
        x: (batch_size, M, D)
        types: (batch_size, M) with values {1, 2}
        returns: (z_concat, x_reconstructed)
        """
        if types.dim() != 2 or types.shape[:2] != x.shape[:2]:
            raise ValueError(f"types must have shape {x.shape[:2]}, got {tuple(types.shape)}")

        mask_type1 = types == 1
        mask_type2 = types == 2

        if (mask_type1.sum(dim=1) == 0).any() or (mask_type2.sum(dim=1) == 0).any():
            raise ValueError("Each sample must contain at least one particle of each type.")

        z_type1 = self.encoder(x, mask=mask_type1)
        z_type2 = self.encoder(x, mask=mask_type2)
        z = torch.cat([z_type1, z_type2], dim=-1)

        x_recon_flat = self.decoder(z)
        x_recon = x_recon_flat.view(x.shape[0], self.n_points, self.in_dim)
        return z, x_recon


# ---------------------------------------------------------------------------
# Training
# ---------------------------------------------------------------------------

def train_deepset_autoencoder(
    data_path: str,
    n_traj: int,
    dim: int | None = None,
    bs: int = 32,
    z_dim: int = 16,
    hidden_dim: int = 256,
    pool: str = "mean",
    lr: float = 1e-3,
    num_epochs: int = 100,
    seed: int = 42,
    device: str | torch.device | None = None,
    save_path: str = "deepset_Chamfer_autoencoder.pth",
    args_dict: dict | None = None,
):
    """
    Train a DeepSet autoencoder on trajectory data with type-split encoding.

    Args:
        pool: Pooling type for DeepSet encoder (mean/max/sum).
        seed: Random seed for reproducibility.
    """
    # Set seed for reproducibility
    set_seed(seed)

    # Device
    if device is None:
        device = "cuda:0" if torch.cuda.is_available() else "cpu"
    device = torch.device(device)

    # Data (preload to GPU if available)
    train_loader, val_loader, data_dim, n_anchors, input_dim, normalization_info = make_dataloaders(
        data_path=data_path,
        bs=bs,
        n_traj=n_traj,
        device=device,
    )

    # Logger
    log_dir = os.path.dirname(save_path) if save_path else "."
    os.makedirs(log_dir, exist_ok=True)
    log_path = os.path.join(log_dir, f"train_deepset_Chamfer_autoencoder_Z{z_dim}.log")
    logger = logging.getLogger(f"TrainDeepSetChamferAutoencoder:{log_path}")
    logger.setLevel(logging.INFO)
    logger.handlers = []  # avoid duplicate handlers
    fh = logging.FileHandler(log_path, mode="w")
    fh.setLevel(logging.INFO)
    formatter = logging.Formatter("%(asctime)s - %(message)s")
    fh.setFormatter(formatter)
    logger.addHandler(fh)

    print(f"Logging to {log_path}")
    logger.info("Starting DeepSet Chamfer autoencoder training (type-split encoder)")
    logger.info(f"Pool type: {pool}")
    logger.info(f"Input dimension: {input_dim} (n_anchors={n_anchors}, dim={data_dim})")
    logger.info(f"z_dim (per type): {z_dim}")

    # Build model
    model = TypeSplitDeepSetAutoencoder(
        in_dim=data_dim,
        n_points=n_anchors,
        hidden_dim=hidden_dim,
        z_dim=z_dim,
        pool=pool,
    ).to(device)

    # Count parameters
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    logger.info(f"Model parameters: {num_params:,}")

    optimizer = optim.Adam(model.parameters(), lr=lr)
    chamfer = ChamferDistance()

    best_val_loss = float("inf")
    train_loss_history = []
    val_loss_history = []

    # ----------------------------------------------------------------------
    # Training loop
    # ----------------------------------------------------------------------
    for epoch in tqdm(range(1, num_epochs + 1)):
        model.train()
        train_losses = []

        for anchors_batch, types_batch in train_loader:
            # Move to device only if not already there (data preloading optimization)
            if anchors_batch.device != device:
                anchors_batch = anchors_batch.to(device, non_blocking=True)
            if types_batch.device != device:
                types_batch = types_batch.to(device, non_blocking=True)

            # Forward pass
            z, x_recon = model(anchors_batch, types_batch)

            # Reconstruction loss
            loss = chamfer(x_recon, anchors_batch, bidirectional=True)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_losses.append(loss.item())

        mean_train_loss = float(np.mean(train_losses))

        # ------------------------------------------------------------------
        # Validation
        # ------------------------------------------------------------------
        model.eval()
        val_losses = []
        with torch.no_grad():
            for anchors_batch_val, types_batch_val in val_loader:
                if anchors_batch_val.device != device:
                    anchors_batch_val = anchors_batch_val.to(device, non_blocking=True)
                if types_batch_val.device != device:
                    types_batch_val = types_batch_val.to(device, non_blocking=True)

                # Forward pass
                z_val, x_recon_val = model(anchors_batch_val, types_batch_val)
                val_loss = chamfer(x_recon_val, anchors_batch_val, bidirectional=True)
                val_losses.append(val_loss.item())

        mean_val_loss = float(np.mean(val_losses))

        msg = (
            f"Epoch {epoch:4d} | "
            f"Train Loss: {mean_train_loss:8.6f} | "
            f"Val Loss: {mean_val_loss:8.6f}"
        )
        train_loss_history.append(mean_train_loss)
        val_loss_history.append(mean_val_loss)
        logger.info(msg)

        # Save best model
        if mean_val_loss < best_val_loss and save_path is not None:
            best_val_loss = mean_val_loss
            checkpoint = {
                "model_state_dict": model.state_dict(),
                "input_dim": input_dim,
                "n_anchors": n_anchors,
                "data_dim": data_dim,
                "z_dim": z_dim,
                "concat_z_dim": z_dim * 2,
                "hidden_dim": hidden_dim,
                "pool": pool,
                "lr": lr,
                "bs": bs,
                "num_epochs": num_epochs,
                "normalization_info": normalization_info,
                "n_traj": n_traj,
                "best_val_loss": best_val_loss,
            }
            if args_dict is not None:
                checkpoint["args"] = args_dict
            torch.save(checkpoint, save_path)
            logger.info(f"Saved checkpoint to {save_path}")

        if epoch % 5 == 0:
            # plot training/validation loss curves
            plt.figure()
            plt.plot(range(1, epoch + 1), train_loss_history, label="Train Loss")
            plt.plot(range(1, epoch + 1), val_loss_history, label="Validation Loss")
            plt.xlabel("Epoch")
            plt.ylabel("Chamfer Loss")
            plt.title("DeepSet Autoencoder Training and Validation Loss")
            plt.legend()
            plt.grid(True, alpha=0.3)
            plt_path = os.path.join(log_dir, f"loss_curve_Z{z_dim}.png")
            plt.savefig(plt_path)
            plt.close()

    logger.info("Training finished")
    logger.info(f"Best validation loss: {best_val_loss:8.6f}")


# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train a DeepSet autoencoder with Chamfer loss and type-split encoder")

    parser.add_argument("--data_path", type=str, default="generate_data/dataset/trajectories_large.npz")
    parser.add_argument("--n_traj", type=int, default=10000, help="number of trajectories to use")

    parser.add_argument("--dim", type=int, default=2, help="data dimension (D)")
    parser.add_argument("--bs", type=int, default=512, help="batch size")

    parser.add_argument("--z_dim", type=int, default=1, help="latent dimension per type")
    parser.add_argument("--hidden_dim", type=int, default=256, help="hidden layer dimension")
    parser.add_argument("--pool", type=str, default="mean", choices=["mean", "max", "sum"], help="pooling type for DeepSet encoder")

    parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
    parser.add_argument("--num_epochs", type=int, default=100, help="number of training epochs")

    parser.add_argument("--seed", type=int, default=42, help="random seed for reproducibility")

    parser.add_argument("--device", type=int, default=0, help="GPU device index (-1 for CPU)")

    parser.add_argument("--save_dir", type=str, default="trained_deepset_Chamfer_autoencoder/", help="directory to save checkpoints")

    args = parser.parse_args()

    # Resolve device
    resolved_device = f"cuda:{args.device}" if torch.cuda.is_available() and args.device >= 0 else "cpu"

    # Build save directory
    os.makedirs(args.save_dir, exist_ok=True)
    experiments_done = len(os.listdir(args.save_dir))
    save_dir = os.path.join(
        args.save_dir,
        f"exp{experiments_done + 1}",
        f"pool_{args.pool}",
        f"z_dim_{args.z_dim}",
        f"hidden_{args.hidden_dim}",
        f"traj_{args.n_traj}",
    )
    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, f"best_deepset_Chamfer_autoencoder_Z{args.z_dim}.pth")

    train_deepset_autoencoder(
        data_path=args.data_path,
        n_traj=args.n_traj,
        dim=args.dim,
        bs=args.bs,
        z_dim=args.z_dim,
        hidden_dim=args.hidden_dim,
        pool=args.pool,
        lr=args.lr,
        num_epochs=args.num_epochs,
        seed=args.seed,
        device=resolved_device,
        save_path=save_path,
        args_dict=vars(args),
    )
