# estimate_alpha_per_dim.py
import os, json, csv
import numpy as np
import torch
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from utils import optimal_alpha_simplex, optimal_alpha_general

OUT_ROOT = "saved_simulations_mog"
ALPHA_DIR = os.path.join(OUT_ROOT, "alpha_simplex")

GROUPS = [
    (["sw", "pwd"], "fast-SP"),
    (["ebsw", "est"], "fast-E"),
    (["max_sw", "min_swgg"], "fast-Op"),
    (["sw", "pwd", "ebsw", "est"], "fast4"),
    (["sw", "pwd", "ebsw", "est", "max_sw", "min_swgg"], "fast6"),
]

def collect_runs_for_dim(dim_dir):
    """Đọc tất cả run trong folder dim_xxx, trả về list dict"""
    rows = []
    for run_name in sorted(os.listdir(dim_dir)):
        run_path = os.path.join(dim_dir, run_name)
        if not os.path.isdir(run_path):
            continue
        row = {}
        for m in ["W","sw","pwd","ebsw","est","max_sw","min_swgg"]:
            fpath = os.path.join(run_path, f"{m}.pt")
            if os.path.exists(fpath):
                row[m] = float(torch.load(fpath).item())
            else:
                row[m] = np.nan
        if np.isfinite(row["W"]):
            rows.append(row)
    return rows

def r2_score(y_true, y_pred):
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    if y_true.size == 0: return np.nan
    ss_res = np.sum((y_true - y_pred) ** 2)
    ss_tot = np.sum((y_true - y_true.mean()) ** 2)
    if ss_tot <= 1e-15: return np.nan
    return 1 - ss_res/ss_tot

def estimate_alpha_across_dims(ridge=0.0, solver="OSQP"):
    os.makedirs(ALPHA_DIR, exist_ok=True)
    out_csv = os.path.join(ALPHA_DIR, "alphas_by_dim.csv")
    header = ["dim"]
    for _, name in GROUPS:
        header += [f"{name}_weights_json", f"{name}_mse", f"{name}_mae", f"{name}_r2", f"{name}_n"]
    with open(out_csv, "w", newline="") as f:
        csv.writer(f).writerow(header)
    for dim_name in sorted(os.listdir(OUT_ROOT)):
        if not dim_name.startswith("dim_"): continue
        dim_dir = os.path.join(OUT_ROOT, dim_name)
        if not os.path.isdir(dim_dir): continue
        try:
            dim = int(dim_name.split("_")[1])
        except: continue
        rows = collect_runs_for_dim(dim_dir)
        if not rows: continue
        row_out = [dim]
        for cols, name in GROUPS:
            X_list, y_list = [], []
            for r in rows:
                try:
                    x = [r[c] for c in cols]
                except KeyError:
                    continue
                if all(np.isfinite(x)):
                    X_list.append(x)
                    y_list.append(r["W"])
            X = np.array(X_list, dtype=float)
            y = np.array(y_list, dtype=float)
            if len(y) == 0:
                row_out += [json.dumps({}), np.nan, np.nan, np.nan, 0]
                continue
            a = optimal_alpha_simplex(X, y, ridge=ridge, solver=solver)  # (d,)
            # a = optimal_alpha_general(X, y)
            y_hat = X @ a
            mse = float(np.mean((y_hat - y)**2))
            mae = float(np.mean(np.abs(y_hat - y)))
            r2  = float(r2_score(y, y_hat))
            weights = {k: float(v) for k,v in zip(cols, a)}
            row_out += [json.dumps(weights), mse, mae, r2, int(len(y))]

        with open(out_csv, "a", newline="") as f:
            csv.writer(f).writerow(row_out)

if __name__ == "__main__":
    estimate_alpha_across_dims(ridge=0.0, solver="OSQP")
