import os
import glob
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from ast import literal_eval

data_folder = "./"

metrics = ["cf1", "betaf1", "collapse"]


def load_data(model_name, is_clevr=False):
    data, data_ood = [], []

    for which_c in [
        [-2],
        [1],
        [2],
        [3],
        [4],
    ]:
        which_c_count = which_c[0] if which_c != [-2] else 0  # Handle -2 as 0 concepts supervised

        strategies = [
            ("boiacbm", "c_sup_1_entropy_1.0_multi_linear", "entropy"),
            ("boiacbm", "c_sup_1.0_contrastive_1.0_multi_linear", "contrastive"),
            ("boiacbm", "c_sup_1.0_k_sup_0.0_multi_linear", "0% k-sup"),
            ("boiacbm", "c_sup_1.0_k_sup_0.2_multi_linear", "20% k-sup"),
            ("boiacbm", "c_sup_1.0_k_sup_0.4_multi_linear", "40% k-sup"),
            ("boiacbm", "c_sup_1.0_k_sup_0.6_multi_linear", "60% k-sup"),
            ("boiacbm", "c_sup_1.0_k_sup_0.8_multi_linear", "80% k-sup"),
            ("boiacbm", "c_sup_1.0_k_sup_1.0_multi_linear", "100% k-sup"),
        ]

        if model_name != "cbm":
            raise NotImplementedError

        for model, strategy, supervision_type in strategies:

            metric_results, ood_metrics = [], []

            for seed in [1011, 1213, 1415, 1617, 1819]:
                base_filename = f"best_models_boia/boia_{model}_boia_{seed}_which_c_{which_c}_{strategy}.csv"
               
                if os.path.exists(base_filename):
                    df = pd.read_csv(base_filename)
                    metric_results.append(pd.DataFrame(df[df.index.isin([0])]))
                else:
                    print(f"File not found: {base_filename}")
                    quit()

            if metric_results:
                combined_df = pd.concat(metric_results, ignore_index=True)
                
                mean_values = combined_df.mean()
                std_values = combined_df.std()

                row = {
                    "which_c_count": which_c_count,
                    "supervision_type": supervision_type,
                }
                for metric in metrics:
                    row[f"{metric}_mean"] = mean_values[metric]
                    row[f"{metric}_std"] = std_values[metric]

                data.append(row)

            if ood_metrics:
                combined_df_ood = pd.concat(ood_metrics, ignore_index=True)
            
                mean_values = combined_df.mean()
                std_values = combined_df.std()

                row = {
                    "which_c_count": which_c_count,
                    "supervision_type": supervision_type,
                }
                for metric in metrics:
                    row[f"{metric}_mean"] = mean_values[metric]
                    row[f"{metric}_std"] = std_values[metric]

                combined_df_ood.append(row)

    return pd.DataFrame(data), pd.DataFrame(data_ood)


