from scipy.stats import gmean
import numpy as np
from magnipy import Magnipy
from magnipy.distances import get_dist
from vendi_score import vendi
import pandas as pd
import pickle
from sklearn.model_selection import cross_val_score
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
import numpy as np
from sklearn.model_selection import RepeatedKFold, cross_validate
from sklearn.ensemble import RandomForestRegressor
from sklearn.isotonic import IsotonicRegression
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import pairwise_distances
from n_gram_sim import *
from diversity_metrics import *
import sentence_transformers
import pickle
import numpy as np
from prediction_models import *

path="./data/dec/"

def embedding_function(resp_set, model_name = 'bert-large-nli-stsb-mean-tokens'):
    emb=resp_to_emb(resp_set, model_name)
    return emb

def resp_to_emb(resp_list, model_name = 'bert-large-nli-stsb-mean-tokens'):
    n=len(resp_list)
    model = sentence_transformers.SentenceTransformer(model_name)
    embeds = np.array(model.encode(resp_list))
    assert len(embeds.shape) == 2
    assert embeds.shape[0] == n
    embeds = np.reshape(embeds, [n, -1])
    return embeds

def create_embeddings(resps, model_name, out_path, idx):   
    emb = embedding_function(resps, model_name = model_name)
    with open(out_path+model_name+str(round(idx))+'.pkl', 'wb') as fp:
        pickle.dump(emb, fp)
    return emb

def load_dec_data(idx, input_path, model_name, data_in=None, precomputed=True):
    if precomputed:
        with open(input_path+"/"+model_name+str(round(idx))+'.pkl', 'rb') as fp:
            emb = pickle.load(fp)
    else:
        if data_in is None:
            raise ValueError("Input data must be provided if embedding isn't precomputed!")
        emb = embedding_function(input_data_to_line(idx), model_name)
    if isinstance(emb, list):
        return emb[0]
    else:
        return emb


def all_data_to_embeddings(path, human=False):
    datasets=["prompt", 
        "resp", 
        "story"]
    for dataset in datasets:
        if human:
            start_path1=path+'data/raw/decTest/dec_test_200_with_hds_'
        else:
            start_path1=path+"data/raw/decTest/dec_test_1000_no_hds_"
        end_path="_gen"
        input_path=start_path1+dataset+end_path
        data_in= pd.read_csv(input_path+".csv")

        data_to_embeddings(data_in, path)

def data_to_embeddings(data_in, input_path,
    model_names=[
    "bert-large-nli-stsb-mean-tokens",
    "all-mpnet-base-v2",
    "all-distilroberta-v1",
    "roberta-base-nli-mean-tokens",
    "all-MiniLM-L12-v2"
    ]):
    for model_name in model_names:
        print(model_name)
        for i in range(data_in.shape[0]):
            print(i) 
            resps = data_in.loc[:,[str(j) for j in range(10)]].iloc[i,:]
            resp = [str(r) for r in resps]
            emb = embedding_function(resp, model_name = model_name)
            with open(input_path+"/"+model_name+str(round(i))+'.pkl', 'wb') as fp:
                pickle.dump(emb, fp)

def read_files(dataset, model_name, reference_scale=None, target_scale=None):
    output_path="./dec_text/"+dataset+model_name
    if reference_scale is not None:
            if reference_scale !=0.5:
                output_path="./dec_text/"+dataset+model_name+str(reference_scale)
    if target_scale is not None:
        output_path=output_path+str(target_scale)
    with open(output_path+"_magnitude_results"+'.pkl', 'rb') as fp:
        mag_results = pickle.load(fp)
    return mag_results

