import os
import glob
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from ast import literal_eval

data_folder = "./"

# metrics = ["yacc","yf2","cacc","cf1","collapse","betaf1","betaacc","nll"]
is_clevr = False

if is_clevr:
    metrics = ["yf1", "yacc", "betaf1","cf1","collapse"]
else:
    metrics = ["cf1", "betaf1", "collapse"] #["yf1", "yacc","betaf1","cf1"]#,"collapse"]

def parse_supervision_type(file_name):
    if "entropy" in file_name:
        return "entropy"
    elif "rec" in file_name:
        return "rec"
    elif "k_sup" in file_name:
        return "k_sup"
    else:
        return "unknown"

def load_data_v2(model_name, is_clevr=False):
    data = []

    for perc_k in [
        0.0,
        0.2,
        0.4,
        0.6,
        0.8,
        1.0
    ]:

        strategies = [
            ("mnistcbm", f"which_c_[-2]_c_sup_1.0_k_sup_{perc_k}_entropy_1.0_multi_linear_optimal", "entropy"),
            ("mnistcbmrec", f"which_c_[-2]_c_sup_1_k_sup_{perc_k}_rec_0.01_multi_linear_optimal", "reconstruction"),
            ("mnistcbm", f"which_c_[-2]_c_sup_1.0_k_sup_{perc_k}_contrastive_0.1_multi_linear_optimal", "contrastive"),
            ("mnistcbmrec", f"which_c_[-2]_c_sup_1_k_sup_{perc_k}_entropy_1.0_rec_0.01_contrastive_0.1_multi_linear_optimal", "H+R+C"),
            ("mnistcbm", f"which_c_[-2]_c_sup_1.0_k_sup_{perc_k}_multi_linear_optimal", "0% c-sup"),
            ("mnistcbm", f"which_c_[3, 4]_c_sup_1.0_k_sup_{perc_k}_multi_linear_optimal", "20% c-sup"),
            ("mnistcbm", f"which_c_[3, 4, 1, 0]_c_sup_1.0_k_sup_{perc_k}_multi_linear_optimal", "40% c-sup"),
            ("mnistcbm", f"which_c_[3, 4, 1, 0, 2, 8]_c_sup_1.0_k_sup_{perc_k}_multi_linear_optimal", "60% c-sup"),
            ("mnistcbm", f"which_c_[3, 4, 1, 0, 2, 8, 9, 5]_c_sup_1.0_k_sup_{perc_k}_multi_linear_optimal", "80% c-sup"),
            ("mnistcbm", f"which_c_[3, 4, 1, 0, 2, 8, 9, 5, 6, 7]_c_sup_1.0_k_sup_{perc_k}_multi_linear_optimal", "100% c-sup"),
        ]

        if model_name != "cbm":
            raise NotImplementedError

        for model, strategy, supervision_type in strategies:

            metric_results = []

            if is_clevr:
                model = model.replace("mnist", "clevr")

            for seed in [2021, 2223, 2425, 2627, 2829, 3031, 3233, 3435, 3738, 3839, 4041, 4243, 4445, 4647, 4849]: #[1011, 1213, 1415, 1617, 1819]:
                if is_clevr:
                    base_filename = f"best_models_clevr/clevr_{model}_clevr_{seed}_{strategy}.csv"
                else:
                    base_filename = f"best_models_addmnist/addmnist_{model}_sumparity_{seed}_{strategy}.csv"
                
                if os.path.exists(base_filename):
                    df = pd.read_csv(base_filename)
                    metric_results.append(pd.DataFrame(df[df.index.isin([0])]))
                else:
                    print(f"File not found: {base_filename}")

            if metric_results:
                combined_df = pd.concat(metric_results, ignore_index=True)
                
                mean_values = combined_df.mean()
                std_values = combined_df.std()

                row = {
                    "which_k_count": int(perc_k * 10),
                    "supervision_type": supervision_type,
                }
                for metric in metrics:
                    row[f"{metric}_mean"] = mean_values[metric]
                    row[f"{metric}_std"] = std_values[metric]

                data.append(row)
    
    data = pd.DataFrame(data)
    return data


