import os
import logging
import argparse

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm

import matplotlib.pyplot as plt
import sys
sys.path.append('..')
from utils.utils import set_seed
from utils.model import Autoencoder

# ---------------------------------------------------------------------------
# Data loading
# ---------------------------------------------------------------------------
def _should_record(time_step: int) -> bool:
    i = time_step + 1
    if i <= 100:
        return True
    elif i <= 150:
        return (i % 2) == 0
    elif i <= 200:
        return (i % 5) == 0
    else:
        return (i % 10) == 0


def make_dataloaders(
    data_path: str,
    bs: int,
    n_traj: int,
    shuffle_points: bool = True,
    device: str | torch.device = "cpu",
):
    """
    Load training & validation data and wrap them in TensorDataset/DataLoader.
    
    Each file is assumed to contain a numpy array of shape [E, T, M, D], where:
      - E: number of experiments
      - T: number of timesteps
      - M: number of anchors per experiment
      - D: data dimension
    
    We reshape to (E*T_sel, M*D) for the autoencoder, where T_sel are selected timesteps.
    
    Args:
        shuffle_points: If True, each batch will have its points randomly shuffled during training.
    """
    
    input_data = np.load(data_path, allow_pickle=True)
    assert input_data.shape[0] >= n_traj, f"data has {input_data.shape[0]} trajectories, but n_traj={n_traj}"
    input_data = input_data[:n_traj]  # (E, T, M, D)
    print(f"input_data shape: {input_data.shape}")

    # Select timesteps according to _should_record
    selected_timesteps = [t for t in range(input_data.shape[1]) if _should_record(t)]
    input_data = input_data[:, selected_timesteps, :, :]  # (E, T_sel, M, D)
    print(f"Selected {len(selected_timesteps)} timesteps: {selected_timesteps}")
    print(f"input_data shape after timestep selection: {input_data.shape}")
    
    # normalize data to zero mean, 1 std in each dimension
    mean = input_data.mean(axis=(0, 1, 2), keepdims=True)
    std = input_data.std(axis=(0, 1, 2), keepdims=True)
    norm_input_data = (input_data - mean) / std
    
    normalization_info = {
        "mean": mean,
        "std": std,
    }
    print(f"Data normalization info: {normalization_info}")

    train_num = int(0.8 * norm_input_data.shape[0])
    train_pos = norm_input_data[:train_num]   # (E_train, T, M, D)
    val_pos = norm_input_data[train_num:]     # (E_val, T, M, D)
    
    # print(f"safe_ranges: {safe_ranges}")
    print(
        f"train x,y max: {train_pos.max(axis=(0, 1, 2))}, "
        f"min: {train_pos.min(axis=(0, 1, 2))}"
    )
    print(
        f"val   x,y max: {val_pos.max(axis=(0, 1, 2))}, "
        f"min: {val_pos.min(axis=(0, 1, 2))}"
    )

    

    # Reshape: (E,T_sel,M,D) -> (E*T_sel, M, D)
    train_anchors = torch.from_numpy(
        train_pos.reshape(-1, train_pos.shape[2], train_pos.shape[3])
    ).float()
        
    val_anchors = torch.from_numpy(
        val_pos.reshape(-1, val_pos.shape[2], val_pos.shape[3])
    ).float()    

    assert train_anchors.dim() == 3
    assert val_anchors.dim() == 3

    # Move data to device if GPU is available (for A100 with sufficient memory)
    if str(device) != "cpu":
        print(f"Preloading data to {device} for faster training...")
        train_anchors = train_anchors.to(device)
        val_anchors = val_anchors.to(device)
        use_pin_memory = False
    else:
        use_pin_memory = True

    # Wrap in datasets/loaders; each item is a single experiment's anchor set: (M,D)
    train_ds = TensorDataset(train_anchors)
    val_ds = TensorDataset(val_anchors)
    print(f"train_anchors shape: {train_anchors.shape}, val_anchors shape: {val_anchors.shape}")

    train_loader = DataLoader(train_ds, batch_size=bs, pin_memory=use_pin_memory, shuffle=True, drop_last=False)
    val_loader = DataLoader(val_ds, batch_size=bs, pin_memory=use_pin_memory, shuffle=True, drop_last=False)

    dim = train_pos.shape[3]
    n_anchors = train_pos.shape[2]
    
    # For autoencoder, we also need the flattened input dimension
    input_dim = n_anchors * dim

    return train_loader, val_loader, dim, n_anchors, input_dim, normalization_info



