import os
import logging
import argparse

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm

import matplotlib.pyplot as plt
import sys
sys.path.append('..')
from utils.utils import set_seed
from utils.model import DeepSetAutoencoder

# ---------------------------------------------------------------------------
# Data loading
# ---------------------------------------------------------------------------
def _should_record(time_step: int) -> bool:
    i = time_step + 1
    if i <= 100:
        return True
    elif i <= 150:
        return (i % 2) == 0
    elif i <= 200:
        return (i % 5) == 0
    else:
        return (i % 10) == 0


def make_dataloaders(
    data_path: str,
    bs: int,
    n_traj: int,
    device: str | torch.device = "cpu",
):
    """
    Load training & validation data and wrap them in TensorDataset/DataLoader.
    
    Each file is assumed to contain a numpy array of shape [E, T, M, D], where:
      - E: number of experiments
      - T: number of timesteps
      - M: number of anchors per experiment
      - D: data dimension
    """
    
    input_data = np.load(data_path, allow_pickle=True)
    assert input_data.shape[0] >= n_traj, f"data has {input_data.shape[0]} trajectories, but n_traj={n_traj}"
    input_data = input_data[:n_traj]  # (E, T, M, D)
    print(f"input_data shape: {input_data.shape}")

    # Select timesteps according to _should_record
    selected_timesteps = [t for t in range(input_data.shape[1]) if _should_record(t)]
    input_data = input_data[:, selected_timesteps, :, :]  # (E, T_sel, M, D)
    print(f"Selected {len(selected_timesteps)} timesteps: {selected_timesteps}")
    print(f"input_data shape after timestep selection: {input_data.shape}")
    
    # normalize data to zero mean, 1 std in each dimension
    mean = input_data.mean(axis=(0, 1, 2), keepdims=True)
    std = input_data.std(axis=(0, 1, 2), keepdims=True)
    norm_input_data = (input_data - mean) / std
    
    normalization_info = {
        "mean": mean,
        "std": std,
    }
    print(f"Data normalization info: {normalization_info}")

    train_num = int(0.8 * norm_input_data.shape[0])
    train_pos = norm_input_data[:train_num]   # (E_train, T, M, D)
    val_pos = norm_input_data[train_num:]     # (E_val, T, M, D)
    
    # print(f"safe_ranges: {safe_ranges}")
    print(
        f"train x,y max: {train_pos.max(axis=(0, 1, 2))}, "
        f"min: {train_pos.min(axis=(0, 1, 2))}"
    )
    print(
        f"val   x,y max: {val_pos.max(axis=(0, 1, 2))}, "
        f"min: {val_pos.min(axis=(0, 1, 2))}"
    )

    

    # Reshape: (E,T_sel,M,D) -> (E*T_sel, M, D)
    train_anchors = torch.from_numpy(
        train_pos.reshape(-1, train_pos.shape[2], train_pos.shape[3])
    ).float()
        
    val_anchors = torch.from_numpy(
        val_pos.reshape(-1, val_pos.shape[2], val_pos.shape[3])
    ).float()    

    assert train_anchors.dim() == 3
    assert val_anchors.dim() == 3

    # Move data to device if GPU is available (for A100 with sufficient memory)
    if str(device) != "cpu":
        print(f"Preloading data to {device} for faster training...")
        train_anchors = train_anchors.to(device)
        val_anchors = val_anchors.to(device)
        use_pin_memory = False
    else:
        use_pin_memory = True

    # Wrap in datasets/loaders; each item is a single experiment's anchor set: (M,D)
    train_ds = TensorDataset(train_anchors)
    val_ds = TensorDataset(val_anchors)
    print(f"train_anchors shape: {train_anchors.shape}, val_anchors shape: {val_anchors.shape}")

    train_loader = DataLoader(train_ds, batch_size=bs, pin_memory=use_pin_memory, shuffle=True, drop_last=False)
    val_loader = DataLoader(val_ds, batch_size=bs, pin_memory=use_pin_memory, shuffle=True, drop_last=False)

    dim = train_pos.shape[3]
    n_anchors = train_pos.shape[2]
    
    # For autoencoder, we also need the flattened input dimension
    input_dim = n_anchors * dim

    return train_loader, val_loader, dim, n_anchors, input_dim, normalization_info



