import os, time
import numpy as np
import pandas as pd
import torch

from sw import (
    Wasserstein_Distance,
    Sliced_Wasserstein_Distance,
    Projected_Wasserstein_Distance,
    Energy_based_Sliced_Wasserstein,
    Max_Sliced_Wasserstein_Distance,
    Min_SWGG,
    Expected_Sliced_Transport,
)
from utils import (
    generate_uniform_unit_sphere_projections,
    optimal_alpha,
    optimal_alpha_simplex,
    optimal_alpha_general,
)


# change dataset_root and result_root based on which dataset you want to experiment on: [pcmnist, pointcloud, merfish, scrna]
dataset_root = "preprocessed_dataset/point_cloud"
result_root  = "saved_compare_wormhole/pointcloud/rg"
os.makedirs(result_root, exist_ok=True)

# also change the NUM_PROJ and DIMS according to each dataset
NUM_PROJ = 100
DIMS     = 3


DTYPE    = torch.float32
DEVICE   = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_samples_list = [10, 50, 100, 200]
RIDGE = 0.0

def _calib_metrics(y_true: np.ndarray, y_pred: np.ndarray):
    r = y_true - y_pred
    mse = float(np.mean(r * r))
    mae = float(np.mean(np.abs(r)))
    ss_tot = float(np.sum((y_true - y_true.mean())**2))
    ss_res = float(np.sum((y_true - y_pred)**2))
    r2 = 1.0 - (ss_res / ss_tot if ss_tot > 0 else np.nan)
    return r2, mse, mae

