
"""
Input: the output files form running run_exp.sh

Ouput: Tables with metrics
"""

from enum import Enum
import os
import pandas as pd
import argparse
from prettytable import PrettyTable
from src.priors.gaussian_prior import TruncatedGaussian
from src.priors.kde_prior import KDEPrior
from src.priors.lc_prior import LCPrior
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os, matplotlib.pyplot as plt, matplotlib.patches as mpatches
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.size'] = 28  # Adjust the font size as needed
plt.rcParams['font.weight'] = 'bold'
plt.rcParams['axes.titleweight'] = 'bold'
plt.rcParams['axes.labelweight'] = 'bold'

class Method(Enum):
    """Enum for methods."""
    Uniform = "uniform_prior"
    Gaussian = "gaussian_prior"
    KDE = "kde_prior"
    LCP = "logit_based_calibrated_prior"


def read_output_file(output_folder, prior: Method, model_type: str, input_file: str, iter=0, filter_zero_corr=False, understood_only=False, not_defauling_zero_only=False):
    if prior == Method.LCP or prior == Method.Uniform:
        prior = Method.KDE # they share the same output df
    base = os.path.splitext(os.path.basename(input_file))[0]
    safe_model = model_type.replace(" ", "_")
    safe_prior = prior.value.replace(" ", "_")

    output_file = os.path.join(output_folder, f"{base}_{safe_model}_{safe_prior}_iter_{iter}.csv")
    print("reading", output_file)
    df = pd.read_csv(output_file)    
    return df

def get_sign_accuracy_and_MSE(df: pd.DataFrame, prior: Method, column='est_mean'):
    """
    Calculate the sign accuracy and MSE for the given DataFrame.
    """
    # Calculate sign accuracy
    if prior == Method.Uniform:
        df['abs_diff'] = abs(df['r_obs'])
        df['random_sign'] = np.random.choice([-1, 1], size=len(df))
        df['sign_accuracy'] = (df['random_sign'] * df['r_obs']) > 0
        print(df['sign_accuracy'].mean())
    else:
        df['abs_diff'] = abs(df['r_obs'] - df[column])
        df['sign_accuracy'] = (df['r_obs'] * df[column]) > 0
    sign_accuracy = df['sign_accuracy'].mean()
    df['squared_err'] = (df['r_obs'] - df[column]) ** 2
    # Calculate MSE
    mean_abs_diff = df['abs_diff'].mean()
    std_abs_diff = df['abs_diff'].std()
    mse = ((df['r_obs'] - df[column]) ** 2).mean()
    return round(mean_abs_diff, 3), round(std_abs_diff, 3), round(sign_accuracy, 3), round(mse, 3)

def get_density_and_NLL(df: pd.DataFrame):
    epsilon = 1e-9
    df['density_value'] = df['density_value'].clip(lower=epsilon)
    df['NLL'] = -np.log(df['density_value'])
    avg_NLL = df['NLL'].mean()
    std_NLL = df['NLL'].std()   
    sum_NLL = df['NLL'].sum()
    return round(df['density_value'].mean(), 2), round(df['density_value'].std(), 2), round(avg_NLL, 2), round(std_NLL, 2)

def get_ci_coverage(df: pd.DataFrame):
    """
    Calculate the coverage of the 95% CI for the given DataFrame.
    """
    # Count the number of rows where the observed value is within the CI
    df['in_ci'] = (df['r_obs'] >= df['est_lower']) & (df['r_obs'] <= df['est_upper'])
    ci_coverage = df['in_ci'].mean()
    return round(ci_coverage, 3)