def try_different_scales(human=True):
    all_results=pd.DataFrame()
    all_corrs=pd.DataFrame()
    for r in [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]:  
        # [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
        if r not in [0.5]:  
            run_dec_metrics(metrics=[ 
                "cosine"],
                            model_names=[
            #"bert-large-nli-stsb-mean-tokens",
            "all-mpnet-base-v2"#,
            #"all-distilroberta-v1",
            #"roberta-base-nli-mean-tokens",
            #"all-MiniLM-L12-v2"
            ],
            datasets=["prompt", 
                "resp", 
                "story"],
                #new_data=False, reference_scale=r,  n_ts=20, human=human, 
                new_data=False, reference_scale=r,  n_ts=10, human=human, 
                path=path, target_scale=None
                )
        if True:
            for d in ["prompt", "resp", "story"]:
                for m in [#"bert-large-nli-stsb-mean-tokens",
            "all-mpnet-base-v2"#,
            #"all-distilroberta-v1",
            #"roberta-base-nli-mean-tokens",
            #"all-MiniLM-L12-v2"
            ]:
                    print(d)
                    sums = get_prediction_results_hds(m, human=human, reference_scale=r)[d]
                    sums["conv"]=r

                    r2s = r2_subsampling_summary(sums, name="", corr="spearman",
                                cats=['mag_area_cosine'#,         
                                                    #'vendi_cosine', 'stds_div_zero',     
                    #'neg_mean_cosine'
                    ])
                    r2s["dataset"] =d
                    r2s["model"]=m
                    r2s["conv"]=r
                    
                    all_results = pd.concat([all_results, r2s], axis=0)

                    
                    sums = get_prediction_results_hds(m, name="summary_statistics", human=human, reference_scale=r)
                    data_in  = get_prediction_results_hds(m, name="data_in", human=human, reference_scale=r)

                    if human:
                        label_value="metric_abs_hds_mean"
                    else:
                        label_value="label_value"
                    allll = pd.concat([sums[d], data_in[d]], axis=1)
                    #print(allll.columns)
                    corrs, coo = corr_subsampling_summary(allll, label_value=label_value, 
                                                    name="", corr="spearman", cats=['mag_area_cosine'#,         
                                                    # 9, 
                                                    #'vendi_cosine', 'stds_div_zero',     
                    #'neg_mean_cosine'
                    ], n_resample=1000, size_resample=data_in[d].shape[0], replace=True)
                    corrs["dataset"] =d
                    corrs["model"]=m
                    corrs["conv"]=r
                    all_corrs = pd.concat([all_corrs, corrs], axis=0)
    if False:
        if human:
            all_corrs.to_csv("./dec_text/scale_corrs_hds_30.csv")
            all_results.to_csv("./dec_text/scale_results_hds_30.csv")
        else:
            all_corrs.to_csv("./dec_text/scale_corrs_30.csv")
            all_results.to_csv("./dec_text/scale_results_30.csv")


def try_different_conv_scales(human=True):
    all_results=pd.DataFrame()
    all_corrs=pd.DataFrame()
    for r in [ 0.6, 0.7, 0.8, 0.9, 0.95, 0.99]:  
        # [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
        #if r not in [0.5]:  
        if True:
            run_dec_metrics(metrics=[ 
                "cosine"],
                            model_names=[
            #"bert-large-nli-stsb-mean-tokens",
            "all-mpnet-base-v2"#,
            #"all-distilroberta-v1",
            #"roberta-base-nli-mean-tokens",
            #"all-MiniLM-L12-v2"
            ],
            datasets=["prompt", 
                "resp", 
                "story"],
                #new_data=False, reference_scale=r,  n_ts=20, human=human, 
                new_data=False, reference_scale=0.5,  n_ts=20, human=human, 
                path=path, target_scale=r
                )
        
        if True:
            for d in ["prompt", "resp", "story"]:
                for m in [#"bert-large-nli-stsb-mean-tokens",
            "all-mpnet-base-v2"#,
            #"all-distilroberta-v1",
            #"roberta-base-nli-mean-tokens",
            #"all-MiniLM-L12-v2"
            ]:
                    print(d)
                    sums = get_prediction_results_hds(m, human=human, reference_scale=0.5, target_scale=r)[d]
                    sums["conv"]=r

                    r2s = r2_subsampling_summary(sums, name="", corr="spearman",
                                cats=['mag_area_cosine'#,         
                                                    #'vendi_cosine', 'stds_div_zero',     
                    #'neg_mean_cosine'
                    ])
                    r2s["dataset"] =d
                    r2s["model"]=m
                    r2s["conv"]=r
                    
                    all_results = pd.concat([all_results, r2s], axis=0)

                    sums = get_prediction_results_hds(m, name="summary_statistics", human=human, reference_scale=0.5, target_scale=r)
                    data_in  = get_prediction_results_hds(m, name="data_in", human=human, reference_scale=0.5, target_scale=r)

                    if human:
                        label_value="metric_abs_hds_mean"
                    else:
                        label_value="label_value"
                    allll = pd.concat([sums[d], data_in[d]], axis=1)
                    #print(allll.columns)
                    corrs, coo = corr_subsampling_summary(allll, label_value=label_value, 
                                                    name="", corr="spearman", cats=['mag_area_cosine'#,         
                                                    # 9, 
                                                    #'vendi_cosine', 'stds_div_zero',     
                    #'neg_mean_cosine'
                    ], n_resample=1000, size_resample=data_in[d].shape[0], replace=True)
                    corrs["dataset"] =d
                    corrs["model"]=m
                    corrs["conv"]=r
                    all_corrs = pd.concat([all_corrs, corrs], axis=0)
    #if False:    
    if human:
        all_corrs.to_csv("./dec_text/scale_corrs_hds_conv10.csv")
        all_results.to_csv("./dec_text/scale_results_hds_conv10.csv")
    else:
        all_corrs.to_csv("./dec_text/scale_corrs_conv10.csv")
        all_results.to_csv("./dec_text/scale_results_conv10.csv")