def load_data(model_name, is_clevr=False):
    data, data_ood = [], []

    for which_c in [
        [-2],
        [3, 4],
        [3, 4, 1, 0],
        [3, 4, 1, 0, 2, 8],
        [3, 4, 1, 0, 2, 8, 9, 5],
        [3, 4, 1, 0, 2, 8, 9, 5, 6, 7],
    ]:
        which_c_count = len(which_c) if which_c != [-2] else 0  # Handle -2 as 0 concepts supervised

        # strategies = [
        #     ("mnistcbm", "c_sup_1_entropy_5.0_multi_linear", "entropy"),
        #     ("mnistcbmrec", "c_sup_1_rec_5.0_multi_linear", "reconstruction"),
        #     ("mnistcbm", "c_sup_1.0_contrastive_8.0_multi_linear", "contrastive"),
        #     ("mnistcbmrec", "c_sup_1_entropy_5.0_rec_5.0_contrastive_8.0_multi_linear", "H+R+C"),
        #     ("mnistcbm", "c_sup_5.0_k_sup_0.0_multi_linear", "0% k sup"),
        #     ("mnistcbm", "c_sup_5.0_k_sup_0.2_multi_linear", "20% k sup"),
        #     ("mnistcbm", "c_sup_5.0_k_sup_0.4_multi_linear", "40% k sup"),
        #     ("mnistcbm", "c_sup_5.0_k_sup_0.6_multi_linear", "60% k sup"),
        #     ("mnistcbm", "c_sup_5.0_k_sup_0.8_multi_linear", "80% k sup"),
        #     ("mnistcbm", "c_sup_5.0_k_sup_1.0_multi_linear", "100% k sup"),
        # ]

        strategies = [
            ("mnistcbm", "c_sup_1_entropy_1.0_multi_linear_optimal", "entropy"),
            ("mnistcbmrec", "c_sup_1_rec_0.01_multi_linear_optimal", "reconstruction"),
            ("mnistcbm", "c_sup_1.0_contrastive_0.1_multi_linear_optimal", "contrastive"),
            ("mnistcbmrec", "c_sup_1_entropy_1.0_rec_0.01_contrastive_0.1_multi_linear_optimal", "H+R+C"),
            ("mnistcbm", "c_sup_1.0_k_sup_0.0_multi_linear_optimal", "0% k-sup"),
            ("mnistcbm", "c_sup_1.0_k_sup_0.2_multi_linear_optimal", "20% k-sup"),
            ("mnistcbm", "c_sup_1.0_k_sup_0.4_multi_linear_optimal", "40% k-sup"),
            ("mnistcbm", "c_sup_1.0_k_sup_0.6_multi_linear_optimal", "60% k-sup"),
            ("mnistcbm", "c_sup_1.0_k_sup_0.8_multi_linear_optimal", "80% k-sup"),
            ("mnistcbm", "c_sup_1.0_k_sup_1.0_multi_linear_optimal", "100% k-sup"),
        ]

        if model_name != "cbm":
            raise NotImplementedError
            # strategies = [
            #     ("mnistdsldpl", "c_sup_1_entropy_2.0", "entropy"),
            #     ("mnistdsldplrec", "c_sup_1_rec_8.0", "reconstruction"),
            #     ("mnistdsldpl", "c_sup_1.0_contrastive_10.0", "contrastive"),
            #     ("mnistdsldplrec", "c_sup_1_entropy_2.0_rec_8.0_contrastive_10.0", "H+R+C"),
            #     ("mnistdsldpl", "c_sup_1.0_k_sup_0.0", "0% k sup"),
            #     ("mnistdsldpl", "c_sup_1.0_k_sup_0.2", "20% k sup"),
            #     ("mnistdsldpl", "c_sup_1.0_k_sup_0.4", "40% k sup"),
            #     ("mnistdsldpl", "c_sup_1.0_k_sup_0.6", "60% k sup"),
            #     ("mnistdsldpl", "c_sup_1.0_k_sup_0.8", "80% k sup"),
            #     ("mnistdsldpl", "c_sup_1.0_k_sup_1.0", "100% k sup"),
            # ]

        for model, strategy, supervision_type in strategies:

            metric_results, ood_metrics = [], []

            if is_clevr:
                model = model.replace("mnist", "clevr")

            for seed in [2021, 2223, 2425, 2627, 2829, 3031, 3233, 3435, 3738, 3839, 4041, 4243, 4445, 4647, 4849]: #[1011, 1213, 1415, 1617, 1819]:
                if is_clevr:
                    base_filename = f"best_models_clevr/clevr_{model}_clevr_{seed}_which_c_{which_c}_{strategy}.csv"
                else:
                    base_filename = f"best_models_addmnist/addmnist_{model}_sumparity_{seed}_which_c_{which_c}_{strategy}.csv"
                
                if os.path.exists(base_filename):
                    df = pd.read_csv(base_filename)
                    metric_results.append(pd.DataFrame(df[df.index.isin([0])]))
                else:
                    print(f"File not found: {base_filename}")

            if metric_results:
                combined_df = pd.concat(metric_results, ignore_index=True)
                
                mean_values = combined_df.mean()
                std_values = combined_df.std()

                row = {
                    "which_c_count": which_c_count,
                    "supervision_type": supervision_type,
                }
                for metric in metrics:
                    row[f"{metric}_mean"] = mean_values[metric]
                    row[f"{metric}_std"] = std_values[metric]

                data.append(row)

            if ood_metrics:
                combined_df_ood = pd.concat(ood_metrics, ignore_index=True)
            
                mean_values = combined_df.mean()
                std_values = combined_df.std()

                row = {
                    "which_c_count": which_c_count,
                    "supervision_type": supervision_type,
                }
                for metric in metrics:
                    row[f"{metric}_mean"] = mean_values[metric]
                    row[f"{metric}_std"] = std_values[metric]

                combined_df_ood.append(row)

    return pd.DataFrame(data), pd.DataFrame(data_ood)

