import numpy as np
import torch

from utils import optimal_alpha_simplex, optimal_alpha_general, generate_uniform_unit_sphere_projections
from sw import (
    Wasserstein_Distance, Sliced_Wasserstein_Distance, Projected_Wasserstein_Distance,
    Energy_based_Sliced_Wasserstein, Max_Sliced_Wasserstein_Distance,
    Min_SWGG, Expected_Sliced_Transport
)

saved_path = "saved_knn"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE = torch.float32
N_RUNS, BASE_SEED, NUM_PAIRS, MAX_K, RIDGE = 10, 42, 10, 25, 0.0

WD      = torch.load(f"{saved_path}/WD_dist.pt",      map_location="cpu").to(DEVICE).to(DTYPE)
SWD     = torch.load(f"{saved_path}/SWD_dist.pt",     map_location="cpu").to(DEVICE).to(DTYPE)
PWD     = torch.load(f"{saved_path}/PWD_dist.pt",     map_location="cpu").to(DEVICE).to(DTYPE)
EBSW    = torch.load(f"{saved_path}/EBSW_dist.pt",    map_location="cpu").to(DEVICE).to(DTYPE)
EST     = torch.load(f"{saved_path}/EST_dist.pt",     map_location="cpu").to(DEVICE).to(DTYPE)
MaxSW   = torch.load(f"{saved_path}/MaxSW_dist.pt",   map_location="cpu").to(DEVICE).to(DTYPE)
MinSWGG = torch.load(f"{saved_path}/MinSW_dist.pt",   map_location="cpu").to(DEVICE).to(DTYPE)

y_tr = torch.load(f"{saved_path}/y_train.pt", map_location="cpu").to(DEVICE).long()
y_te = torch.load(f"{saved_path}/y_test.pt",  map_location="cpu").to(DEVICE).long()
X_tr = torch.load(f"{saved_path}/X_train.pt", map_location="cpu").to(DEVICE).to(DTYPE)

proj = generate_uniform_unit_sphere_projections(3, False, 100, DTYPE, DEVICE)
GROUPS = [
    ("SWD_PWD", ["sw","pwd"]),
    ("EBSW_EST", ["ebsw","est"]),
    ("MaxSW_MinSWGG", ["maxsw","minswgg"]),
    ("OPT4", ["sw","pwd","ebsw","est"]),
    ("OPT6", ["sw","pwd","ebsw","est","maxsw","minswgg"]),
]
MAT_POOL = {"sw": SWD, "pwd": PWD, "ebsw": EBSW, "est": EST, "maxsw": MaxSW, "minswgg": MinSWGG}
names = [f"{g}_{t}" for g,_ in GROUPS for t in ["constr","unconstr"]] + ["WD","SWD","PWD","EBSW","EST","MaxSW","MinSWGG"]
acc = {n: np.zeros((N_RUNS, MAX_K)) for n in names}

def knn(M, k):
    idx = torch.topk(M, k, dim=1, largest=False).indices
    pred = torch.mode(y_tr[idx], dim=1).values
    return (pred == y_te).float().mean().item()

for run in range(N_RUNS):
    g = torch.Generator(device=DEVICE).manual_seed(BASE_SEED+run)
    idx = torch.randperm(X_tr.shape[0], generator=g, device=DEVICE)[:2*NUM_PAIRS]
    I, J = idx[:NUM_PAIRS], idx[NUM_PAIRS:]
    vals = {k: [] for k in ["ws","sw","pwd","ebsw","est","minswgg","maxsw"]}
    for i in range(NUM_PAIRS):
        x, y = X_tr[I[i]], X_tr[J[i]]
        vals["ws"].append(Wasserstein_Distance(x,y,device=DEVICE).item())
        vals["sw"].append(Sliced_Wasserstein_Distance(x,y,proj,DEVICE,DTYPE).item())
        vals["pwd"].append(Projected_Wasserstein_Distance(x,y,proj,DEVICE,DTYPE).item())
        vals["ebsw"].append(Energy_based_Sliced_Wasserstein(x,y,proj,DEVICE,DTYPE).item())
        vals["est"].append(Expected_Sliced_Transport(x,y,proj,DEVICE,DTYPE).item())
        vals["minswgg"].append(Min_SWGG(x,y,5e-2,5,20,0.5,DEVICE,DTYPE)[0].item())
        vals["maxsw"].append(Max_Sliced_Wasserstein_Distance(x,y,True,1e-1,5,DEVICE,DTYPE)[0].item())
    y_ws = np.array(vals["ws"])
    pool = {k: np.array(vals[k]) for k in MAT_POOL.keys()}
    combos = {}
    for gname, feats in GROUPS:
        X_np = np.stack([pool[k] for k in feats],1)
        a1 = optimal_alpha_simplex(X_np, y_ws)
        a2 = optimal_alpha_general(X_np, y_ws, ridge=RIDGE)
        combos[f"{gname}_constr"]   = sum(float(w)*MAT_POOL[f] for w,f in zip(a1,feats))
        combos[f"{gname}_unconstr"] = sum(float(w)*MAT_POOL[f] for w,f in zip(a2,feats))
    combos.update({"WD":WD,"SWD":SWD,"PWD":PWD,"EBSW":EBSW,"EST":EST,"MaxSW":MaxSW,"MinSWGG":MinSWGG})
    for k in range(1,MAX_K+1):
        for n,M in combos.items():
            acc[n][run,k-1] = knn(M,k)

for k in range(1,MAX_K+1):
    row = [f"k={k:2d}"]
    for n in names:
        row.append(f"{n}: {acc[n][:,k-1].mean():.3f}±{acc[n][:,k-1].std(ddof=1):.3f}")
    print(" | ".join(row))
