import os
import numpy as np
import pandas as pd
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error

num_samples = 50
num_pairs_test = 10000

metrics_path = f"preprocessed_dataset/point_cloud/test/num_pairs_{num_pairs_test}/all_metrics_test.csv"
alpha_path   = f"preprocessed_dataset/point_cloud/test/num{num_samples}/optimal_alpha_all.csv"

df_metrics = pd.read_csv(metrics_path)
df_alpha   = pd.read_csv(alpha_path)

weights = {}
for method in df_alpha["method"].unique():
    sub = df_alpha[df_alpha["method"] == method]
    row_alpha = sub[sub["param"] == "alpha"]
    if len(row_alpha) == 1:
        weights[method] = {"alpha": float(row_alpha["value"].values[0])}
    else:
        ws = []
        for k in sorted([p for p in sub["param"].values if str(p).startswith("w")],
                        key=lambda s: int(str(s)[1:])):
            ws.append(float(sub[sub["param"] == k]["value"].values[0]))
        weights[method] = {"w": np.asarray(ws, dtype=float)}

PAIR_ORDERS = {
    "rg_s": ["SW", "PWD"],
    "rg_e": ["EBSW", "EST"],
    "rg_o": ["MaxSW", "MinSWGG"],
}
MULTI_ORDERS = {
    "rg_se": ["SW", "EBSW", "PWD", "EST"],
    "rg_seo": ["SW", "EBSW", "PWD", "EST", "MaxSW", "MinSWGG"],
}

def build_X(df, cols):
    return df[cols].to_numpy(dtype=float)

def predict_for_method(df, method_key):
    if method_key.endswith("_constr"):
        base = method_key[:-7]; variant = "constr"
    elif method_key.endswith("_unconstr"):
        base = method_key[:-9]; variant = "unconstr"
    else:
        base = method_key; variant = None

    if base in PAIR_ORDERS or base in MULTI_ORDERS:
        cols = (PAIR_ORDERS.get(base) or MULTI_ORDERS.get(base))
        X = build_X(df, cols)
        key = f"{base}_{variant}"
        if variant == "constr":
            if "alpha" in weights[key]:
                a = float(weights[key]["alpha"])
                w = np.array([a, 1.0 - a], dtype=float)
            else:
                w = np.asarray(weights[key]["w"], dtype=float).reshape(-1)
            w[w < 0] = 0
        else:
            w = np.asarray(weights[key]["w"], dtype=float).reshape(-1)
        return X @ w
    raise KeyError(method_key)

def eval_method(df, method_key):
    y_true = df["Wasserstein"].to_numpy(dtype=float)
    y_pred = predict_for_method(df, method_key)
    mask = ~((y_true == 0) & (y_pred == 0))
    y_true, y_pred = y_true[mask], y_pred[mask]
    mask = ~((y_true < 0) | (y_pred < 0))
    y_true, y_pred = y_true[mask], y_pred[mask]
    r2  = r2_score(y_true, y_pred)
    mse = mean_squared_error(y_true, y_pred)
    mae = mean_absolute_error(y_true, y_pred)
    print(f"{method_key:>18s} | R²={r2:.4f}  MSE={mse:.6f}  MAE={mae:.6f}")
    return r2, mse, mae

methods_to_run = [
    "rg_s_constr", "rg_s_unconstr",
    "rg_e_constr", "rg_e_unconstr",
    "rg_o_constr", "rg_o_unconstr",
    "rg_se_constr", "rg_se_unconstr",
    "rg_seo_constr", "rg_seo_unconstr",
]

summary_rows = []
for m in methods_to_run:
    try:
        r2, mse, mae = eval_method(df_metrics, m)
        summary_rows.append({"method": m, "R2": r2, "MSE": mse, "MAE": mae})
    except KeyError:
        pass

if summary_rows:
    pd.DataFrame(summary_rows).to_csv(f"summary_{num_samples}.csv", index=False)