def plot_multi_violin_box(
    name_to_df: dict,
    columns: list[str] | None = None,
    *,
    palette_name: str | list = "colorblind", 
    out_dir: str = "figures",
    figsize: tuple[int, int] = (24, 6),
):
    """
    Draws a 1 × 4 grid of violin + box plots.
    Each method (row key in ``name_to_df``) gets its own colour.
    X-tick labels are suppressed; colours are explained in a legend.
    """
    os.makedirs(out_dir, exist_ok=True)

    if columns is None:
        # columns = ["mode", "abs_diff", "density_value", "NLL"]
        columns = ["abs_diff", "density_value", "NLL"]

    # ── build a colour palette (skip UNIFORM) ───────────────────────────────────
    method_names = [
        m.name.replace("_", " ")
        for m in Method
        if m.name in [k.name for k in name_to_df.keys()]
    ]

    palette_vals = (
        sns.color_palette(palette_name, n_colors=len(method_names))
        if isinstance(palette_name, str)
        else sns.color_palette(palette_name)
    )
    palette = dict(zip(method_names, palette_vals))

    legend_patches = [
        mpatches.Patch(color=clr, label=lbl) for lbl, clr in palette.items()
    ]
  
    # patches for legend (one per method)
    legend_patches = [
        mpatches.Patch(color=clr, label=lbl) for lbl, clr in palette.items()
    ]

    rows, cols = 1, len(columns)
  
    fig, axes = plt.subplots(rows, cols, figsize=figsize, sharex=False)
    axes = axes.flatten()

    for ax, col in zip(axes, columns):
        # long form ↓
        records = []
        for lbl_enum, df in name_to_df.items():
            if lbl_enum is Method.Uniform and col != 'abs_diff':
                continue
            label_str = lbl_enum.name.replace("_", " ")
            for v in df[col]:
                records.append((label_str, v))
        df_long = pd.DataFrame(records, columns=["Group", "Value"])

        # violin – coloured by Group
        sns.violinplot(
            x="Group",
            y="Value",
            data=df_long,
            palette=palette,
            inner=None,
            linewidth=1,
            cut=0,
            alpha=0.5,
            ax=ax,
        )

        # transparent box (same palette)
        sns.boxplot(
            x="Group",
            y="Value",
            data=df_long,
            palette=palette,
            width=0.25,
            showcaps=True,
            boxprops=dict(facecolor="none", edgecolor="black", linewidth=1.2),
            whiskerprops=dict(color="black", linewidth=1.2),
            capprops=dict(color="black", linewidth=1.2),
            medianprops=dict(color="darkorange", linewidth=2),
            showfliers=False,
            ax=ax,
        )
        # annotate medians
        for grp, med in df_long.groupby("Group")["Value"].median().items():
            if col == 'abs_diff':
                x_idx = method_names.index(grp)
            else:
                x_idx = method_names.index(grp) -1
            y_text = med * (1.03 if ax.get_yscale() == "linear" else 1.1)
            ax.text(
                x_idx,
                y_text,
                f"{med:.2f}",
                ha="center",
                va="bottom",
                fontsize=23,
                fontweight="bold",
            )

        # axis cosmetics
        if col == "density_value":
            ax.set_yscale("log")
            ax.set_ylabel("Density p(r)  (log scale)")
        else:
            ax.set_ylabel(
                {
                    "mode": r"Mode $\hat{r}$",
                    "est_mean": "Distribution Mean  E[r]",
                    "abs_diff": "Absolute Difference",
                    "NLL": "Information Content",
                }.get(col, col)
            )
        ax.set_xlabel("")
        ax.set_xticks([])           # ← remove the ticks entirely

    for idx, ax in enumerate(axes):
        letter = chr(ord('a') + idx)
        ax.text(
            0.5, -0.05, f"({letter})", transform=ax.transAxes,
            ha='center', va='top', fontweight='bold'
        )

    # legend (centred above all subplots)
    fig.legend(
        handles=legend_patches,
        loc="upper center",
        ncol=len(legend_patches),
        frameon=True,
        bbox_to_anchor=(0.5, 1.1),
    )

    plt.tight_layout()
    path = os.path.join(out_dir, "multi_violin_box.png")
    path2 = os.path.join(out_dir, "multi_violin_box.pdf")
    fig.savefig(path, dpi=300, bbox_inches="tight")
    fig.savefig(path2, bbox_inches="tight")
    plt.close(fig)
    return path

def get_prior_obejct(prior: Method, ignore_zero: bool):
    if prior == Method.Gaussian:
        prior_model = TruncatedGaussian(agent=None)
    elif prior == Method.KDE:
        prior_model = KDEPrior(agent=None, ignore_zero=ignore_zero)
    elif prior == Method.LCP:
        prior_model = LCPrior(agent=None, ignore_zero=ignore_zero)
    elif prior == Method.Uniform:
        prior_model = None
    return prior_model

