import os
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import time
import math
from sw2 import Wasserstein_Distance
from wassersteinwormhole_pytorch.transformer import TransformerAutoencoder
from wassersteinwormhole_pytorch.wormhole import Wormhole
from wassersteinwormhole_pytorch.default_config import DefaultConfig

def train_wormhole(num_samples, config, save_prefix="wormhole_run", dataset_root="compare_wormhole"):
    num_samples_str = "all" if num_samples == "all" else str(int(num_samples))
    run_dir = f"{save_prefix}/num_train_{num_samples_str}"
    os.makedirs(run_dir, exist_ok=True)

    dataset_path = os.path.join(dataset_root, "train", f"num_samples_{num_samples_str}", "samples.pt")
    if not os.path.exists(dataset_path):
        raise FileNotFoundError(f"Dataset not found: {dataset_path}")
    samples = torch.load(dataset_path).to(config.device).to(config.dtype)
    config.n_points = samples.shape[1]
    config.n_samples = samples.shape[0]
    config.input_dim = samples.shape[2]
    print(f"=> Loaded {samples.shape[0]} point clouds with {config.n_points} points each from {dataset_path}")
    dataset = TensorDataset(samples)
    dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)


    TARGET_STEPS = 10000
    steps_per_epoch = math.ceil(len(dataset) / config.batch_size)
    config.epochs = max(2000, math.ceil(TARGET_STEPS / steps_per_epoch))
    print(f"=> steps/epoch ~ {steps_per_epoch}, target steps={TARGET_STEPS}, epochs set to {config.epochs}")
    model = TransformerAutoencoder(config=config, seq_len=config.n_points, inp_dim=config.input_dim).to(config.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
    final_factor = 0.1
    gamma = final_factor ** (1.0 / max(1, config.epochs))
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)
    wormhole = Wormhole(transformer=model, config=config, run_dir=run_dir)

    print(f"--- Start training with num_samples = {num_samples_str} ---")
    start_time = time.time()
    wormhole.train_model(
        dataloader=dataloader,
        optimizer=optimizer,
        scheduler=scheduler,
        epochs=config.epochs,
        save_every=50,
        verbose=True
    )
    elapsed = time.time() - start_time
    print(f"--- Finished training in {elapsed/60:.2f} minutes ---")

    ckpt_name = f"num{num_samples_str}_lr{config.lr}_epoch{config.epochs}.pth"
    ckpt_path = os.path.join(run_dir, ckpt_name)
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': config.epochs,
        'train_time_sec': elapsed,
    }, ckpt_path)
    print(f"=> Saved checkpoint at {ckpt_path}")

if __name__ == "__main__":
    config = DefaultConfig(
        device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        batch_size=15, dtype=torch.float32, coeff_dec=0.1,
        emb_dim=128, num_heads=4, num_layers=3, mlp_dim=512, attention_dropout_rate=0.1,
        lr=1e-4, epochs=0, decay_steps=0
    )
    config.summary()
    for ns in [10, 50, 100, 200]:
        print(f"\n===== RUNNING EXPERIMENT WITH num_samples = {ns} =====")
        train_wormhole(num_samples=ns, config=config, save_prefix="saved_compare_wormhole/pointcloud/wormhole", dataset_root="preprocessed_dataset/point_cloud")
