import os, time, json, csv
import torch
from torch.utils.data import TensorDataset, DataLoader
import numpy as np

from wassersteinwormhole_pytorch.wormhole import Wormhole
from wassersteinwormhole_pytorch.transformer import TransformerAutoencoder
from wassersteinwormhole_pytorch.default_config import DefaultConfig


for num_samples_training in [10, 50, 100, 200]:
    for num_epochs in [10, 100, 1000, 2000]:
        dataset_path = f"compare_wormhole_pc/train/num_samples_{num_samples_training}"
        prefix_saved_result = f"saved_time_training/num_samples_{num_samples_training}/epochs_{num_epochs}"
        os.makedirs(prefix_saved_result, exist_ok=True)

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        dtype = torch.float32

        X_train = torch.load(os.path.join(dataset_path, "samples.pt")).cuda()

        BATCH_SIZES = [4, 6, 8, 10, 12, 14, 16, 18, 20]

        results = []
        for bs in BATCH_SIZES:

            print(f"\n=== Baseline Wormhole | batch_size={bs} ===")
            config = DefaultConfig(
                device=device, batch_size=bs, dtype=dtype,
                coeff_dec=0.1, emb_dim=128, num_heads=4, num_layers=3, mlp_dim=512,
                attention_dropout_rate=0.1, lr=1e-4, epochs=num_epochs, decay_steps=200
            )
            config.n_samples, config.n_points, config.input_dim = X_train.shape

            dataloader = DataLoader(
                TensorDataset(X_train),
                batch_size=config.batch_size, shuffle=True, drop_last=False
            )

            group_name = "wormhole"
            run_dir = os.path.join(prefix_saved_result, group_name)
            os.makedirs(run_dir, exist_ok=True)

            model = TransformerAutoencoder(
                config=config, seq_len=config.n_points, inp_dim=config.input_dim
            ).to(device)
            optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
            scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=config.decay_steps, gamma=0.75)

            wormhole = Wormhole(transformer=model, config=config, run_dir=run_dir, compute_stats=False, save_best=False)

            if torch.cuda.is_available():
                torch.cuda.synchronize()
            t0 = time.time()
            wormhole.train_model(dataloader=dataloader, optimizer=optimizer, scheduler=scheduler, epochs=config.epochs)
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            train_sec = time.time() - t0
            total_sec = train_sec

            print(f"[{group_name}] batch={bs} | train={train_sec:.2f}s")

            time_run_dir = os.path.join(run_dir, f"time_{group_name}_bs{bs}")
            os.makedirs(time_run_dir, exist_ok=True)

            time_stats = {
                "group": group_name,
                "batch_size": int(config.batch_size),
                "estimate_alpha_sec": 0.0,
                "train_time_sec": round(train_sec, 3),
                "total_time_sec": round(total_sec, 3),
                "epochs": int(config.epochs),
                "lr": float(config.lr),
                "n_samples": int(config.n_samples),
                "n_points": int(config.n_points),
                "input_dim": int(config.input_dim),
                "device": str(device),
            }
            with open(os.path.join(time_run_dir, "time_stats.json"), "w") as f:
                json.dump(time_stats, f, indent=2)
            print(f"=> [{group_name}] Saved time stats to {os.path.join(time_run_dir, 'time_stats.json')}")
            results.append(time_stats)

            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        agg_json = os.path.join(prefix_saved_result, "batchsize_time_baseline.json")
        with open(agg_json, "w") as f:
            json.dump(results, f, indent=2)

        agg_csv = os.path.join(prefix_saved_result, "batchsize_time_baseline.csv")
        with open(agg_csv, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(["group", "batch_size", "estimate_alpha_sec", "train_time_sec", "total_time_sec",
                            "epochs", "lr", "n_samples", "n_points", "input_dim", "device"])
            for r in results:
                writer.writerow([
                    r["group"], r["batch_size"], r["estimate_alpha_sec"], r["train_time_sec"], r["total_time_sec"],
                    r["epochs"], r["lr"], r["n_samples"], r["n_points"], r["input_dim"], r["device"]
                ])

        print(f"=> Saved {agg_json} and {agg_csv}")