def parse_args():
    parser = argparse.ArgumentParser(
        description="Evaluate the performance of different priors on a benchmark dataset."
    )
    parser.add_argument(
        '--benchmark_name',
        type=str,
        default='benchmark_2096',
        help='Name of the benchmark (without .csv extension).'
    )
    parser.add_argument(
        '--output_dir',
        type=str,
        default=None,
        help='Path to the directory for storing outputs. Defaults to outputs/<benchmark_name>/. '
    )
    parser.add_argument(
        '--num_iter',
        type=int,
        default=1,
        help='Number of iterations to run.'
    )
    parser.add_argument(
        '--model_type',
        type=str,
        default='gpt-4o',
        help='LLM model to use (e.g., gpt-4o, gpt-4o-mini).'
    )
    parser.add_argument(
        '--priors',
        type=str,
        default='Uniform,Gaussian,KDE,LCP',
        help=(
            'Comma-separated list of prior methods to apply. '
            'Options: ' + ','.join([m.name for m in Method])
        )
    )
    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()
    benchmark_name = args.benchmark_name
    output_folder = args.output_dir or f"outputs/{benchmark_name}/"
    os.makedirs(output_folder, exist_ok=True)

    input_file = f"benchmark/{benchmark_name}.csv"

    prior_names = [name.strip() for name in args.priors.split(',') if name.strip()]
    priors = []
    for name in prior_names:
        try:
            priors.append(Method[name])
        except KeyError:
            raise ValueError(f"Unknown prior method: {name}")
    prior_data_map = {}
    prior_object_map = {}
    for prior in priors:
        print(prior)
        df = read_output_file(output_folder, prior, args.model_type, input_file, filter_zero_corr=False, understood_only=False, not_defauling_zero_only=False)
        print(len(df))
        prior_object = get_prior_obejct(prior, ignore_zero=False)
        prior_object_map[prior] = prior_object
        if prior == Method.Uniform:
            df['est_mean'] = 0
            df['mode'] = 0
            df['est_lower'] = -1 + 0.025*2
            df['est_upper'] = 1 - 0.025*2
            df['density_value'] = 0.5
        elif prior == Method.LCP or prior == Method.KDE:
            bw = 0.4
            summary_stats = df['distribution'].apply(lambda x: pd.Series(
                prior_object.get_summary_stats(x, bw=bw, q_vals=[0.025, 0.975])))
            df['density_value'] = df.apply(lambda row: prior_object.get_density_at(row['r_obs'], row['distribution'], bw=bw), axis=1)
            if prior == Method.KDE:
                df['bw'] = df.apply(lambda row: prior_object.get_bandwidth(row['distribution']), axis=1)
            summary_stats.columns = ['mode', 'est_mean', 'est_lower', 'est_upper']
            df = pd.concat([df, summary_stats], axis=1)
        elif prior == Method.Gaussian:
            summary_stats = df.apply(lambda row: pd.Series(
                prior_object.get_summary_stats(row['predicted_coef'], row['predicted_std'], q_vals=[0.025, 0.975])), axis=1)
            df['density_value'] = df.apply(lambda row: prior_object.get_density_at(row['r_obs'], row['predicted_coef'], row['predicted_std']), axis=1)
            summary_stats.columns = ['est_mean', 'est_lower', 'est_upper']
            df = pd.concat([df, summary_stats], axis=1)
            df['mode'] = df['est_mean']
        prior_data_map[prior] = df
    # get sign accuracy and MSE for each prior
    table1 = PrettyTable()
    # table1.field_names = ["Method", "Sign Accuracy", "abs diff", "MSE", "avg(p(r))", "std(p(r))", "avg(NLL)", "sum(NLL)", "95% CI coverage"]
    table1.field_names = [
        "Method",
        "Sign Accuracy",
        "Error",
        "p(r)",
        "Information Content",
        "95% coverage"
    ]
    for prior, df in prior_data_map.items():
        mean_abs_diff, std_abs_diff, sign_accuracy, mse = get_sign_accuracy_and_MSE(df, prior, 'mode')
        pr_mean, pr_std, avg_NLL, std_NLL = get_density_and_NLL(df)
        ci_coverage = get_ci_coverage(df)
        # format them into "mean ± std"
        diff_str = f"{mean_abs_diff:.2f}±{std_abs_diff:.2f}"
        pr_str   = f"{pr_mean:.2f}±{pr_std:.2f}"
        nll_str  = f"{avg_NLL:.2f}±{std_NLL:.2f}"

        table1.add_row([
            prior.name,
            f"{sign_accuracy:.3f}",
            diff_str,
            pr_str,
            nll_str,
            f"{ci_coverage:.1%}"
        ])

    for prior, df in prior_data_map.items():
        prior_data_map[prior] = df
    custom = ["#4C72B0", "#DD8452", "#55A868"]
    plot_multi_violin_box(prior_data_map, palette_name='Set2')

    data = table1.get_csv_string()
    with open(f'table_{benchmark_name}.csv', 'w') as f:
        f.write(data)
    print(table1)
    