def run_dec_metrics(path, metrics=["cosine_ngrams", "cosine"],#
                    model_names=[
    "bert-large-nli-stsb-mean-tokens",
    "all-mpnet-base-v2",
    "all-distilroberta-v1",
    "roberta-base-nli-mean-tokens",
    "all-MiniLM-L12-v2"
    ],
    datasets=["prompt", 
        "resp", 
        "story"],
        new_data=False, reference_scale=0.5,  n_ts=20, scoring="r2", human=False, n_samples=None, target_scale=0.95):
    if human:
        start_path1=path+'data/raw/decTest/dec_test_200_with_hds_'
        start_path2=path+"data/with_metrics/decTest/dec_test_200_with_hds_"
        lable_name="metric_abs_hds_mean"
    else:
        start_path1=path+"data/raw/decTest/dec_test_1000_no_hds_"
        start_path2=path+"data/with_metrics/decTest/dec_test_1000_no_hds_"
        lable_name="label_value"
    end_path="_gen"
    for dataset in datasets:
        these_metrics=metrics.copy()
        print(dataset)
        
        input_path_csv=start_path1+dataset+end_path  ### path where the csvs are stored
        data_in= pd.read_csv(input_path_csv+".csv")
        print(data_in.shape[0])
        n_samples=data_in.shape[0]
        input_path=start_path2+dataset+end_path ### path where the embeddings are stored



        if "cosine_ngrams" in these_metrics:
            n_samples=data_in.shape[0]
            def load_ngrams(idx):
                try:
                    out = input_data_to_line(idx, data_in, new_data=new_data)
                except:
                    print(idx)
                    print(dataset)
                    print(data_in.shape)
                    Exception("Too many indices!")
                return out
            if human:
                output_path="./dec_text/"+dataset+"_ngrams_hds"
            else:
                output_path="./dec_text/"+dataset+"_ngrams"
            if reference_scale!=0.5:
                if human:
                    output_path="./dec_text/"+dataset+"_ngrams_hds_"+str(reference_scale)
                else:
                    output_path="./dec_text/"+dataset+"_ngrams_"+str(reference_scale)
            ngram_results = calc_metrics_from_responses(load_ngrams, n_samples=n_samples,  n_ts=n_ts, metrics=["cosine_ngrams"], 
                                    reference_scale=reference_scale, scale=True, absolute_area=True, target_scale=target_scale)
            ngram_results["data_in"]=data_in
            if target_scale != 0.5:
                output_path="./dec_text/"+dataset+"_ngrams_"+str(reference_scale)+str(target_scale)
            save_magnitude_results(ngram_results, output_path)
            df_scores = ngram_results["summary_statistics"]
            linear_results = run_prediction_task_con("cosine_ngrams", df_scores, data_in, dataset, "ngrams")
            met = "cosine_ngrams"
            for met in [#"cosine_ngrams1", 
                #"cosine_ngrams2", 
                "cosine_ngrams3"]:
                ts_results = run_prediction_task_con_ts(met, ngram_results["magnitude_function_dfs"][met], ngram_results["magnitude_differences"][met], 
                                                                data_in, dataset, "ngrams", lable_name=lable_name, scoring=scoring, 
                                n_splits=5, n_repeats= 10, model=IsotonicRegression(out_of_bounds="clip", increasing="auto"))
                linear_results=pd.concat([linear_results, ts_results], axis=0)
            ngram_results["prediction_results"]=linear_results
            save_magnitude_results(ngram_results, output_path)
            these_metrics.remove("cosine_ngrams")

        if len(these_metrics)>0:
            for model_name in model_names:
                print(model_name)
            
                if human:
                    output_path="./dec_text/"+dataset+model_name+"_hds"
                else:
                    output_path="./dec_text/"+dataset+model_name
                if reference_scale!=0.5:
                    if human:
                        output_path="./dec_text/"+dataset+model_name+"_hds_"+str(reference_scale)
                    else:
                        output_path="./dec_text/"+dataset+model_name+str(reference_scale)
                if target_scale != 0.5:
                    output_path=output_path+str(target_scale)
                def load_data(idx):
                    return load_dec_data(idx, input_path=input_path, model_name=model_name)

                mag_results=run_embedding_eval(load_data, output_path, n_samples, n_ts, these_metrics, reference_scale, data_in, dataset, model_name, lable_name, scoring, target_scale=target_scale)