def plot_metrics(data, metrics, data_folder, model_name, is_clevr, prefix=""):
    sns.set(style="whitegrid", palette="muted", font_scale=1.2)
    
    for metric in metrics:
        plt.figure(figsize=(16, 9))

        plot_data = data.copy()
        plot_data["y_value"] = plot_data[f"{metric}_mean"]
        plot_data["y_lower"] = plot_data[f"{metric}_mean"] - plot_data[f"{metric}_std"]
        plot_data["y_upper"] = plot_data[f"{metric}_mean"] + plot_data[f"{metric}_std"]

        def reorderLegend(ax=None,order=None,unique=False):
            if ax is None: ax=plt.gca()
            handles, labels = ax.get_legend_handles_labels()
            labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0])) # sort both labels and handles by labels
            if order is not None: # Sort according to a given list (not necessarily complete)
                keys=dict(zip(order,range(len(order))))
                labels, handles = zip(*sorted(zip(labels, handles), key=lambda t,keys=keys: keys.get(t[0],np.inf)))
            if unique:  labels, handles= zip(*unique_everseen(zip(labels,handles), key = labels)) # Keep only the first of each handle
            ax.legend(handles, labels, fontsize=22, loc="lower right")
            return(handles, labels)

        def unique_everseen(seq, key=None):
            seen = set()
            seen_add = seen.add
            return [x for x,k in zip(seq,key) if not (k in seen or seen_add(k))]

        color_map = {
            "entropy": "#E74C3C",  # Red
            "reconstruction": "#2CA02C",  # Green
            "contrastive": "black",
            "H+R+C": "#8B4513",  # Brown (Saddle Brown)
            "0% k-sup": "#D4E6F1",  # Light blue
            "20% k-sup": "#A9CCE3", # Sky blue
            "40% k-sup": "#7FB3D5", # Steel blue
            "60% k-sup": "#5DADE2", # Royal blue
            "80% k-sup": "#3498DB", # Medium blue
            "100% k-sup": "#1F77B4", # Dark blue
        }

        legend_order = [
            "entropy",
            "reconstruction",
            "contrastive",
            "H+R+C",
            "0% k-sup",
            "20% k-sup",
            "40% k-sup",
            "60% k-sup",
            "80% k-sup",
            "100% k-sup",
        ]

        for supervision_type, subset in plot_data.groupby("supervision_type"):
            plt.plot(
                subset["which_c_count"],
                subset["y_value"],
                marker="o",
                label=supervision_type.capitalize(),
                linewidth=2.5,
                color=color_map[supervision_type]
            )
            plt.fill_between(
                subset["which_c_count"],
                subset["y_lower"],
                subset["y_upper"],
                alpha=0.1,
                color=color_map[supervision_type]
            )

        name_metric = metric.upper() if metric != "yf2" else "YF1"
        name_metric = name_metric.replace("BETA", r"$\beta$ ")
        name_metric = name_metric.replace("CF1", r"C F1")

        plt.title(
            f"{prefix if prefix == '' else prefix + ' '}{name_metric} vs % of Concept Supervision", 
            fontsize=30, fontweight="bold"
        )
        plt.xlabel("% of Supervised Concepts", fontsize=22)
        plt.ylabel(name_metric, fontsize=22, fontweight="bold")
        plt.ylim(0, 1)

        plt.xticks([0, 2, 4, 6, 8, 10], ["0", "20", "40", "60", "80", "100"], fontsize=22)
        plt.yticks(fontsize=22)

        plt.legend(
            title="Strategy"
        )
        reorderLegend(plt.gca(), legend_order)
        plt.grid(visible=True, linestyle="--", alpha=0.4)
        plt.ylim(top=1)
        plt.tight_layout()
        if is_clevr:
            path = f"clevr_{prefix if prefix == '' else prefix + '_'}{model_name}_{metric}_plot.pdf"
        else:
            path = f"{prefix if prefix == '' else prefix + '_'}{model_name}_{metric}_plot.pdf"
        plot_path = os.path.join(data_folder, path)
        plt.savefig(plot_path, dpi=300)
        print(f"Saved plot to {plot_path}")

        plt.close()


