from pathlib import Path
import os
import json
import numpy as np
import pandas as pd
import time

import matplotlib.pyplot as plt
import seaborn as sns
import copy

from PIL import Image

# I would have loved to make this one a proper task, but time is limited and refactoring is slow.

base_path_association = "./data/association"
base_path_results = "./data/results"
path_output = "./data/boxplots"
association_list = [p for p in Path(base_path_association).glob("*") if (p/"association_full.csv").exists()]
Path(path_output).mkdir(parents=True, exist_ok=True)

max_seed = 20
threshold = 1e6
for threshold_abs, threshold_dst in [(0.8, 1.0e6), (0.8, 2.0e6)]:
    aggregated = []
    # aggregate all results into a single dataset
    for association_path in association_list:
        # experiment name
        name = association_path.name
        if "_MiniBatchKMeans_" in name: continue
        method = name.split("_")[0]
        model = [n for n in ["resnet18", "resnet34", "densenet121", "efficientnet_b0", "vgg16"] if n in name][0]
        dataset = [n for n in ["ABplus", "AB", "CO", "BigSmall", "colorGB", "isA", "leather", "metal_nut"] if n in name][0]
        if method=="ace":
            seed=int(name.split("_")[-1])
            training = name.split("_")[-2]
            variant="default"
        if (method=="eclad")or(method=="cshap"):
            seed=int(name.split("_")[-2])
            training = name.split("_")[-3]
            variant=name.split("_")[-1]
        if seed>max_seed:continue
        if (method=="space"):continue
        ####################### Load data
        # has results?
        if not (Path(base_path_results)/name/"results.json").exists(): 
            continue
        with open((Path(base_path_results)/name/"results.json").as_posix(), "rb") as f:
            results = json.load(f)
            if method=="ace": importance = [{"c":r["idx"], "name":r["name"], "importance":(r["score_mean"]*2-1), "concept_paths":r["concept_path"]} for r in results["results"]]
            if method=="eclad": importance = [{"c":r["idx"], "name":r["name"], "importance":r["RI"], "concept_paths":r["concept_paths"]} for r in results["results"]]
            if method=="cshap": importance = [{"c":r["idx"], "name":r["name"], "importance":r["shap_mean"], "concept_paths":r["concept_paths"]} for r in results["results"]]
        # has association?
        if not (Path(base_path_association)/name/"association.json").exists(): 
            continue
        with open((Path(base_path_association)/name/"association.json").as_posix(), "rb") as f:
            consolidated = json.load(f)
        ####################### compute correctness
        if not("name" in consolidated[0]):
            consolidated = [{**c, "name":f"c_{int(c['c']):02}"} for c in consolidated]
            #continue
        consolidated = [{
            **c,
            "importance":[i for i in importance if i["name"]==c["name"]][0]["importance"], 
            "concept_paths":[i for i in importance if i["name"]==c["name"]][0]["concept_paths"]} for c in consolidated]
        # get alignment
        dfc = pd.DataFrame(consolidated)
        important_primitives={
            "AB":[1,2],
            "ABplus":[1,2], 
            "CO":[1,2], 
            "BigSmall":[1,2], 
            "isA":[1,2], 
            "colorGB":[1], 
            "leather":[1,2,3,4,5,6], 
            "metal_nut":[1,2,3,4,5]
        }
        
        important_primitives=important_primitives[dataset]
        dfc["aligned"] = False
        dfc.loc[dfc["p"].isin(important_primitives)&(dfc["sym_dst"]<threshold_dst), "aligned"] = True
        data = dfc.sort_values(by=["sym_dst"], ascending=True).drop_duplicates(["name"])
        ##### select aligned and unaligned concepts
        data = data[data["n_c"]!=0]
        representation_correctness = - data[(data["aligned"]==True)]["sym_dst"].abs().mean()
        importance_correctness_unaligned = data[(data["aligned"]==False)]["importance"].abs().mean()
        importance_correctness_align = data[(data["aligned"]==True)]["importance"].abs().mean()
        max_importance = data["importance"].abs().max()
        if max_importance==0: max_importance=1.0
        importance_correctness_diff = (importance_correctness_align - importance_correctness_unaligned)/max_importance
        n_concepts_aligned = data[(data["aligned"]==True)]["n_c"].count()
        n_concepts_unaligned = data[(data["aligned"]==False)]["n_c"].count()
        n_concepts = data["n_c"].count()
        ratio = data["ratio"].mean()

        ########### other metrics adjusted_rand_score
        dfc["aligned"] = False
        dfc.loc[dfc["p"].isin(important_primitives)&(dfc["adjusted_rand_score"]>threshold_abs), "aligned"] = True
        data = dfc.sort_values(by=["adjusted_rand_score"], ascending=True).drop_duplicates(["name"])
        ##### select aligned and unaligned concepts
        data = data[data["n_c"]!=0]
        representation_correctness_ARS = - data[(data["aligned"]==True)]["adjusted_rand_score"].abs().mean()
        importance_correctness_unaligned_ARS = data[(data["aligned"]==False)]["importance"].abs().mean()
        importance_correctness_align_ARS = data[(data["aligned"]==True)]["importance"].abs().mean()
        max_importance_ARS = data["importance"].abs().max()
        if max_importance_ARS==0: max_importance_ARS=1.0
        importance_correctness_diff_ARS = (importance_correctness_align_ARS - importance_correctness_unaligned_ARS)/max_importance_ARS
        ########### other metrics jaccard_score
        dfc["aligned"] = False
        dfc.loc[dfc["p"].isin(important_primitives)&(dfc["jaccard_score"]>threshold_abs), "aligned"] = True
        data = dfc.sort_values(by=["jaccard_score"], ascending=True).drop_duplicates(["name"])
        ##### select aligned and unaligned concepts
        data = data[data["n_c"]!=0]
        representation_correctness_JS = - data[(data["aligned"]==True)]["jaccard_score"].abs().mean()
        importance_correctness_unaligned_JS = data[(data["aligned"]==False)]["importance"].abs().mean()
        importance_correctness_align_JS = data[(data["aligned"]==True)]["importance"].abs().mean()
        max_importance_JS = data["importance"].abs().max()
        if max_importance_JS==0: max_importance_JS=1.0
        importance_correctness_diff_JS = (importance_correctness_align_JS - importance_correctness_unaligned_JS)/max_importance_JS
        ########### other metrics normalized_mutual_info_score
        dfc["aligned"] = False
        dfc.loc[dfc["p"].isin(important_primitives)&(dfc["normalized_mutual_info_score"]>threshold_abs), "aligned"] = True
        data = dfc.sort_values(by=["normalized_mutual_info_score"], ascending=True).drop_duplicates(["name"])
        ##### select aligned and unaligned concepts
        data = data[data["n_c"]!=0]
        representation_correctness_NMIS = - data[(data["aligned"]==True)]["normalized_mutual_info_score"].abs().mean()
        importance_correctness_unaligned_NMIS = data[(data["aligned"]==False)]["importance"].abs().mean()
        importance_correctness_align_NMIS = data[(data["aligned"]==True)]["importance"].abs().mean()
        max_importance_NMIS = data["importance"].abs().max()
        if max_importance_NMIS==0: max_importance_NMIS=1.0
        importance_correctness_diff_NMIS = (importance_correctness_align_NMIS - importance_correctness_unaligned_NMIS)/max_importance_NMIS

        ######### aggregate
        #print(name)
        model = {
            "resnet18":"r18", 
            "resnet34":"r34", 
            "densenet121":"den", 
            "efficientnet_b0":"eff", 
            "vgg16":"vgg"}[model]
        aggregated.append(
            {
                "name":name,
                "method":method,
                "model":model,
                "dataset":dataset,
                "training":training,
                "variant":variant,
                "seed":seed,
                "representation correctness":representation_correctness,
                "importance correctness unaligned":importance_correctness_unaligned,
                "importance correctness aligned":importance_correctness_align,
                "importance correctness":importance_correctness_diff,
                "representation correctness_ARS":representation_correctness_ARS,
                "importance correctness unaligned_ARS":importance_correctness_unaligned_ARS,
                "importance correctness aligned_ARS":importance_correctness_align_ARS,
                "importance correctness_ARS":importance_correctness_diff_ARS,
                "representation correctness_JS":representation_correctness_JS,
                "importance correctness unaligned_JS":importance_correctness_unaligned_JS,
                "importance correctness aligned_JS":importance_correctness_align_JS,
                "importance correctness_JS":importance_correctness_diff_JS,
                "representation correctness_NMIS":representation_correctness_NMIS,
                "importance correctness unaligned_NMIS":importance_correctness_unaligned_NMIS,
                "importance correctness aligned_NMIS":importance_correctness_align_NMIS,
                "importance correctness_NMIS":importance_correctness_diff_NMIS,
                "n_concepts_aligned":n_concepts_aligned,
                "n_concepts_unaligned":n_concepts_unaligned,
                "n_concepts":n_concepts,
                "ratio":ratio
            }
        )
    dfa = pd.DataFrame(aggregated)
    dfa.to_csv((Path(path_output)/f"aggregated_results_t{threshold_dst:.2E}_t{threshold_abs:.2E}.csv".replace(".0","0")).as_posix(), sep=";")

    
    plt.rcParams.update({'font.size':40})
    for metric in [
                "representation correctness",
                "importance correctness",
                "representation correctness_ARS",
                "importance correctness_ARS",
                "representation correctness_JS",
                "importance correctness_JS",
                "representation correctness_NMIS",
                "importance correctness_NMIS"]:
        dfa_tmp = dfa[dfa["dataset"].isin(["AB", "BigSmall", "metal_nut", "leather"])]
        g = sns.catplot(x="model", y=metric,
                        hue="method", col="dataset", hue_order=["eclad", "ace", "cshap"],
                        data=dfa_tmp, kind="box",
                        height=8, aspect=1.4)
        if metric in ["representation correctness","importance correctness"]:
            threshold=threshold_dst
        else:
            threshold=threshold_abs
        fname = (Path(path_output)/f"boxplot_{metric}_t{threshold:.2E}_full.jpg".replace(".0","0").replace("+",""))
        plt.savefig(fname, format="jpg", dpi=600, facecolor='white', bbox_inches='tight', transparent=False)
        plt.close()
        for dataset in ["ABplus", "AB", "CO", "BigSmall", "colorGB", "isA", "leather", "metal_nut"]:
            dfa_tmp = dfa[dfa["dataset"]==dataset]
            g = sns.catplot(x="model", y=metric,
                            hue="method", col="dataset", hue_order=["eclad", "ace", "cshap"],
                            data=dfa_tmp, kind="box",
                            height=10, aspect=1.4)
            fname = (Path(path_output)/f"boxplot_{metric}_t{threshold:.2E}_full_{dataset}.jpg".replace(".0","0").replace("+",""))
            plt.savefig(fname, format="jpg", dpi=600, facecolor='white', bbox_inches='tight', transparent=False)
            plt.close()