# ---------------------------------------------------------------------------
# Training
# ---------------------------------------------------------------------------

def train_autoencoder(
    data_path: str,
    n_traj: int,
    dim: int | None = None,
    bs: int = 32,
    z_dim: int = 16,
    hidden_dim: int = 256,
    lr: float = 1e-3,
    num_epochs: int = 100,
    shuffle_points: bool = True,
    seed: int = 42,
    device: str | torch.device | None = None,
    save_path: str = "autoencoder.pth",
    args_dict: dict | None = None,
):
    """
    Train an autoencoder on trajectory data with optional random point shuffling.
    
    Args:
        shuffle_points: If True, randomly shuffle the order of points in each batch during training.
        seed: Random seed for reproducibility.
    """
    # Set seed for reproducibility
    set_seed(seed)
    
    # Device
    if device is None:
        device = "cuda:0" if torch.cuda.is_available() else "cpu"
    device = torch.device(device)

    # Data (preload to GPU if available)
    train_loader, val_loader, data_dim, n_anchors, input_dim, normalization_info = make_dataloaders(
        data_path=data_path,
        bs=bs,
        n_traj=n_traj,
        shuffle_points=shuffle_points,
        device=device,
    )

    # Logger
    log_dir = os.path.dirname(save_path) if save_path else "."
    os.makedirs(log_dir, exist_ok=True)
    log_path = os.path.join(log_dir, f"train_autoencoder_Z{z_dim}.log")
    logger = logging.getLogger(f"TrainAutoencoder:{log_path}")
    logger.setLevel(logging.INFO)
    logger.handlers = []  # avoid duplicate handlers
    fh = logging.FileHandler(log_path, mode="w")
    fh.setLevel(logging.INFO)
    formatter = logging.Formatter("%(asctime)s - %(message)s")
    fh.setFormatter(formatter)
    logger.addHandler(fh)

    print(f"Logging to {log_path}")
    logger.info("Starting autoencoder training")
    logger.info(f"Shuffle points during training: {shuffle_points}")
    logger.info(f"Input dimension: {input_dim} (n_anchors={n_anchors}, dim={data_dim})")
    logger.info(f"z_dim: {z_dim}")

    # Build model
    model = Autoencoder(
        input_dim=input_dim,
        hidden_dim=hidden_dim,
        z_dim=z_dim,
    ).to(device)

    # Count parameters
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    logger.info(f"Model parameters: {num_params:,}")

    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()

    best_val_loss = float("inf")
    train_loss_history = []
    val_loss_history = []
    
    # ----------------------------------------------------------------------
    # Training loop
    # ----------------------------------------------------------------------
    for epoch in tqdm(range(1, num_epochs + 1)):
        model.train()
        train_losses = []

        for (x_batch,) in train_loader:
            # Move to device only if not already there (data preloading optimization)
            if x_batch.device != device:
                x_batch = x_batch.to(device, non_blocking=True)
            
            # Randomly shuffle the order of points (same permutation for all samples in batch)
            if shuffle_points:
                perm = torch.randperm(x_batch.shape[1], device=device)
                x_batch = x_batch[:, perm, :]
            
            # Forward pass
            z, x_recon = model(x_batch)
            
            # Reconstruction loss
            loss = criterion(x_recon, x_batch)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_losses.append(loss.item())

        mean_train_loss = float(np.mean(train_losses))

        # ------------------------------------------------------------------
        # Validation
        # ------------------------------------------------------------------
        model.eval()
        val_losses = []
        with torch.no_grad():
            for (x_batch_val,) in val_loader:
                if x_batch_val.device != device:
                    x_batch_val = x_batch_val.to(device, non_blocking=True)
                
                # No shuffling during validation to get consistent metrics
                z_val, x_recon_val = model(x_batch_val)
                val_loss = criterion(x_recon_val, x_batch_val)
                val_losses.append(val_loss.item())

        mean_val_loss = float(np.mean(val_losses))

        msg = (
            f"Epoch {epoch:4d} | "
            f"Train Loss: {mean_train_loss:8.6f} | "
            f"Val Loss: {mean_val_loss:8.6f}"
        )
        train_loss_history.append(mean_train_loss)
        val_loss_history.append(mean_val_loss)
        logger.info(msg)

        # Save best model
        if mean_val_loss < best_val_loss and save_path is not None:
            best_val_loss = mean_val_loss
            checkpoint = {
                "model_state_dict": model.state_dict(),
                "input_dim": input_dim,
                "n_anchors": n_anchors,
                "data_dim": data_dim,
                "z_dim": z_dim,
                "hidden_dim": hidden_dim,
                "lr": lr,
                "bs": bs,
                "num_epochs": num_epochs,
                "shuffle_points": shuffle_points,
                "normalization_info": normalization_info,
                "n_traj": n_traj,
                "best_val_loss": best_val_loss,
            }
            if args_dict is not None:
                checkpoint["args"] = args_dict
            torch.save(checkpoint, save_path)
            logger.info(f"Saved checkpoint to {save_path}")

        if epoch % 5 == 0:
            # plot training/validation loss curves
            plt.figure()
            plt.plot(range(1, epoch + 1), train_loss_history, label="Train Loss")
            plt.plot(range(1, epoch + 1), val_loss_history, label="Validation Loss")
            plt.xlabel("Epoch")
            plt.ylabel("MSE Loss")
            plt.title("Autoencoder Training and Validation Loss")
            plt.legend()
            plt.grid(True, alpha=0.3)
            plt_path = os.path.join(log_dir, f"loss_curve_Z{z_dim}.png")
            plt.savefig(plt_path)
            plt.close()

    logger.info("Training finished")
    logger.info(f"Best validation loss: {best_val_loss:8.6f}")


# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train a baseline MLP autoencoder with optional point shuffling")

    parser.add_argument("--data_path", type=str, default="generate_dataset/data/trajectories.npy")
    parser.add_argument("--n_traj", type=int, default=10000, help="number of trajectories to use")

    parser.add_argument("--dim", type=int, default=2, help="data dimension (D)")
    parser.add_argument("--bs", type=int, default=1024, help="batch size")

    parser.add_argument("--z_dim", type=int, default=8, help="latent dimension")
    parser.add_argument("--hidden_dim", type=int, default=256, help="hidden layer dimension")

    parser.add_argument("--shuffle_points", action="store_true", help="randomly shuffle point order during training to test permutation invariance")
    parser.add_argument("--no_shuffle_points", dest="shuffle_points", action="store_false", help="do NOT shuffle points (baseline without permutation invariance test)")
    parser.set_defaults(shuffle_points=True)

    parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
    parser.add_argument("--num_epochs", type=int, default=200, help="number of training epochs")
    
    parser.add_argument("--seed", type=int, default=42, help="random seed for reproducibility")

    parser.add_argument("--device", type=int, default=0, help="GPU device index (-1 for CPU)")

    parser.add_argument("--save_dir", type=str, default="trained_autoencoder/", help="directory to save checkpoints")

    args = parser.parse_args()

    # Resolve device
    resolved_device = f"cuda:{args.device}" if torch.cuda.is_available() and args.device >= 0 else "cpu"

    # Build save directory
    shuffle_str = "shuffle" if args.shuffle_points else "no_shuffle"
    os.makedirs(args.save_dir, exist_ok=True)
    experiments_done = len(os.listdir(args.save_dir))
    save_dir = os.path.join(
        args.save_dir,
        f"exp{experiments_done + 1}",
        shuffle_str,
        f"z_dim_{args.z_dim}",
        f"hidden_{args.hidden_dim}",
        f"traj_{args.n_traj}"
    )
    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, f"best_autoencoder_Z{args.z_dim}.pth")

    train_autoencoder(
        data_path=args.data_path,
        n_traj=args.n_traj,
        dim=args.dim,
        bs=args.bs,
        z_dim=args.z_dim,
        hidden_dim=args.hidden_dim,
        lr=args.lr,
        num_epochs=args.num_epochs,
        shuffle_points=args.shuffle_points,
        seed=args.seed,
        device=resolved_device,
        save_path=save_path,
        args_dict=vars(args),
    )