def plot_metrics_reverse(data, metrics, data_folder, model_name, is_clevr, prefix=""):
    sns.set(style="whitegrid", palette="muted", font_scale=1.2)
    
    for metric in metrics:
        plt.figure(figsize=(16, 9))

        plot_data = data.copy()
        plot_data["y_value"] = plot_data[f"{metric}_mean"]
        plot_data["y_lower"] = plot_data[f"{metric}_mean"] - plot_data[f"{metric}_std"]
        plot_data["y_upper"] = plot_data[f"{metric}_mean"] + plot_data[f"{metric}_std"]

        def reorderLegend(ax=None,order=None,unique=False):
            if ax is None: ax=plt.gca()
            handles, labels = ax.get_legend_handles_labels()
            labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0])) # sort both labels and handles by labels
            if order is not None: # Sort according to a given list (not necessarily complete)
                keys=dict(zip(order,range(len(order))))
                labels, handles = zip(*sorted(zip(labels, handles), key=lambda t,keys=keys: keys.get(t[0],np.inf)))
            if unique:  labels, handles= zip(*unique_everseen(zip(labels,handles), key = labels)) # Keep only the first of each handle
            ax.legend(handles, labels, fontsize=22, loc="lower right")
            return(handles, labels)

        def unique_everseen(seq, key=None):
            seen = set()
            seen_add = seen.add
            return [x for x,k in zip(seq,key) if not (k in seen or seen_add(k))]

        color_map = {
            "entropy": "#E74C3C",  # Red
            "reconstruction": "#2CA02C",  # Green
            "contrastive": "black",
            "H+R+C": "#8B4513",  # Brown (Saddle Brown)
            "0% c-sup": "#D4E6F1",  # Light blue
            "20% c-sup": "#A9CCE3", # Sky blue
            "40% c-sup": "#7FB3D5", # Steel blue
            "60% c-sup": "#5DADE2", # Royal blue
            "80% c-sup": "#3498DB", # Medium blue
            "100% c-sup": "#1F77B4", # Dark blue
        }

        legend_order = [
            "entropy",
            "reconstruction",
            "contrastive",
            "H+R+C",
            "0% c-sup",
            "20% c-sup",
            "40% c-sup",
            "60% c-sup",
            "80% c-sup",
            "100% c-sup",
        ]

        for supervision_type, subset in plot_data.groupby("supervision_type"):
            plt.plot(
                subset["which_k_count"],
                subset["y_value"],
                marker="o",
                label=supervision_type.capitalize(),
                linewidth=2.5,
                color=color_map[supervision_type]
            )
            plt.fill_between(
                subset["which_k_count"],
                subset["y_lower"],
                subset["y_upper"],
                alpha=0.1,
                color=color_map[supervision_type]
            )

        name_metric = metric.upper() if metric != "yf2" else "YF1"
        name_metric = name_metric.replace("BETA", r"$\beta$ ")
        name_metric = name_metric.replace("CF1", r"C F1")

        plt.title(
            f"{prefix if prefix == '' else prefix + ' '}{name_metric} vs % of Concept Supervision", 
            fontsize=30, fontweight="bold"
        )
        plt.xlabel("% of Supervised Knowledge", fontsize=22)
        plt.ylabel(name_metric, fontsize=22, fontweight="bold")
        plt.ylim(0, 1)

        plt.xticks([0, 2, 4, 6, 8, 10], ["0", "20", "40", "60", "80", "100"], fontsize=22)
        plt.yticks(fontsize=22)

        plt.legend(
            title="Strategy"
        )
        reorderLegend(plt.gca(), legend_order)
        plt.grid(visible=True, linestyle="--", alpha=0.4)
        plt.ylim(top=1)
        plt.tight_layout()
        if is_clevr:
            path = f"clevr_{prefix if prefix == '' else prefix + '_'}{model_name}_{metric}_plot.pdf"
        else:
            path = f"{prefix if prefix == '' else prefix + '_'}{model_name}_{metric}_plot_knowledge.pdf"
        plot_path = os.path.join(data_folder, path)
        plt.savefig(plot_path, dpi=300)
        print(f"Saved plot to {plot_path}")

        plt.close()


if __name__ == "__main__":
    # Load and preprocess the data
    model_name = "cbm"

    data, data_ood = load_data(model_name, is_clevr)
    reversed_data = load_data_v2(model_name, is_clevr)

    if data.empty:
        print("No valid data files found. Ensure the folder contains valid CSVs.")
    else:
    #     # Generate and save the plots
        plot_metrics(data, metrics, "plots", model_name, is_clevr)
        if not data_ood.empty:
            plot_metrics(data_ood, metrics, "plots", model_name, "OOD")

    if not reversed_data.empty:
        plot_metrics_reverse(reversed_data, metrics, "plots", model_name, is_clevr)