def run_embedding_eval(load_data, output_path, n_samples, n_ts, these_metrics, reference_scale, data_in, dataset, model_name, lable_name, scoring, target_scale=None):
    this_dfs_lin = None
    mag_results = calc_metrics_from_embeddings(load_data, n_samples=n_samples,  n_ts=n_ts, metrics=these_metrics, 
                                    reference_summaries = False, reference_scale=reference_scale, scale=True, absolute_area=True, nearest_k=10, target_scale=target_scale)
    mag_results["data_in"]=data_in
    df_scores = mag_results["summary_statistics"]
    save_magnitude_results(mag_results, output_path)
    for met in these_metrics:
        linear_results = run_prediction_task_con(met, df_scores, data_in, dataset, model_name, 
                                                 n_splits=5, n_repeats= 10, lable_name=lable_name)
        if this_dfs_lin is None:
            this_dfs_lin=linear_results
        else:
            this_dfs_lin=this_dfs_lin.merge(linear_results, how='outer')
    mag_results["prediction_results"]=this_dfs_lin
    save_magnitude_results(mag_results, output_path)
    return mag_results


def run_prediction_task_con(met, df_scores, data_in, dataset, model_name, lable_name="label_value", 
                            scoring="r2", n_splits=5, n_repeats= 10, model=IsotonicRegression(out_of_bounds="clip", increasing="auto")):
    if met == "cosine":
        rows=["stds_div", "stds_div_zero", "neg_mean_cosine", "vendi_cosine", "mag_area_cosine",
              "mag_t_min_cosine", "mag_t_med_cosine", "mag_t_max_cosine"]
    elif met =="cosine_ngrams":
        rows = [#"unique_ngrams1", "neg_mean_cosine_ngrams1", "vendi_cosine_ngrams1","mag_area_cosine_ngrams1",
                #"unique_ngrams2", "neg_mean_cosine_ngrams2", "vendi_cosine_ngrams2","mag_area_cosine_ngrams2",
                "unique_ngrams3", "neg_mean_cosine_ngrams3", "vendi_cosine_ngrams3", "mag_area_cosine_ngrams3"]
    elif met == "L2":
        rows=["stds_div", "stds_div_zero", #"neg_mean_l2", 
              "neg_mean_exp", "vendi_exp", "mag_area_L2",
              "mag_t_min_L2", "mag_t_med_L2", "mag_t_max_L2"]
    elif met == "L1":
        rows=["stds_div", "stds_div_zero",#"neg_mean_l1", 
              "neg_mean_expl1", "vendi_expl1", "mag_area_L1",
              "mag_t_min_L1", "mag_t_med_L1", "mag_t_max_L1"]
    df_sub=df_scores[rows]
    df_sub[lable_name]=data_in[lable_name]

    linear_results = regression_cross_validation(df_sub, lable_name, scoring=scoring, model=model,
                                        n_splits=n_splits, get_scores=True,  n_repeats= n_repeats)
    
    linear_results=pd.DataFrame(linear_results)
    linear_results["experiment"]=dataset
    linear_results["model"]=model_name
    linear_results["metric"]=met
    return linear_results