# ---------------------------------------------------------------------------
# Training
# ---------------------------------------------------------------------------

def train_deepset_autoencoder(
    data_path: str,
    n_traj: int,
    dim: int | None = None,
    bs: int = 32,
    z_dim: int = 16,
    hidden_dim: int = 256,
    pool: str = "mean",
    lr: float = 1e-3,
    num_epochs: int = 100,
    seed: int = 42,
    device: str | torch.device | None = None,
    save_path: str = "deepset_autoencoder.pth",
    args_dict: dict | None = None,
):
    """
    Train a DeepSet autoencoder on trajectory data.
    DeepSet encoder is inherently permutation-invariant, no shuffling needed.
    
    Args:
        pool: Pooling type for DeepSet encoder (mean/max/sum).
        seed: Random seed for reproducibility.
    """
    # Set seed for reproducibility
    set_seed(seed)
    
    # Device
    if device is None:
        device = "cuda:0" if torch.cuda.is_available() else "cpu"
    device = torch.device(device)

    # Data (preload to GPU if available)
    train_loader, val_loader, data_dim, n_anchors, input_dim, normalization_info = make_dataloaders(
        data_path=data_path,
        bs=bs,
        n_traj=n_traj,
        device=device,
    )

    # Logger
    log_dir = os.path.dirname(save_path) if save_path else "."
    os.makedirs(log_dir, exist_ok=True)
    log_path = os.path.join(log_dir, f"train_autoencoder_Z{z_dim}.log")
    logger = logging.getLogger(f"TrainAutoencoder:{log_path}")
    logger.setLevel(logging.INFO)
    logger.handlers = []  # avoid duplicate handlers
    fh = logging.FileHandler(log_path, mode="w")
    fh.setLevel(logging.INFO)
    formatter = logging.Formatter("%(asctime)s - %(message)s")
    fh.setFormatter(formatter)
    logger.addHandler(fh)

    print(f"Logging to {log_path}")
    logger.info("Starting DeepSet autoencoder training")
    logger.info(f"Pool type: {pool}")
    logger.info(f"Input dimension: {input_dim} (n_anchors={n_anchors}, dim={data_dim})")
    logger.info(f"z_dim: {z_dim}")

    # Build DeepSet model
    model = DeepSetAutoencoder(
        in_dim=data_dim,
        n_points=n_anchors,
        hidden_dim=hidden_dim,
        z_dim=z_dim,
        pool=pool,
    ).to(device)

    # Count parameters
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    logger.info(f"Model parameters: {num_params:,}")

    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()

    best_val_loss = float("inf")
    train_loss_history = []
    val_loss_history = []
    
    # ----------------------------------------------------------------------
    # Training loop
    # ----------------------------------------------------------------------
    for epoch in tqdm(range(1, num_epochs + 1)):
        model.train()
        train_losses = []

        for (x_batch,) in train_loader:
            # Move to device only if not already there (data preloading optimization)
            if x_batch.device != device:
                x_batch = x_batch.to(device, non_blocking=True)
            
            # Forward pass (DeepSet is inherently permutation-invariant)
            z, x_recon = model(x_batch)
            
            # Reconstruction loss
            loss = criterion(x_recon, x_batch)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_losses.append(loss.item())

        mean_train_loss = float(np.mean(train_losses))

        # ------------------------------------------------------------------
        # Validation
        # ------------------------------------------------------------------
        model.eval()
        val_losses = []
        with torch.no_grad():
            for (x_batch_val,) in val_loader:
                if x_batch_val.device != device:
                    x_batch_val = x_batch_val.to(device, non_blocking=True)
                
                # Forward pass
                z_val, x_recon_val = model(x_batch_val)
                val_loss = criterion(x_recon_val, x_batch_val)
                val_losses.append(val_loss.item())

        mean_val_loss = float(np.mean(val_losses))

        msg = (
            f"Epoch {epoch:4d} | "
            f"Train Loss: {mean_train_loss:8.6f} | "
            f"Val Loss: {mean_val_loss:8.6f}"
        )
        train_loss_history.append(mean_train_loss)
        val_loss_history.append(mean_val_loss)
        logger.info(msg)

        # Save best model
        if mean_val_loss < best_val_loss and save_path is not None:
            best_val_loss = mean_val_loss
            checkpoint = {
                "model_state_dict": model.state_dict(),
                "input_dim": input_dim,
                "n_anchors": n_anchors,
                "data_dim": data_dim,
                "z_dim": z_dim,
                "hidden_dim": hidden_dim,
                "pool": pool,
                "lr": lr,
                "bs": bs,
                "num_epochs": num_epochs,
                "normalization_info": normalization_info,
                "n_traj": n_traj,
                "best_val_loss": best_val_loss,
            }
            if args_dict is not None:
                checkpoint["args"] = args_dict
            torch.save(checkpoint, save_path)
            logger.info(f"Saved checkpoint to {save_path}")

        if epoch % 5 == 0:
            # plot training/validation loss curves
            plt.figure()
            plt.plot(range(1, epoch + 1), train_loss_history, label="Train Loss")
            plt.plot(range(1, epoch + 1), val_loss_history, label="Validation Loss")
            plt.xlabel("Epoch")
            plt.ylabel("MSE Loss")
            plt.title("Autoencoder Training and Validation Loss")
            plt.legend()
            plt.grid(True, alpha=0.3)
            plt_path = os.path.join(log_dir, f"loss_curve_Z{z_dim}.png")
            plt.savefig(plt_path)
            plt.close()

    logger.info("Training finished")
    logger.info(f"Best validation loss: {best_val_loss:8.6f}")


# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train a DeepSet autoencoder with permutation-invariant encoder")

    parser.add_argument("--data_path", type=str, default="generate_dataset/data/trajectories.npy")
    parser.add_argument("--n_traj", type=int, default=10000, help="number of trajectories to use")

    parser.add_argument("--dim", type=int, default=2, help="data dimension (D)")
    parser.add_argument("--bs", type=int, default=512, help="batch size")

    parser.add_argument("--z_dim", type=int, default=8, help="latent dimension")
    parser.add_argument("--hidden_dim", type=int, default=256, help="hidden layer dimension")
    parser.add_argument("--pool", type=str, default="mean", choices=["mean", "max", "sum"], help="pooling type for DeepSet encoder")

    parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
    parser.add_argument("--num_epochs", type=int, default=200, help="number of training epochs")
    
    parser.add_argument("--seed", type=int, default=42, help="random seed for reproducibility")

    parser.add_argument("--device", type=int, default=0, help="GPU device index (-1 for CPU)")

    parser.add_argument("--save_dir", type=str, default="trained_deepset_autoencoder/", help="directory to save checkpoints")

    args = parser.parse_args()

    # Resolve device
    resolved_device = f"cuda:{args.device}" if torch.cuda.is_available() and args.device >= 0 else "cpu"

    # Build save directory
    os.makedirs(args.save_dir, exist_ok=True)
    experiments_done = len(os.listdir(args.save_dir))
    save_dir = os.path.join(
        args.save_dir,
        f"exp{experiments_done + 1}",
        f"pool_{args.pool}",
        f"z_dim_{args.z_dim}",
        f"hidden_{args.hidden_dim}",
        f"traj_{args.n_traj}"
    )
    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, f"best_deepset_autoencoder_Z{args.z_dim}.pth")

    train_deepset_autoencoder(
        data_path=args.data_path,
        n_traj=args.n_traj,
        dim=args.dim,
        bs=args.bs,
        z_dim=args.z_dim,
        hidden_dim=args.hidden_dim,
        pool=args.pool,
        lr=args.lr,
        num_epochs=args.num_epochs,
        seed=args.seed,
        device=resolved_device,
        save_path=save_path,
        args_dict=vars(args),
    )
