import os, glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# === Configuration Section ===
FOLDER = "/Users/bytedance/Desktop/QileZhang/llm/IRT/eval/exp_0918_mix_metric/lll/02_sample_predictions"
PRED_COLS = ["global_mean_pred","model_mean_pred","question_mean_pred","1pl_irt_pred","mixed_metric_irt_pred"]

TRUE_COL = "true_value"
RATIO_COL = "train_ratio"
METRIC_NAME_COL = "metric_name"  # New: Column name for filtering metrics
TARGET_METRIC = "response_matrix__bertscore_F1"  # New: Target metric name
OUTPUT_DIR = '/Users/bytedance/Desktop/QileZhang/llm/IRT/eval/exp_0918_mix_metric/lll/04_metrics/more'
PLOTS_DIR = '/Users/bytedance/Desktop/QileZhang/llm/IRT/eval/exp_0918_mix_metric/lll/04_metrics/more'
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(PLOTS_DIR, exist_ok=True)

# Optional: Enable logarithmic scale for error metrics (True/False)
LOG_SCALE_FOR = {"MSE","RMSE","MAE","MAPE(%)"}  # To disable logarithmic scale, set to empty set()

# === Reading and Merging ===
files = sorted(glob.glob(os.path.join(FOLDER, "*.csv")))
if not files:
    raise FileNotFoundError(f"CSV not found: {FOLDER}")

dfs = []
for f in files:
    df = pd.read_csv(f)
    # Keep only data for target metric
    df = df[df[METRIC_NAME_COL] == TARGET_METRIC].copy()
    # Fill missing columns with NaN
    for c in PRED_COLS:
        if c not in df.columns:
            df[c] = np.nan
    keep = [RATIO_COL, TRUE_COL] + PRED_COLS
    dfs.append(df[keep])

full = pd.concat(dfs, ignore_index=True)
full[RATIO_COL] = pd.to_numeric(full[RATIO_COL], errors="coerce")
full[TRUE_COL]  = pd.to_numeric(full[TRUE_COL],  errors="coerce")
for c in PRED_COLS:
    full[c] = pd.to_numeric(full[c], errors="coerce")
full = full.dropna(subset=[RATIO_COL, TRUE_COL])

# === Metric Functions ===
def _safe_arrays(y_true, y_pred):
    y_true = np.asarray(y_true, dtype=float)
    y_pred = np.asarray(y_pred, dtype=float)
    m = ~np.isnan(y_true) & ~np.isnan(y_pred)
    return y_true[m], y_pred[m]

def mse(y_true, y_pred):
    y_true, y_pred = _safe_arrays(y_true, y_pred)
    return np.mean((y_true - y_pred) ** 2) if y_true.size else np.nan

def rmse(y_true, y_pred):
    v = mse(y_true, y_pred)
    return np.sqrt(v) if np.isfinite(v) else np.nan

def mae(y_true, y_pred):
    y_true, y_pred = _safe_arrays(y_true, y_pred)
    return np.mean(np.abs(y_true - y_pred)) if y_true.size else np.nan

def bias(y_true, y_pred):
    y_true, y_pred = _safe_arrays(y_true, y_pred)
    return np.mean(y_pred - y_true) if y_true.size else np.nan

def r2(y_true, y_pred):
    y_true, y_pred = _safe_arrays(y_true, y_pred)
    if y_true.size == 0:
        return np.nan
    ss_res = np.sum((y_true - y_pred) ** 2)
    ss_tot = np.sum((y_true - np.mean(y_true)) ** 2)
    return 1 - ss_res / ss_tot if ss_tot > 0 else np.nan

def pearson_r(y_true, y_pred):
    y_true, y_pred = _safe_arrays(y_true, y_pred)
    if y_true.size < 2: return np.nan
    return np.corrcoef(y_true, y_pred)[0, 1]

def spearman_rho(y_true, y_pred):
    y_true, y_pred = _safe_arrays(y_true, y_pred)
    if y_true.size < 2: return np.nan
    rt = pd.Series(y_true).rank(method="average").to_numpy()
    rp = pd.Series(y_pred).rank(method="average").to_numpy()
    return np.corrcoef(rt, rp)[0, 1]