def read_files_hds(dataset, model_name, reference_scale=None, target_scale=None):
    output_path="./dec_text/"+dataset+model_name+"_hds"
    if reference_scale is not None:
        if reference_scale !=0.5:
            output_path="./dec_text/"+dataset+model_name+"_hds_"+str(reference_scale)
    if target_scale is not None:
        output_path=output_path+str(target_scale)
    with open(output_path+"_magnitude_results"+'.pkl', 'rb') as fp:
        mag_results = pickle.load(fp)
    return mag_results

def get_prediction_results_hds(model_name, name="prediction_results", human=True, reference_scale=None, target_scale=None):
    summary_scores={}
    Summary=pd.DataFrame()
    datasets=[
        "prompt",
        "resp",
        "story"
        ]
    for i, dataset in enumerate(datasets):
        if human:
            summary_scores[dataset]=read_files_hds(dataset, model_name, reference_scale=reference_scale, target_scale=target_scale)[name]
        else:
            summary_scores[dataset]=read_files(dataset, model_name, reference_scale=reference_scale, target_scale=target_scale)[name]
    return summary_scores


def input_data_to_line(idx, data_in, new_data=False):
    #print(idx, data_in.columns)
    if new_data:
        this_resp = [d for d in data_in.iloc[idx,:].loc[[str(i) for i in range(40)]]]
    else:
        this_resp = [d for d in data_in.iloc[idx,:].loc[["resp_"+str(i) for i in range(10)]]]
    return this_resp



def collect_prediction_results(path, ngrams=False, rerun=False, scoring="r2", n_repeats=10, n_splits=5, human=False, reference_scale=None, model_names=[
        "bert-large-nli-stsb-mean-tokens",
        "all-mpnet-base-v2",
        "all-distilroberta-v1",
        "roberta-base-nli-mean-tokens",
        "all-MiniLM-L12-v2"
        ]):
    if ngrams:
        model_names=["_ngrams"]
        met = "cosine_ngrams"
    else:
        if model_names is None:
            model_names=[
            "bert-large-nli-stsb-mean-tokens",
            "all-mpnet-base-v2",
            "all-distilroberta-v1",
            "roberta-base-nli-mean-tokens",
            "all-MiniLM-L12-v2"
            ]
        met = "cosine"
    
    datasets=["prompt", 
        "resp", 
        "story"]
    all_dfs_lin = None
    for dataset in datasets:
        if human:
            start_path1=path+'data/raw/decTest/dec_test_200_with_hds_'
        else:
            start_path1=path+"data/raw/decTest/dec_test_1000_no_hds_"
        end_path="_gen"
        input_path=start_path1+dataset+end_path
        data_in= pd.read_csv(input_path+".csv")

        for model_name in model_names:
            if human:
                results = read_files_hds(dataset, model_name, reference_scale=reference_scale) 
                lable_name="metric_abs_hds_mean"
            else:
                results = read_files(dataset, model_name, reference_scale=reference_scale) 
                lable_name="label_value"
            if rerun:
                linear_results = run_prediction_task_con(met=met, df_scores=results["summary_statistics"], data_in=data_in, 
                                                         dataset=dataset, model_name=model_name, scoring=scoring, 
                                                         n_repeats=n_repeats, n_splits=n_splits, lable_name=lable_name)
                if all_dfs_lin is None:
                    all_dfs_lin=linear_results
                else:
                    all_dfs_lin=all_dfs_lin.merge(linear_results, how='outer')
            else:
                linear_results = results["prediction_results"]
                if all_dfs_lin is None:
                    all_dfs_lin=linear_results
                else:
                    all_dfs_lin=all_dfs_lin.merge(linear_results, how='outer')
    return all_dfs_lin