for num_samples in num_samples_list:
    train_dir = os.path.join(dataset_root, "train", f"num_samples_{num_samples}")
    X_train = torch.load(os.path.join(train_dir, "samples.pt"), map_location="cpu").to(DEVICE)
    n = X_train.shape[0]
    pairs = [(i, j) for i in range(n) for j in range(i+1, n)]

    proj = generate_uniform_unit_sphere_projections(
        dim=DIMS, requires_grad=False, num_projections=NUM_PROJ, dtype=DTYPE, device=DEVICE
    )

    vals = {k: [] for k in ["ws","sw","pwd","ebsw","est","minswgg","maxsw"]}
    timing_groups = {"Wasserstein": 0.0, "SW_PWD": 0.0, "EBSW_EST": 0.0, "MinSWGG_MaxSW": 0.0}

    for i, j in pairs:
        xi, yj = X_train[i], X_train[j]

        t0 = time.time()
        vals["ws"].append(Wasserstein_Distance(xi, yj, device=DEVICE).item())
        timing_groups["Wasserstein"] += time.time() - t0

        t0 = time.time()
        vals["sw"].append(Sliced_Wasserstein_Distance(xi, yj, projection_matrix=proj, device=DEVICE, dtype=DTYPE).item())
        vals["pwd"].append(Projected_Wasserstein_Distance(xi, yj, projection_matrix=proj, device=DEVICE, dtype=DTYPE).item())
        timing_groups["SW_PWD"] += time.time() - t0

        t0 = time.time()
        vals["ebsw"].append(Energy_based_Sliced_Wasserstein(xi, yj, projection_matrix=proj, device=DEVICE, dtype=DTYPE).item())
        vals["est"].append(Expected_Sliced_Transport(xi, yj, projection_matrix=proj, device=DEVICE, dtype=DTYPE).item())
        timing_groups["EBSW_EST"] += time.time() - t0

        t0 = time.time()
        vals["minswgg"].append(Min_SWGG(xi, yj, lr=5e-2, num_iter=10, s=5, std=0.5, device=DEVICE, dtype=DTYPE)[0].item())
        vals["maxsw"].append(Max_Sliced_Wasserstein_Distance(xi, yj, require_optimize=True, lr=1e-1, num_iter=10, device=DEVICE, dtype=DTYPE)[0].item())
        timing_groups["MinSWGG_MaxSW"] += time.time() - t0

    for k in vals: vals[k] = np.asarray(vals[k], dtype=float)
    ws, sw, pwd, ebsw, est, minswgg, maxsw = (vals[k] for k in ["ws","sw","pwd","ebsw","est","minswgg","maxsw"])

    pairs_defs = [
        ("rg_s",  np.column_stack([sw, pwd])),
        ("rg_e",  np.column_stack([ebsw, est])),
        ("rg_o",  np.column_stack([maxsw, minswgg])),
    ]
    X_rg_se  = np.column_stack([sw, ebsw, pwd, est])
    X_rg_seo = np.column_stack([sw, ebsw, pwd, est, maxsw, minswgg])
    multis_defs = [("rg_se", X_rg_se), ("rg_seo", X_rg_seo)]

    timing_fit, alpha_results, calib_rows = {}, {}, []

    for base_name, X2 in pairs_defs:
        t0 = time.time()
        a = optimal_alpha(X2[:,0], X2[:,1], ws)
        timing_fit[f"{base_name}_constr"] = time.time() - t0
        alpha_results[f"{base_name}_constr"] = a
        yhat = a * X2[:,0] + (1.0 - a) * X2[:,1]
        r2, mse, mae = _calib_metrics(ws, yhat)
        calib_rows.append({"method": f"{base_name}_constr", "type": "pair", "r2_cal": r2, "mse_cal": mse, "mae_cal": mae, "fit_time_sec": timing_fit[f"{base_name}_constr"]})

        t0 = time.time()
        w = optimal_alpha_general(X2, ws, ridge=RIDGE)
        timing_fit[f"{base_name}_unconstr"] = time.time() - t0
        alpha_results[f"{base_name}_unconstr"] = w
        yhat = X2 @ w
        r2, mse, mae = _calib_metrics(ws, yhat)
        calib_rows.append({"method": f"{base_name}_unconstr", "type": "pair", "r2_cal": r2, "mse_cal": mse, "mae_cal": mae, "fit_time_sec": timing_fit[f"{base_name}_unconstr"]})

    for base_name, Xk in multis_defs:
        t0 = time.time()
        w_c = optimal_alpha_simplex(Xk, ws)
        timing_fit[f"{base_name}_constr"] = time.time() - t0
        alpha_results[f"{base_name}_constr"] = w_c
        yhat = Xk @ w_c
        r2, mse, mae = _calib_metrics(ws, yhat)
        calib_rows.append({"method": f"{base_name}_constr", "type": "multi", "r2_cal": r2, "mse_cal": mse, "mae_cal": mae, "fit_time_sec": timing_fit[f"{base_name}_constr"]})

        t0 = time.time()
        w_u = optimal_alpha_general(Xk, ws, ridge=RIDGE)
        timing_fit[f"{base_name}_unconstr"] = time.time() - t0
        alpha_results[f"{base_name}_unconstr"] = w_u
        yhat = Xk @ w_u
        r2, mse, mae = _calib_metrics(ws, yhat)
        calib_rows.append({"method": f"{base_name}_unconstr", "type": "multi", "r2_cal": r2, "mse_cal": mse, "mae_cal": mae, "fit_time_sec": timing_fit[f"{base_name}_unconstr"]})

    save_dir = os.path.join(result_root, f"num{num_samples}")
    os.makedirs(save_dir, exist_ok=True)

    with open(os.path.join(save_dir, "timing_and_alpha.txt"), "w") as f:
        for k, v in timing_groups.items():
            f.write(f"{k}: {v:.4f}\n")
        for k, v in alpha_results.items():
            if np.isscalar(v):
                f.write(f"{k}: {float(v):.6f}\n")
            else:
                arr_str = ", ".join([f"{x:.6f}" for x in np.asarray(v).ravel()])
                f.write(f"{k}: [{arr_str}]\n")
        for k, v in timing_fit.items():
            f.write(f"{k}: {v:.6f}\n")

    pd.DataFrame([timing_groups]).to_csv(os.path.join(save_dir, "timing_group_metric.csv"), index=False)

    rows = []
    for k, v in alpha_results.items():
        if np.isscalar(v):
            rows.append({"method": k, "param": "alpha", "value": float(v), "fit_time_sec": timing_fit.get(k, np.nan)})
        else:
            for idx, wv in enumerate(np.asarray(v).ravel()):
                rows.append({"method": k, "param": f"w{idx+1}", "value": float(wv), "fit_time_sec": timing_fit.get(k, np.nan)})
    pd.DataFrame(rows).to_csv(os.path.join(save_dir, "optimal_alpha_all.csv"), index=False)

    pd.DataFrame(calib_rows).to_csv(os.path.join(save_dir, "calibration_quality.csv"), index=False)
