from wassersteinwormhole_pytorch.wormhole import fast_Wormhole
from wassersteinwormhole_pytorch.transformer import TransformerAutoencoder
from wassersteinwormhole_pytorch.default_config import DefaultConfig
import os
import torch
from torch.utils.data import TensorDataset, DataLoader
from sw2 import (
    Wasserstein_Distance,
    Sliced_Wasserstein_Distance,
    Projected_Wasserstein_Distance,
    Energy_based_Sliced_Wasserstein,
    Max_Sliced_Wasserstein_Distance,
    Min_SWGG,
    Expected_Sliced_Transport,
)
from utils import generate_uniform_unit_sphere_projections
import time
import json
import numpy as np
import math

estimate_alpha_general = True
num_samples_training = 200
dataset_path   = f"preprocessed_dataset/pointcloud/train/num_samples_{num_samples_training}"

if estimate_alpha_general:
    prefix_run_dir = f"saved_embeddings/rg_wormhole/optimal_alpha_general/num_{num_samples_training}"
else:
    prefix_run_dir = f"saved_embeddings/rg_wormhole/optimal_alpha_simplex/num_{num_samples_training}"
os.makedirs(prefix_run_dir, exist_ok=True)

config = DefaultConfig(
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    batch_size=10, dtype=torch.float32, coeff_dec=0.1,
    emb_dim=128, num_heads=4, num_layers=3, mlp_dim=512, attention_dropout_rate=0.0,
    lr=1e-4, epochs=0, decay_steps=0
)

X_train = torch.load(os.path.join(dataset_path, "samples.pt")).to(config.device).to(config.dtype)
print(f"=> Loaded {X_train.shape[0]} point clouds with {X_train.shape[1]} points each from {dataset_path}")

config.n_samples = X_train.shape[0]
config.n_points = X_train.shape[1]
config.input_dim = X_train.shape[2]

dataset = TensorDataset(X_train)
dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True, drop_last=False)

TARGET_STEPS = 10000
steps_per_epoch = math.ceil(len(dataset) / config.batch_size)
config.epochs = max(2000, math.ceil(TARGET_STEPS / steps_per_epoch))
print(f"=> steps/epoch ~ {steps_per_epoch}, target steps={TARGET_STEPS}, epochs set to {config.epochs}")

projection_matrix = generate_uniform_unit_sphere_projections(
    dim=config.input_dim, requires_grad=False, num_projections=100, dtype=config.dtype, device=config.device
)

def W(x, y):
    return Wasserstein_Distance(x, y, numItermax=10000, device=config.device)

def SW(x, y):
    return Sliced_Wasserstein_Distance(
        x, y, projection_matrix=projection_matrix, device=config.device, dtype=config.dtype
    )

def PWD(x, y):
    return Projected_Wasserstein_Distance(
        x, y, projection_matrix=projection_matrix, device=config.device, dtype=config.dtype
    )

def EBSW(x, y):
    return Energy_based_Sliced_Wasserstein(
        x, y, projection_matrix=projection_matrix, device=config.device, dtype=config.dtype
    )

def EST(x, y):
    return Expected_Sliced_Transport(
        x, y, projection_matrix=projection_matrix, device=config.device, dtype=config.dtype
    )

def MinSWGG(x, y):
    return Min_SWGG(
        x, y, lr=5e-2, num_iter=5, s=10, std=0.5, device=config.device, dtype=config.dtype
    )[0]

def MaxSW(x, y):
    return Max_Sliced_Wasserstein_Distance(
        x, y, require_optimize=True, lr=1e-2, num_iter=5, device=config.device, dtype=config.dtype
    )[0]

w_fn = W

GROUPS = [
    ("sw_pwd",                        [SW, PWD]),
    ("ebsw_est",                      [EBSW, EST]),
    ("maxsw_minswgg",                 [MaxSW, MinSWGG]),
    ("sw_pwd_ebsw_est",               [SW, PWD, EBSW, EST]),
    ("sw_pwd_ebsw_est_maxsw_minswgg", [SW, PWD, EBSW, EST, MaxSW, MinSWGG]),
]

