# experiment_metrics_dim_first.py
import os, csv
import numpy as np
import torch

from sw import (
    Wasserstein_Distance,
    Sliced_Wasserstein_Distance,
    Projected_Wasserstein_Distance,
    Energy_based_Sliced_Wasserstein,
    Max_Sliced_Wasserstein_Distance,
    Min_SWGG,
    Expected_Sliced_Transport,
)
from utils import generate_uniform_unit_sphere_projections

DEVICE = "cuda"
DTYPE  = torch.float32
NUM_PROJ = 200
OUT_ROOT = "saved_simulations_mog"

def random_gaussian_mixture(n_components, n_samples, dim,
                            mean_range=(-5.,5.), std_range=(0.2,1.5),
                            device="cpu", dtype=torch.float32):
    n_samples = [n_samples] * n_components
    means = torch.empty(n_components, dim, device=device, dtype=dtype).uniform_(*mean_range)
    stds  = torch.empty(n_components, dim, device=device, dtype=dtype).uniform_(*std_range)
    xs, ys = [], []
    for k, n_k in enumerate(n_samples):
        eps = torch.randn(n_k, dim, device=device, dtype=dtype)
        xs.append(means[k] + eps * stds[k])
        ys.append(torch.full((n_k,), k, device=device, dtype=torch.long))
    return torch.cat(xs,0), torch.cat(ys,0)


def run_one_experiment(run_id,
                       dim_start=2, dim_end=200,
                       mix1_components=3, mix1_examples=500, std_mix1=(0.5,1.0),
                       mix2_components=3, mix2_examples=500, std_mix2=(0.5,1.0),
                       mean_mix1=(0.0,2.0), mean_mix2=(4.0,6.0),
                       base_seed=20250819):
    os.makedirs(OUT_ROOT, exist_ok=True)
    run_csv = os.path.join(OUT_ROOT, f"run{run_id}_metrics_by_dim.csv")
    with open(run_csv, "w", newline="") as f:
        csv.writer(f).writerow(
            ["dim","W","sw","pwd","ebsw","est","max_sw","min_swgg","seed",
             "mean1_lo","mean1_hi","mean2_lo","mean2_hi"]
        )

    for dim in range(dim_start, dim_end+1):
        seed_here = base_seed + run_id*10000 + dim
        torch.manual_seed(seed_here)
        np.random.seed(seed_here)

        dim_dir = os.path.join(OUT_ROOT, f"dim_{dim:03d}")
        run_dir = os.path.join(dim_dir, f"run{run_id}")
        os.makedirs(run_dir, exist_ok=True)

        P = generate_uniform_unit_sphere_projections(
            dim=dim, requires_grad=False, num_projections=NUM_PROJ,
            dtype=DTYPE, device=DEVICE
        )
        torch.save({"projection_matrix": P}, os.path.join(run_dir, "projection.pt"))

        X1,_ = random_gaussian_mixture(mix1_components, mix1_examples, dim,
                                       mean_range=mean_mix1, std_range=std_mix1,
                                       device=DEVICE, dtype=DTYPE)
        X2,_ = random_gaussian_mixture(mix2_components, mix2_examples, dim,
                                       mean_range=mean_mix2, std_range=std_mix2,
                                       device=DEVICE, dtype=DTYPE)
        torch.save({"dim":dim, "X1":X1, "X2":X2, "seed":seed_here,
                    "mean_mix1":mean_mix1, "mean_mix2":mean_mix2},
                   os.path.join(run_dir, "data.pt"))

        # Tính tất cả metrics
        Wv   = float(Wasserstein_Distance(X1,X2,numItermax=10000,device=DEVICE))
        SWv  = float(Sliced_Wasserstein_Distance(X1,X2,projection_matrix=P,device=DEVICE,dtype=DTYPE))
        PWDv = float(Projected_Wasserstein_Distance(X1,X2,projection_matrix=P,device=DEVICE,dtype=DTYPE))
        EBSWv= float(Energy_based_Sliced_Wasserstein(X1,X2,projection_matrix=P,device=DEVICE,dtype=DTYPE))
        ESTv = float(Expected_Sliced_Transport(X1,X2,projection_matrix=P,device=DEVICE,dtype=DTYPE))
        Minv = float(Min_SWGG(X1,X2,lr=5e-2,num_iter=20,s=20,std=0.5,device=DEVICE,dtype=DTYPE)[0])
        Maxv = float(Max_Sliced_Wasserstein_Distance(X1,X2,require_optimize=True,lr=1e-1,num_iter=20,device=DEVICE,dtype=DTYPE)[0])

        # Lưu từng metric .pt trong run này
        torch.save(torch.tensor(Wv,   dtype=DTYPE), os.path.join(run_dir, "W.pt"))
        torch.save(torch.tensor(SWv,  dtype=DTYPE), os.path.join(run_dir, "sw.pt"))
        torch.save(torch.tensor(PWDv, dtype=DTYPE), os.path.join(run_dir, "pwd.pt"))
        torch.save(torch.tensor(EBSWv,dtype=DTYPE), os.path.join(run_dir, "ebsw.pt"))
        torch.save(torch.tensor(ESTv, dtype=DTYPE), os.path.join(run_dir, "est.pt"))
        torch.save(torch.tensor(Maxv, dtype=DTYPE), os.path.join(run_dir, "max_sw.pt"))
        torch.save(torch.tensor(Minv, dtype=DTYPE), os.path.join(run_dir, "min_swgg.pt"))

        # Append CSV của run
        with open(run_csv, "a", newline="") as f:
            csv.writer(f).writerow(
                [dim, Wv, SWv, PWDv, EBSWv, ESTv, Maxv, Minv, seed_here,
                 mean_mix1[0], mean_mix1[1], mean_mix2[0], mean_mix2[1]]
            )

        print(f"[run{run_id}] dim={dim:3d} | W={Wv:.4f} | SW={SWv:.4f} | PWD={PWDv:.4f} | "
              f"EBSW={EBSWv:.4f} | EST={ESTv:.4f} | Max={Maxv:.4f} | Min={Minv:.4f}")

if __name__ == "__main__":
    os.makedirs(OUT_ROOT, exist_ok=True)
    base_seed = 28032003

    for run_id in range(1, 11):
        mix1_components = 3
        mix1_examples = 200
        mean_mix1 = (-2, -1)
        std_mix1 = (0.5, 1)

        mix2_components = 3
        mix2_examples = 200
        mean_mix2 = (1, 2)
        std_mix2 = (0.5,1)

        run_one_experiment(
            run_id=run_id,
            dim_start=1, dim_end=100,
            mix1_components=mix1_components, mix1_examples=mix1_examples, std_mix1=std_mix1,
            mix2_components=mix2_components, mix2_examples=mix2_examples, std_mix2=std_mix2,
            mean_mix1=mean_mix1, mean_mix2=mean_mix2,
            base_seed=base_seed
        )
