import os
import argparse
import re
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

from config.config import (
    BASE_DIR,
    RESULTS_TEST_INTERNAL,
    PREDICTIONS_TEST_INTERNAL,
    RESULTS_TEST_EXTERNAL,
    PREDICTIONS_TEST_EXTERNAL,
    COMBINATION_INTERNAL_TEST,
    COMBINATION_EXTERNAL_TEST,
)

# ===== Helpers =====

def safe_filename(s: str) -> str:
    s = str(s)
    s = re.sub(r"[^\w\-_.]+", "_", s, flags=re.UNICODE)
    return s.strip("._")[:200] or "plot"

def load_benchmarks(benchmark_csv):
    df = pd.read_csv(benchmark_csv)
    return dict(zip(df["model_name"].astype(str), df["accuracy"].astype(float)))

def prepare_xy(sub_df):
    x = sub_df["coverage"].to_numpy(float)
    y = sub_df["accuracy"].to_numpy(float)
    order = np.argsort(x)
    x, y = x[order], y[order]
    x_unique, idx = np.unique(x, return_index=True)
    y_unique = y[idx]
    return x_unique, y_unique

def _linear_extrapolate(x0, y0, x1, y1, x_new):
    # single-step linear extrapolation using the first/last segment slope
    m = (y1 - y0) / (x1 - x0) if x1 != x0 else 0.0
    return y0 + m * (x_new - x0)

def extend_to_range(x, y, xmin=0.1, xmax=1.0):
    """
    Ensure (x,y) covers [xmin, xmax] using linear extrapolation at ends.
    Assumes x is strictly increasing.
    Returns:
      x_ext, y_ext, x_min_orig, x_max_orig
    """
    x_min_orig, x_max_orig = float(x[0]), float(x[-1])
    xs, ys = list(x), list(y)

    # left extrapolation if needed
    if xs[0] > xmin and len(xs) >= 2:
        y_left = _linear_extrapolate(xs[0], ys[0], xs[1], ys[1], xmin)
        xs = [xmin] + xs
        ys = [y_left] + ys
    elif xs[0] > xmin and len(xs) == 1:
        # degenerate case: only one point; keep it and copy value to xmin
        xs = [xmin] + xs
        ys = [ys[0]] + ys

    # right side: ensure xmax (1.0)
    if xs[-1] < xmax:
        if len(xs) >= 2:
            y_right = _linear_extrapolate(xs[-2], ys[-2], xs[-1], ys[-1], xmax)
        else:
            y_right = ys[-1]
        xs.append(xmax)
        ys.append(y_right)

    return np.asarray(xs, dtype=float), np.asarray(ys, dtype=float), x_min_orig, x_max_orig

def integrate_piecewise_linear(x_ext, y_ext, a=0.1, b=1.0):
    """
    Exact area under piecewise-linear curve on [a,b] via trapezoids.
    Assumes x_ext spans [a,b].
    """
    # Make sure a and b are exactly in the knot list (insert if needed)
    x, y = x_ext, y_ext

    # Insert 'a'
    if a < x[0] - 1e-12 or b > x[-1] + 1e-12:
        raise ValueError("x_ext must cover [a,b].")
    if a > x[0] + 1e-12:
        i = np.searchsorted(x, a)
        # linear interpolate y at a
        x0, x1 = x[i-1], x[i]
        y0, y1 = y[i-1], y[i]
        ya = y0 + (y1 - y0) * (a - x0) / (x1 - x0)
        x = np.insert(x, i, a)
        y = np.insert(y, i, ya)
    elif abs(a - x[0]) > 1e-12:
        # if a ~ x[0], snap
        x[0] = a

    # Insert 'b'
    if b < x[-1] - 1e-12:
        i = np.searchsorted(x, b)
        x0, x1 = x[i-1], x[i]
        y0, y1 = y[i-1], y[i]
        yb = y0 + (y1 - y0) * (b - x0) / (x1 - x0)
        x = np.insert(x, i, b)
        y = np.insert(y, i, yb)
    elif abs(b - x[-1]) > 1e-12:
        x[-1] = b

    return float(np.trapz(y, x))

def normalize_integral(integral, benchmark):
    denom = (1.0 - benchmark) * 0.9
    if denom <= 0:
        return np.nan
    return (integral - 0.9 * benchmark) / denom