def plot_prediction_results(summary_scores_lin, order, rows=None, metric="cosine", scoring="r2", human=False, title=True):
    ### TODO Adjust!
    plt.rcParams['text.usetex'] = True
    if metric == "cosine_ngrams":
        fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(13, 6))
    elif metric == "all_cosine":
        fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(13, 4))
    else:
        fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(13, 2.8))

    if rows is None:
        if metric == "cosine":
            rows=["mag_area_cosine", "vendi_cosine", "stds_div", "stds_div_zero", "neg_mean_cosine", "model"]
        elif metric == "L2":
            rows=["stds_div", "stds_div_zero","neg_mean_l2", "neg_mean_exp", "vendi_exp", "mag_area_l2", "model"]
        elif metric == "L1":
            rows=["stds_div", "stds_div_zero","neg_mean_l1", "neg_mean_expl1", "vendi_expl1", "mag_area_l1", "model"]
        elif metric =="cosine_ngrams":
            rows = [#"unique_ngrams1", "neg_mean_cosine_ngrams1", "vendi_cosine_ngrams1","mag_area_cosine_ngrams1",
                    #"unique_ngrams2", "neg_mean_cosine_ngrams2", "vendi_cosine_ngrams2","mag_area_cosine_ngrams2",
                    "unique_ngrams3", "neg_mean_cosine_ngrams3", "vendi_cosine_ngrams3", "mag_area_cosine_ngrams3", "model"]
        elif metric == "all_cosine":
            rows = ["stds_div_zero", "neg_mean_cosine", "vendi_cosine", "mag_area_cosine",
                    "unique_ngrams3", "neg_mean_cosine_ngrams3", "vendi_cosine_ngrams3", "mag_area_cosine_ngrams3", "model"]

    custom_palette = sns.color_palette("tab10", n_colors=6)
    custom_palette = [color for i, color in enumerate(custom_palette) if i != 2]#[color for color in custom_palette if color != (0.12156862745098039, 0.4666666666666667, 0.7058823529411765)]

    for i, name in enumerate(["prompt", "resp","story"]):
        data_plot = summary_scores_lin[summary_scores_lin.experiment==name]
        if metric == "all_cosine":
            data_plot = data_plot[(data_plot.metric=="cosine") | (data_plot.metric=="cosine_ngrams")]
        else:
            data_plot= data_plot[data_plot.metric==metric]
        data_plot= data_plot.loc[:,rows]

        data_plot = data_plot.melt(id_vars=["model"], value_vars=rows, value_name=scoring)
        #print(data_plot)

        colors=[(0.09019607843137255, 0.7450980392156863, 0.8117647058823529, 1.0), #plt.cm.tab20(0), 
                plt.cm.tab20(12), plt.cm.tab20(6), plt.cm.tab20(2)]
        if len(colors) < len(data_plot.variable.unique()):
            colors = sns.color_palette("tab10", n_colors=len(data_plot.variable.unique()))

        sns.pointplot(data=data_plot, y="model", hue="variable", x=scoring, ax=axes[i], 
                    errorbar="sd", 
                    #errorbar=("pi",90), 
                    palette=colors,#palette="tab10",#custom_palette,#sns.diverging_palette(20, 220, s=60, sep=0, as_cmap=True),
                    #sns.color_palette("tab10"),
                    join=False, dodge=0.5, scale = 0.7, errwidth=1.5, order=order)#, inner="quart")
        axes[i].tick_params(axis='x', rotation=0, labelsize=8)
        axes[i].set_title(name)
        if i<3:
            axes[i].get_legend().remove()
            axes[i].set_ylabel('')
        if i>0:
            axes[i].set_yticklabels([])
            axes[i].set_yticks([])
        if i==0:
            if title:
                if human:
                    axes[i].set_ylabel('(b) predicting\nhuman scores')
                else:
                    axes[i].set_ylabel('(a) predicting\ndecoding parameters')

        for j in range(1, len(data_plot.model.unique())):
            axes[i].axhline(y=j - 0.5, color='black', linestyle='-', linewidth=0.8)
            #axes[i].set_yticks(None)
        axes[i].spines['left'].set_color('white')

    label_list = []
    for t in axes[i].get_legend_handles_labels():
        # the first result will be all handles, i.e. the dots in the legend
        # the second result will be all legend text
        label_list.append(t)

    new_list = []
    if ((metric != "cosine_ngrams") & (metric != "all_cosine")) & (len(data_plot.variable.unique())==4):
        for j, txt in enumerate(label_list[1]):
            #i#f txt[0] == '0':
            #txt = [r"MagArea($t_{cut}$)",#r"$\int_0^{t_{cut}} Mag_X$", 
            #       r"Mag($t_{min}$)",#r"$Mag_X(t_{cut})$", 
            #       "VendiScore","GMStds", "AvgSim"][j]
            txt = [r"\textsc{MagArea}", #r"$\mathrm{Mag}(t_{\mathrm{min}})$",#r"MagArea($t_{cut}$)",#r"$\int_0^{t_{cut}} Mag_X$", 
            #r"Mag($t_{min}$)",#r"$Mag_X(t_{cut})$", 
            r"\textsc{VS}", r"\textsc{AvgSim}", r"\textsc{GMStds}" #r"\textsc{GMStds}", 
            ][j]
            new_list.append(txt)
        label_list[1] = new_list

        #ax.legend(handles=label_list[0], labels=label_list[1])

        plt.legend(handles=label_list[0], #ncol=len(label_list[0]), 
                    loc='center left', 
                bbox_to_anchor=(1, 0.5), #fancybox=True, 
                title='diversity measures', labels=label_list[1], frameon=False)
    else:
        plt.legend(bbox_to_anchor=(1, 0.5), frameon=False) #ncol=len(label_list[0]),
            #["area under the magnitude function", 
            #                                                                                "magnitude at one scale",
            #                                                                                "average similarity",
            #                                                                                "geometric mean of embedding features"])#
            #[r'$\int Mag_X', r'$\int Mag_X(t_{cut})', , r'$\int Mag_X(t_{cut})'])
    sns.despine(fig)
    plt.tight_layout()
    plt.subplots_adjust(wspace=0.05, hspace=0)
    #plt.savefig("div_pred_dec_summary.svg")

