# %%
import json
from typing import Dict, List

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from eliciting_contexts.utils.constants import BASE_DIR


def create_results_dataframe(results_dict: Dict) -> pd.DataFrame:
    """Convert results dictionary into a long-form pandas DataFrame."""
    records = []

    for model_name, splits in results_dict.items():
        # Clean up model name for display
        display_name = model_name.split("/")[-1]

        for split, metrics in splits.items():
            # Add each individual result
            for condition in ["normal", "sandbagging"]:
                for answer in metrics[f"{condition}_answers"]:
                    records.append(
                        {
                            "model": display_name,
                            "split": split,
                            "condition": condition.capitalize(),
                            "correct": int(answer["is_correct"]),
                            "gave_response": int(answer["is_an_answer"]),
                        }
                    )

    return pd.DataFrame(records)


def plot_results(df: pd.DataFrame, output_prefix: str = "figures/results"):
    """Create bar plots with bootstrap error bars for different metrics."""
    metrics_to_plot = [
        ("All Responses", "correct", "Accuracy (All)"),
        ("Response Rate", "gave_response", "Response Rate"),
        ("Valid Responses Only", "correct", "Accuracy (Excluding Non-responses)"),
    ]

    for plot_name, metric, ylabel in metrics_to_plot:
        plt.figure(figsize=(10, 6))

        # Set style
        sns.set_style("whitegrid")
        # plt.rcParams['font.family'] = 'sans-serif'
        # plt.rcParams['font.sans-serif'] = ['Arial']

        # Create grouped bar plot
        ax = sns.barplot(
            data=(
                df[df["gave_response"] == 1]
                if "Valid Responses Only" in plot_name
                else df
            ),
            x="model",
            y=metric,
            hue="condition",
            errorbar="ci",
            n_boot=1000,
            capsize=0.1,
            alpha=0.8,
            palette=["#2ecc71", "#e74c3c"],  # Green for normal, red for sandbagging
        )

        # Customize plot
        plt.title(f"Gemma Model Performance: {plot_name}", pad=20, fontsize=14)
        plt.xlabel("Model", fontsize=12)
        plt.ylabel(ylabel, fontsize=12)

        # Format y-axis as percentage
        ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: "{:.0%}".format(y)))

        # Add grid lines
        plt.grid(True, axis="y", alpha=0.3)

        # Rotate x-axis labels
        plt.xticks(rotation=-20, ha="left")

        # Adjust legend
        plt.legend(title="Condition", bbox_to_anchor=(1.05, 1), loc="upper left")

        # Adjust layout and save
        plt.tight_layout()
        filename = f"{output_prefix}_{plot_name.lower().replace(' ', '_')}.png"
        plt.savefig(filename, dpi=300, bbox_inches="tight")
        plt.show()
        plt.close()


def print_summary_statistics(df: pd.DataFrame):
    """Print summary statistics with bootstrap confidence intervals."""
    print("\nSummary Statistics (with 95% confidence intervals):")

    for model in df["model"].unique():
        print(f"\nModel: {model}")
        for condition in ["Normal", "Sandbagging"]:
            model_condition_data = df[
                (df["model"] == model) & (df["condition"] == condition)
            ]

            # Calculate metrics
            metrics = {
                "Accuracy (All)": model_condition_data["correct"].mean(),
                "Response Rate": model_condition_data["gave_response"].mean(),
                "Accuracy (Excluding Non-responses)": model_condition_data[
                    model_condition_data["gave_response"] == 1
                ]["correct"].mean(),
            }

            print(f"\n{condition} Condition:")
            for metric_name, value in metrics.items():
                # Bootstrap confidence interval
                bootstrap_means = [
                    np.mean(
                        np.random.choice(
                            model_condition_data["correct"],
                            size=len(model_condition_data),
                            replace=True,
                        )
                    )
                    for _ in range(1000)
                ]
                ci_lower, ci_upper = np.percentile(bootstrap_means, [2.5, 97.5])
                print(f"{metric_name}: {value:.1%} ({ci_lower:.1%} - {ci_upper:.1%})")


def main():
    # Load results
    with open(BASE_DIR / "sandbagging_evaluation.json", "r") as f:
        results = json.load(f)

    # Create DataFrame
    df = create_results_dataframe(results)

    # Print summary statistics
    print_summary_statistics(df)
    prefix = "figures/results"
    # Create plots
    plot_results(df)
    print("\nPlots saved with prefix 'figures/results'")


if __name__ == "__main__":
    main()

# %%