def plot_linear(save_dir, model, subset, x_unique, y_unique, x_ext, y_ext, x_min_orig, x_max_orig):
    os.makedirs(save_dir, exist_ok=True)

    plt.figure(figsize=(6.5, 4.5))
    # scatter of unique points
    plt.scatter(x_unique, y_unique, s=18, label="unique points")

    # full polyline
    plt.plot(x_ext, y_ext, linewidth=1.8, label="piecewise-linear")

    # highlight extrapolated segments (left/right)
    if x_ext[0] < x_min_orig - 1e-12:
        mask_left = x_ext <= x_min_orig
        plt.plot(x_ext[mask_left], y_ext[mask_left], linestyle="--", linewidth=2.0, label="left extrapolation")
    if x_ext[-1] > x_max_orig + 1e-12:
        mask_right = x_ext >= x_max_orig
        plt.plot(x_ext[mask_right], y_ext[mask_right], linestyle="--", linewidth=2.0, label="right extrapolation")

    plt.title(f"Model: {model}\nSubset: {subset}")
    plt.xlabel("Coverage")
    plt.ylabel("Accuracy")
    plt.xlim(0.0, 1.0)
    plt.ylim(0.0, 1.0)
    plt.grid(True, alpha=0.3)
    plt.legend(loc="lower right", fontsize=8)

    fname = f"{safe_filename(model)}__{safe_filename(subset)}.png"
    out_path = os.path.join(save_dir, fname)
    plt.tight_layout()
    plt.savefig(out_path, dpi=160)
    plt.close()

def main(dataset_type):
    # Select dirs
    if dataset_type == "EXTERNAL":
        RESULTS_DIR = RESULTS_TEST_EXTERNAL
        PREDICTIONS_DIR = PREDICTIONS_TEST_EXTERNAL
        COMB_DIR = COMBINATION_EXTERNAL_TEST
    else:
        RESULTS_DIR = RESULTS_TEST_INTERNAL
        PREDICTIONS_DIR = PREDICTIONS_TEST_INTERNAL
        COMB_DIR = COMBINATION_INTERNAL_TEST

    # Paths
    coverage_csv = os.path.join(COMB_DIR, "modelwise_subset_coverage_accuracy.csv")
    bench_csv = os.path.join(PREDICTIONS_DIR, "test_prediction.csv")
    output_csv = os.path.join(COMB_DIR, "normalized_integrals_by_subset_variation.csv")
    plots_dir = os.path.join(COMB_DIR, "plots_interpolation")

    # Load data
    df = pd.read_csv(coverage_csv)
    df["coverage"] = df["coverage"].astype(float)
    df["accuracy"] = df["accuracy"].astype(float)
    df["model"] = df["model"].astype(str)
    df["subset"] = df["subset"].astype(str)

    benchmarks = load_benchmarks(bench_csv)

    results = []
    for model, dfg in df.groupby("model"):
        benchmark = benchmarks.get(model, None)
        subset_norms = {}

        for subset, dfs in dfg.groupby("subset"):
            try:
                x_unique, y_unique = prepare_xy(dfs)

                # Ensure coverage over [0.1, 1.0] with linear extrapolation at ends
                x_ext, y_ext, x_min_orig, x_max_orig = extend_to_range(x_unique, y_unique, xmin=0.1, xmax=1.0)

                # Exact integral for piecewise-linear curve on [0.1, 1.0]
                integral = integrate_piecewise_linear(x_ext, y_ext, a=0.1, b=1.0)

                norm_val = normalize_integral(integral, benchmark) if benchmark is not None else np.nan
                subset_norms[subset] = norm_val

                # Plot
                plot_linear(
                    save_dir=plots_dir,
                    model=model,
                    subset=subset,
                    x_unique=x_unique,
                    y_unique=y_unique,
                    x_ext=x_ext,
                    y_ext=y_ext,
                    x_min_orig=x_min_orig,
                    x_max_orig=x_max_orig,
                )
            except Exception:
                subset_norms[subset] = np.nan

        vals = [v for v in subset_norms.values() if np.isfinite(v)]
        mean_val = np.mean(vals) if vals else np.nan
        std_val = np.std(vals, ddof=1) if len(vals) >= 2 else np.nan

        row = {
            "model": model,
            "mean_normalized_integral": mean_val,
            "std_normalized_integral": std_val
        }
        for s, v in subset_norms.items():
            row[f"subset_{s}"] = v
        results.append(row)

    out_df = pd.DataFrame(results)
    out_df.to_csv(output_csv, index=False)
    print(out_df)
    print(f"Saved table to: {output_csv}")
    print(f"Saved plots to: {plots_dir}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "dataset_type",
        nargs="?",
        default="INTERNAL",
        choices=["INTERNAL", "EXTERNAL"],
        help="Dataset type: INTERNAL or EXTERNAL (default: INTERNAL)"
    )
    args = parser.parse_args()
    main(args.dataset_type)