import os
import numpy as np
import torch
import pandas as pd
import time
from sw2 import (
    Wasserstein_Distance,
    Sliced_Wasserstein_Distance,
    Projected_Wasserstein_Distance,
    Energy_based_Sliced_Wasserstein,
    Max_Sliced_Wasserstein_Distance,
    Min_SWGG,
    Expected_Sliced_Transport,
)
from utils import generate_uniform_unit_sphere_projections


NUM_PROJ = 100
NUM_DIMS = 3
DEVICE = torch.device('cuda')
DTYPE = torch.float32
num_pairs_test = 10000

dataset_root = "preprocessed_dataset/point_cloud"
result_root = f"preprocessed_dataset/point_cloud/test/num_pairs_{num_pairs_test}"
os.makedirs(result_root, exist_ok=True)


test_dir = os.path.join(dataset_root, "test", f"num_pairs_{num_pairs_test}")
pcs1 = torch.load(os.path.join(test_dir, "pcs1.pt")).to(DEVICE)
pcs2 = torch.load(os.path.join(test_dir, "pcs2.pt")).to(DEVICE)
num_pairs = pcs1.shape[0]
get_pair = lambda idx: (pcs1[idx], pcs2[idx])
print(f"=> Loaded {num_pairs} test pairs (DATA) from {test_dir}")

projection_matrix = generate_uniform_unit_sphere_projections(
    dim=NUM_DIMS, requires_grad=False, num_projections=NUM_PROJ, dtype=DTYPE, device=DEVICE
)

timing = {"Wasserstein": 0, "SW": 0, "PWD": 0, "EBSW": 0, "EST": 0, "MinSWGG": 0, "MaxSW": 0}
list_ws, list_sw, list_pwd, list_ebsw, list_est, list_minswgg, list_maxsw = [], [], [], [], [], [], []

print(">> Computing all metrics on test set...")
for idx in range(num_pairs):
    x_pc, y_pc = get_pair(idx)

    t0 = time.time()
    list_ws.append(Wasserstein_Distance(x_pc, y_pc, device=DEVICE).item())
    timing["Wasserstein"] += time.time() - t0

    t0 = time.time()
    list_sw.append(Sliced_Wasserstein_Distance(x_pc, y_pc, projection_matrix=projection_matrix, device=DEVICE, dtype=DTYPE).item())
    timing["SW"] += time.time() - t0

    t0 = time.time()
    list_pwd.append(Projected_Wasserstein_Distance(x_pc, y_pc, projection_matrix=projection_matrix, device=DEVICE, dtype=DTYPE).item())
    timing["PWD"] += time.time() - t0

    t0 = time.time()
    list_ebsw.append(Energy_based_Sliced_Wasserstein(x_pc, y_pc, projection_matrix=projection_matrix, device=DEVICE, dtype=DTYPE).item())
    timing["EBSW"] += time.time() - t0

    t0 = time.time()
    list_est.append(Expected_Sliced_Transport(x_pc, y_pc, projection_matrix=projection_matrix, device=DEVICE, dtype=DTYPE).item())
    timing["EST"] += time.time() - t0

    t0 = time.time()
    list_minswgg.append(Min_SWGG(x_pc, y_pc, lr=5e-2, num_iter=5, s=2, std=0.5, device=DEVICE, dtype=DTYPE)[0])
    timing["MinSWGG"] += time.time() - t0

    t0 = time.time()
    list_maxsw.append(Max_Sliced_Wasserstein_Distance(x_pc, y_pc, require_optimize=True, lr=1e-1, num_iter=5, device=DEVICE, dtype=DTYPE)[0])
    timing["MaxSW"] += time.time() - t0

    if idx % 500 == 0: print(f"{idx}/{num_pairs} done.")

distances_dict = {
    "Wasserstein": np.array(list_ws),
    "SW": np.array(list_sw),
    "PWD": np.array(list_pwd),
    "EBSW": np.array(list_ebsw),
    "EST": np.array(list_est),
    "MinSWGG": np.array(list_minswgg),
    "MaxSW": np.array(list_maxsw),
}
distances_df = pd.DataFrame(distances_dict)
distances_df.to_csv(os.path.join(result_root, "all_metrics_test.csv"), index=False)
np.save(os.path.join(result_root, "all_metrics_test.npy"), distances_df)
print(f"Saved test metrics to {os.path.join(result_root, 'all_metrics_test.csv')}")

with open(os.path.join(result_root, "timing_test.txt"), "w") as f:
    for k, v in timing.items():
        f.write(f"{k}: {v:.4f} s\n")
print("Saved timing for test metrics.")