ALPHA_SAMPLES = min(10, X_train.shape[0])

for group_name, metric_list in GROUPS:
    print(f"\n========== TRAIN GROUP: {group_name} ==========")

    model_g = TransformerAutoencoder(
        config=config, seq_len=config.n_points, inp_dim=config.input_dim
    ).to(config.device)

    run_dir = os.path.join(prefix_run_dir, group_name)
    os.makedirs(run_dir, exist_ok=True)

    metric_names = [fn.__name__ for fn in metric_list]

    fast_wh = fast_Wormhole(
        estimate_alpha_general=estimate_alpha_general,
        transformer=model_g,
        config=config,
        metric_funcs=metric_list,
        ground_truth_func=W,
        run_dir=run_dir,
        metric_names=metric_names,
        compute_stats=True,
        save_best=True
    )

    optimizer_g = torch.optim.Adam(model_g.parameters(), lr=config.lr)
    gamma = (0.1) ** (1.0 / config.epochs)
    scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=gamma)

    t0 = time.time()
    fast_wh.estimate_alpha(samples=X_train[:ALPHA_SAMPLES])
    est_alpha_sec = time.time() - t0
    print(f"[{group_name}] estimate_alpha: {est_alpha_sec:.2f} sec")

    t1 = time.time()
    fast_wh.train_model(
    dataloader=dataloader,
    optimizer=optimizer_g,
    scheduler=scheduler_g,
    epochs=config.epochs,
    save_every=50,
    verbose=True)
    train_sec = time.time() - t1
    total_sec = est_alpha_sec + train_sec
    print(f"[{group_name}] Finished training in {train_sec/60:.2f} minutes")

    alphas_np = fast_wh.alphas.detach().cpu().numpy()
    np.save(os.path.join(run_dir, "alphas.npy"), alphas_np)
    with open(os.path.join(run_dir, "alphas.json"), "w") as f:
        json.dump({name: float(val) for name, val in zip(metric_names, alphas_np)}, f, indent=2)
    print(f"=> [{group_name}] Saved alphas to alphas.npy / alphas.json")

    ckpt_name = f"{group_name}_lr{config.lr}_epoch{config.epochs}.pth"
    ckpt_path = os.path.join(run_dir, ckpt_name)
    torch.save({
        'model_state_dict': model_g.state_dict(),
        'optimizer_state_dict': optimizer_g.state_dict(),
        'epoch': config.epochs,
        'estimate_alpha_sec': est_alpha_sec,
        'train_time_sec': train_sec,
        'total_time_sec': total_sec,
        'alphas': alphas_np.tolist(),
        'optimal_alpha': alphas_np,
        'metric_names': metric_names,
    }, ckpt_path)
    print(f"=> [{group_name}] Saved checkpoint at {ckpt_path}")

    time_stats = {
        "group": group_name,
        "estimate_alpha_sec": round(est_alpha_sec, 3),
        "train_time_sec": round(train_sec, 3),
        "total_time_sec": round(total_sec, 3),
        "estimate_alpha_min": round(est_alpha_sec / 60.0, 3),
        "train_time_min": round(train_sec / 60.0, 3),
        "total_time_min": round(total_sec / 60.0, 3),
        "epochs": config.epochs,
        "lr": float(config.lr),
        "batch_size": int(config.batch_size),
        "alphas": {name: float(val) for name, val in zip(metric_names, alphas_np)}
    }
    with open(os.path.join(run_dir, "time_stats.json"), "w") as f:
        json.dump(time_stats, f, indent=2)
    print(f"=> [{group_name}] Saved time stats to time_stats.json")

    if torch.cuda.is_available():
        torch.cuda.empty_cache()