def plot_metrics(data, metrics, data_folder, model_name, is_clevr=False, prefix=""):
    sns.set(style="whitegrid", palette="muted", font_scale=1.5)  # Increase text size
    
    fig, axes = plt.subplots(1, len(metrics), figsize=(6 * len(metrics), 4), sharey=True)
    
    if len(metrics) == 1:
        axes = [axes]
    
    plot_data = data.copy()

    color_map = {
        "entropy": "#b50d12",  # Red
        "reconstruction": "#009E73",  # Green
        "contrastive": "#9966cc",  # Violet
        "H+R+C": "#ff8c00" ,  # Orange
        "0% k-sup": "#D4E6F1",  # Light Blue
        "20% k-sup": "#A9CCE3",
        "40% k-sup": "#7FB3D5",
        "60% k-sup": "#5DADE2",
        "80% k-sup": "#3498DB",
        "100% k-sup": "#1F77B4",  # Dark Blue
    }

    line_styles = {
        "entropy": "-",  # Solid
        "reconstruction": "--",  # Dashed
        "contrastive": "-.",  # Dash-dot
        "H+R+C": ":",  # Dotted
        "0% k-sup": "-",
        "20% k-sup": "--",
        "40% k-sup": "-.",
        "60% k-sup": ":",
        "80% k-sup": "-",
        "100% k-sup": "--",
    }

    markers = {
        "entropy": "o",
        "reconstruction": "s",
        "contrastive": "^",
        "H+R+C": "D",
        "0% k-sup": "o",
        "20% k-sup": "s",
        "40% k-sup": "^",
        "60% k-sup": "D",
        "80% k-sup": "o",
        "100% k-sup": "s",
    }
    
    legend_order = [
        "0% k-sup", "20% k-sup", "40% k-sup", "60% k-sup", "80% k-sup", "100% k-sup",
        "entropy", "reconstruction", "contrastive", "h+r+c"
    ]

    legend_handles = []
    legend_labels = []
    
    for i, metric in enumerate(metrics):
        ax = axes[i]
        plot_data["y_value"] = plot_data[f"{metric}_mean"]
        plot_data["y_lower"] = plot_data[f"{metric}_mean"] - plot_data[f"{metric}_std"]
        plot_data["y_upper"] = plot_data[f"{metric}_mean"] + plot_data[f"{metric}_std"]

        for supervision_type, subset in plot_data.groupby("supervision_type"):
            (line,) = ax.plot(
                subset["which_c_count"],
                subset["y_value"],
                linestyle=line_styles[supervision_type],
                marker=markers[supervision_type],
                linewidth=2.5,
                markersize=8,
                color=color_map[supervision_type],
                label=supervision_type.capitalize()
            )

            if supervision_type.capitalize() not in legend_labels:
                if supervision_type.lower() == "h+r+c" and "H+R+C" in legend_labels:
                    continue
                legend_handles.append(line)
                if supervision_type.lower() == "h+r+c":
                    legend_labels.append("H+R+C")
                else:
                    legend_labels.append(supervision_type.capitalize())
                
        name_metric = metric.upper().replace("BETA", r"$\beta$ ").replace("CF1", r"C F1")
        if name_metric == "COLLAPSE":
            name_metric = r"$\mathsf{Cls}(C)}$"
        elif name_metric == "$\\beta$ F1":
            name_metric = r"$\mathrm{F_1}(\beta)$"
        elif name_metric == "C F1":
            name_metric = r"$\mathrm{F_1}(C)$"
        print(name_metric)

        ax.text(
            0.05, 0.95,  # Adjust these values if needed (1.02 moves it slightly outside x-axis max, 0.95 places it near the top)
            name_metric,
            fontsize=22,
            fontweight="bold",
            transform=ax.transAxes,
            ha="left",  # Align to the right
            va="top"  # Align to the top
        )

        ax.set_xticks([0, 1, 2, 3, 4])
        ax.set_xlabel("C-SUP (%)", fontsize=20, labelpad=10)
        ax.set_xticklabels(["0", "25", "50", "75", "100"], fontsize=20)
        ax.tick_params(axis='y', labelsize=20)
        ax.grid(visible=True, linestyle="--", alpha=0.4)
        ax.set_ylim(0, 1)
    
    handles, labels = axes[-1].get_legend_handles_labels()
    ordered_handles_labels = sorted(zip(handles, labels), key=lambda x: legend_order.index(x[1].lower()))
    handles, labels = zip(*ordered_handles_labels)
    newLabels = []
    for l in labels:
        if l.lower() != "h+r+c":
            newLabels.append(l)
        else:
            newLabels.append("H+C+R")
    fig.legend(handles, newLabels, loc="center right", fontsize=18)
    plt.tight_layout(rect=[0, 0, 0.82, 1])
    
    if is_clevr:
        path = f"clevr_{prefix if prefix == '' else prefix + '_'}{model_name}_metrics_plot.pdf"
    else:
        path = f"{prefix if prefix == '' else prefix + '_'}{model_name}_metrics_plot.pdf"
    
    plot_path = os.path.join(data_folder, path)
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    print(f"Saved plot to {plot_path}")
    plt.close()


if __name__ == "__main__":
    # Load and preprocess the data
    model_name = "cbm"

    data, data_ood = load_data(model_name)

    if data.empty:
        print("No valid data files found. Ensure the folder contains valid CSVs.")
    else:
        # Generate and save the plots
        plot_metrics(data, metrics, "plots", model_name)
        if not data_ood.empty:
            plot_metrics(data_ood, metrics, "plots", model_name, "OOD")