def mape(y_true, y_pred):
    y_true, y_pred = _safe_arrays(y_true, y_pred)
    m = y_true != 0
    y_true, y_pred = y_true[m], y_pred[m]
    if y_true.size == 0: return np.nan
    return np.mean(np.abs((y_true - y_pred) / y_true)) * 100.0

METRICS = {
    "MSE": mse,
    "RMSE": rmse,
    "MAE": mae,
    "Bias": bias,
    "R2": r2,
    "PearsonR": pearson_r,
    "SpearmanRho": spearman_rho,
    "MAPE(%)": mape,
}

# === Generate "row=ratio × column=predictor" metric matrix ===
def metric_by_ratio(df, metric_func):
    rows = []
    for tr, g in df.groupby(RATIO_COL):
        row = {RATIO_COL: tr}
        y = g[TRUE_COL]
        for col in PRED_COLS:
            row[col] = metric_func(y, g[col]) if col in g.columns else np.nan
        rows.append(row)
    out = pd.DataFrame(rows).sort_values(RATIO_COL).set_index(RATIO_COL)
    # Ensure column order consistency
    out = out.reindex(columns=PRED_COLS)
    return out

# === Calculate and save each metric matrix ===
results = {}
for name, func in METRICS.items():
    df_metric = metric_by_ratio(full, func)
    results[name] = df_metric
    df_metric.to_csv(os.path.join(OUTPUT_DIR, f"{name}_by_train_ratio.csv"))

# === Visualization: One line chart per metric ===
def format_ratio_ticks(ax):
    # If ratio is between (0,1] and not too discrete, prefer to display percentage scale
    x = ax.get_xticks()
    # Here only format axis labels, do not change values
    labels = []
    for v in x:
        try:
            if 0 <= v <= 1:
                labels.append(f"{v:.2f}")
            else:
                labels.append(f"{v:g}")
        except Exception:
            labels.append(str(v))
    ax.set_xticklabels(labels, rotation=0)

def plot_metric_df(name, df_metric):
    # Skip columns that are all NaN
    cols = [c for c in df_metric.columns if not df_metric[c].isna().all()]
    if len(cols) == 0:
        print(f"[Skip] {name}: All columns are NaN")
        return

    # Plot
    plt.figure(figsize=(8, 5))
    # Plot one line per column
    for col in cols:
        # Curves with only single points/all NaN can also be plotted, just with shorter line segments
        series = df_metric[col]
        plt.plot(series.index.to_numpy(), series.to_numpy(), marker="o", label=col)

    plt.title(f"{name} by train_ratio")
    plt.xlabel("train_ratio")
    plt.ylabel(name)
    plt.grid(True, linestyle="--", alpha=0.4)
    plt.legend(loc="best", frameon=True)

    ax = plt.gca()
    # Some metrics (error types) may need logarithmic scale
    if name in LOG_SCALE_FOR:
        # Use log only when all valid values > 0; otherwise keep linear scale
        vals = np.asarray(df_metric[cols].to_numpy(dtype=float))
        finite_pos = np.isfinite(vals) & (vals > 0)
        if finite_pos.any():
            plt.yscale("log")

    format_ratio_ticks(ax)
    plt.tight_layout()

    # Save
    base = name.replace("/", "_").replace("%","pct")
    png_path = os.path.join(PLOTS_DIR, f"{base}_by_train_ratio.png")
    svg_path = os.path.join(PLOTS_DIR, f"{base}_by_train_ratio.svg")
    plt.savefig(png_path, dpi=200)
    plt.savefig(svg_path)
    plt.close()
    print(f"Image saved: {png_path}")

# Batch plotting
for name, df_metric in results.items():
    plot_metric_df(name, df_metric)

print(f"All images output to: {PLOTS_DIR}")

# === If you need to quickly view some metric headers in the console ===
for key in ["MSE","RMSE","R2"]:
    if key in results:
        print(f"\n[{key}] Preview:")
        print(results[key].head())