def get_differences(all_results_hds, experiment="prompt", this_row="mag_area_cosine", cols=['vendi_cosine', 'neg_mean_cosine', 'stds_div_zero']):
    print(experiment)
    rows={}
    #subbb = all_results_hds[].dropna()
    subbb= all_results_hds[all_results_hds["experiment"]==experiment]#.dropna()
    for col in cols:
        sub = subbb[[this_row, col]]#.dropna()
        diffs=(sub[this_row]-sub[col]).dropna()
        #print(diffs)/
        #p_val = wilcoxon(diffs, alternative="greater", nan_policy='omit', method = 'approx').pvalue
        rows[col] = [diffs.mean().round(2), diffs.std().round(2), diffs.quantile(0.025).round(2), diffs.quantile(0.975).round(2),# p_val, 
                     experiment]
    df = pd.DataFrame(rows).T
    df.columns = ["mean difference", "standard deviation", "lower 95% percentile interval", "upper 95% percentile interval", #"Wilcoxon p-value", 
                  "task"]
    all_ranks = pd.DataFrame()
    for i, name in enumerate([experiment]):
        #,"story"]):
        data_plot = subbb
        data_plot = data_plot[~(data_plot.model == "ngrams")]
        data_plot= data_plot.loc[:,["stds_div_zero", "neg_mean_cosine", "vendi_cosine", #19, 
                                    "mag_area_cosine",
                        #"unique_ngrams3", "neg_mean_cosine_ngrams3", "vendi_cosine_ngrams3", "mag_area_cosine_ngrams3", 
                        "model"]].dropna()
        #.drop("stds_geom_nonzero_", axis=1)
        #data_plot= data_plot.drop('mag_at_tmin_cosine_', axis=1)

        ranks= data_plot.iloc[:,:(len(data_plot.columns)-1)].rank(axis=1, ascending=False, method='min')
        ranks["model"] = data_plot["model"]
        ranks["experiment"] = name
        all_ranks=pd.concat([all_ranks, ranks], axis=0)
    all_ranks = all_ranks.melt(id_vars=["model"], value_vars=["mag_area_cosine", #19,
                                                            "vendi_cosine","neg_mean_cosine", "stds_div_zero",  
                        #"unique_ngrams3", "neg_mean_cosine_ngrams3", "vendi_cosine_ngrams3", "mag_area_cosine_ngrams3"
                        ], value_name="rank")
    for m in ["mag_area_cosine", #19,
                                                            "vendi_cosine","neg_mean_cosine", "stds_div_zero",  
                        #"unique_ngrams3", "neg_mean_cosine_ngrams3", "vendi_cosine_ngrams3", "mag_area_cosine_ngrams3"
                        ]:
        print(all_ranks[all_ranks["variable"]==m]["rank"].median())
    return df #rows#pd.DataFrame(rows, columns=["mean", "std", "low", "high"])

