import os, time, json
import torch
from torch.utils.data import TensorDataset, DataLoader
import numpy as np

from wassersteinwormhole_pytorch.wormhole import fast_Wormhole
from wassersteinwormhole_pytorch.transformer import TransformerAutoencoder
from wassersteinwormhole_pytorch.default_config import DefaultConfig
from sw2 import (
    Wasserstein_Distance, Sliced_Wasserstein_Distance, Projected_Wasserstein_Distance,
    Energy_based_Sliced_Wasserstein, Expected_Sliced_Transport,
    Max_Sliced_Wasserstein_Distance, Min_SWGG
)
from utils import generate_uniform_unit_sphere_projections


for num_samples_training in [10, 50, 100, 200]:
    for num_epochs in [10, 100, 1000, 2000]:

        dataset_path = f"compare_wormhole_pc/train/num_samples_{num_samples_training}"
        prefix_saved_result = f"saved_time_training/num_samples_{num_samples_training}/epochs_{num_epochs}"
        os.makedirs(prefix_saved_result, exist_ok=True)

        X_train = torch.load(os.path.join(dataset_path, "samples.pt")).cuda()

        projection_matrix = generate_uniform_unit_sphere_projections(
            dim=X_train.shape[2], requires_grad=False, num_projections=100,
            dtype=torch.float32, device="cuda"
        )

        def W(x, y):
            return Wasserstein_Distance(x, y, numItermax=10000, device=config.device)

        def SW(x, y):
            return Sliced_Wasserstein_Distance(
                x, y, projection_matrix=projection_matrix, device=config.device, dtype=config.dtype
            )

        def PWD(x, y):
            return Projected_Wasserstein_Distance(
                x, y, projection_matrix=projection_matrix, device=config.device, dtype=config.dtype
            )

        def EBSW(x, y):
            return Energy_based_Sliced_Wasserstein(
                x, y, projection_matrix=projection_matrix, device=config.device, dtype=config.dtype
            )
        def EST(x, y):
            return Expected_Sliced_Transport(
                x, y, projection_matrix=projection_matrix, device=config.device, dtype=config.dtype
            )

        def MinSWGG(x, y):
            return Min_SWGG(
                x, y, lr=5e-2, num_iter=5, s=1, std=0.5, device=config.device, dtype=config.dtype
            )[0]

        def MaxSW(x, y):
            return Max_Sliced_Wasserstein_Distance(
                x, y, require_optimize=True, lr=1e-1, num_iter=5, device=config.device, dtype=config.dtype
            )[0]

        GROUPS = [
            ("sw_pwd",                        [SW, PWD]),
            ("ebsw_est",                      [EBSW, EST]),
            ("maxsw_minswgg",                 [MaxSW, MinSWGG]),
            ("sw_pwd_ebsw_est",               [SW, PWD, EBSW, EST]),
            ("sw_pwd_ebsw_est_maxsw_minswgg", [SW, PWD, EBSW, EST, MaxSW, MinSWGG]),
        ]

        BATCH_SIZES = [4, 6, 8, 10, 12, 14, 16, 18, 20]

        results = []
        for bs in BATCH_SIZES:
            
            if (X_train.shape[0] % bs) == 1:
                print(f"Skip bs={bs} vì N % b == 1")
                continue

            print(f"\n=== Benchmark với batch_size={bs} ===")
            config = DefaultConfig(
                device=torch.device("cuda"), batch_size=bs, dtype=torch.float32,
                coeff_dec=0.1, emb_dim=128, num_heads=4, num_layers=3, mlp_dim=512,
                attention_dropout_rate=0.1, lr=1e-4, epochs=num_epochs, decay_steps=200
            )
            config.n_samples, config.n_points, config.input_dim = X_train.shape

            dataloader = DataLoader(TensorDataset(X_train), batch_size=config.batch_size, shuffle=True, drop_last=False)

            ALPHA_SAMPLES = min(10, X_train.shape[0])

            for group_name, metric_list in GROUPS:
                print(f"\n========== TRAIN GROUP: {group_name} ==========")

                model_g = TransformerAutoencoder(
                    config=config, seq_len=config.n_points, inp_dim=config.input_dim
                ).to(config.device)

                run_dir = os.path.join(prefix_saved_result, group_name)
                os.makedirs(run_dir, exist_ok=True)
                metric_names = [fn.__name__ for fn in metric_list]

                fast_wh = fast_Wormhole( 
                    estimate_alpha_general=False, 
                    transformer=model_g, 
                    config=config, 
                    metric_funcs=metric_list, 
                    ground_truth_func=W, 
                    run_dir=run_dir, 
                    metric_names=[fn.__name__ for fn in metric_list], 
                    compute_stats=False, 
                    save_best=False 
                )

                optimizer_g = torch.optim.Adam(model_g.parameters(), lr=config.lr)
                scheduler_g = torch.optim.lr_scheduler.StepLR(optimizer_g, step_size=config.decay_steps, gamma=0.75)

                t0 = time.time()
                fast_wh.estimate_alpha(samples=X_train[:ALPHA_SAMPLES])
                est_alpha_sec = time.time() - t0
                print(f"[{group_name}] estimate_alpha: {est_alpha_sec:.2f} sec")

                t1 = time.time()
                fast_wh.train_model(dataloader=dataloader, optimizer=optimizer_g, scheduler=scheduler_g, epochs=config.epochs)
                train_sec = time.time() - t1
                total_sec = est_alpha_sec + train_sec
                print(f"[{group_name}] Finished training in {train_sec/60:.2f} minutes")

                alphas_np = fast_wh.alphas.detach().cpu().numpy()
                np.save(os.path.join(run_dir, "alphas.npy"), alphas_np)
                with open(os.path.join(run_dir, "alphas.json"), "w") as f:
                    json.dump({name: float(val) for name, val in zip(metric_names, alphas_np)}, f, indent=2)
                print(f"[{group_name}] Saved alphas to alphas.npy / alphas.json")

                time_run_dir = os.path.join(run_dir, f"time_{group_name}_bs{bs}")
                os.makedirs(time_run_dir, exist_ok=True)

                time_stats = {
                    "group": group_name,
                    "batch_size": int(config.batch_size),
                    "estimate_alpha_sec": round(est_alpha_sec, 3),
                    "train_time_sec": round(train_sec, 3),
                    "total_time_sec": round(total_sec, 3),
                    "epochs": int(config.epochs),
                    "lr": float(config.lr),
                    "n_samples": int(config.n_samples),
                    "n_points": int(config.n_points),
                    "input_dim": int(config.input_dim),
                }
                with open(os.path.join(time_run_dir, "time_stats.json"), "w") as f:
                    json.dump(time_stats, f, indent=2)
                print(f"=> [{group_name}] Saved time stats to {os.path.join(time_run_dir, 'time_stats.json')}")
                results.append(time_stats)

        # ==== Sau tất cả batch sizes & groups: lưu file tổng hợp ====
        # JSON tổng hợp
        agg_json = os.path.join(prefix_saved_result, "batchsize_time.json")
        with open(agg_json, "w") as f:
            json.dump(results, f, indent=2)

        # CSV tổng hợp
        import csv
        agg_csv = os.path.join(prefix_saved_result, "batchsize_time.csv")
        with open(agg_csv, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(["group", "batch_size", "estimate_alpha_sec", "train_time_sec", "total_time_sec",
                            "epochs", "lr", "n_samples", "n_points", "input_dim"])
            for r in results:
                writer.writerow([
                    r["group"], r["batch_size"], r["estimate_alpha_sec"], r["train_time_sec"], r["total_time_sec"],
                    r["epochs"], r["lr"], r["n_samples"], r["n_points"], r["input_dim"]
                ])

        print(f"Saved {agg_json} and {agg_csv}")