def get_means(all_results_hds, experiment="prompt",  cols=["mag_area_cosine", 'vendi_cosine', 'neg_mean_cosine', 'stds_div_zero']):
    print(experiment)
    rows={}
    #subbb = all_results_hds[].dropna()
    subbb= all_results_hds[all_results_hds["experiment"]==experiment]#.dropna()
    for col in cols:
        sub = subbb[[col]]#.dropna()
        diffs=(sub).dropna()
        rows[col] = [diffs.mean().round(2)[0], diffs.std().round(2)[0], diffs.quantile(0.025).round(2)[0], diffs.quantile(0.975).round(2)[0],# p_val, 
                     experiment]
    df = pd.DataFrame(rows).T
    df.columns = ["mean", "standard deviation", "lower 95% percentile interval", "upper 95% percentile interval", #"Wilcoxon p-value", 
                  "task"]
    return df

def get_global_differences(all_results_hds, this_row="mag_area_cosine", cols=['vendi_cosine', 'neg_mean_cosine', 'stds_div_zero']):
    rows={}
    #subbb = all_results_hds[].dropna()
    #subbb= all_results_hds[all_results_hds["experiment"]==experiment]#.dropna()
    for col in cols:
        sub = all_results_hds[[this_row, col]]#.dropna()
        diffs=(sub[this_row]-sub[col]).dropna()
        #print(diffs)
        #p_val = wilcoxon(diffs, alternative="greater", nan_policy='omit', method = 'approx').pvalue
        rows[col] = [diffs.mean().round(2), diffs.std().round(2), diffs.quantile(0.025).round(2), diffs.quantile(0.975).round(2)]
    df = pd.DataFrame(rows).T
    df.columns = ["mean difference", "standard deviation", "lower 95% percentile interval", "upper 95% percentile interval"]
    return df #rows#pd.DataFrame(rows, columns=["mean", "std", "low", "high"])


# poetry run python3 dec_experiment.py

if __name__ == "__main__":


    all_data_to_embeddings(path, human=False)
    all_data_to_embeddings(path, human=True)
    #try_different_conv_scales(human=True)
    #try_different_conv_scales(human=False)

    if True:
        run_dec_metrics(metrics=[#"cosine_ngrams", 
        "cosine"#, "L2"
        ],
                        model_names=[
        "bert-large-nli-stsb-mean-tokens",
        "all-mpnet-base-v2",
        "all-distilroberta-v1",
        "roberta-base-nli-mean-tokens",
        "all-MiniLM-L12-v2"
        ],
        datasets=["prompt",#,
        "resp", 
        "story"
            ],
            new_data=False, reference_scale=0.5,  n_ts=20, human=False, scoring='r2', path=path, target_scale=0.99)

    if True:
        run_dec_metrics(metrics=[#"cosine_ngrams", 
        "cosine"#, "L2"
        ],
                        model_names=[
        "bert-large-nli-stsb-mean-tokens",
        "all-mpnet-base-v2",
        "all-distilroberta-v1",
        "roberta-base-nli-mean-tokens",
        "all-MiniLM-L12-v2"
        ],
        datasets=["prompt", 
            "resp", 
            "story"],
            new_data=False, reference_scale=0.5,  n_ts=20, human=True, scoring='r2', path=path, target_scale=0.99)
