from evaluation import UnsupervisedEvaluator, SupervisedEvaluator, SemiSupervisedEvaluator
from dataset import Dataset
from pipeline import MultiPipeline
import os
import matplotlib.pyplot as plt
import seaborn as sns
from utils import run_parallel
from datetime import datetime
import numpy as np
import json
from loguru import logger
import shutil
import pandas as pd
from jinja2 import Environment, FileSystemLoader
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from scipy.spatial import cKDTree
from utils import get_best_device
from transformers import DistilBertForSequenceClassification
import finetune
import pickle
from utils import unnest_dictionary
from model_loader import load_model, load_tokenizer
import torch
from transformers import set_seed


sns.set_theme(style="whitegrid")

class MetaEvaluator:
    def __init__(self, pipeline_datasets, evaluation_dataset=None, y="linear_fake_to_real.accuracy") -> None:
        self.pipeline_datasets = pipeline_datasets
        self.evaluation_dataset = evaluation_dataset
        self.y = y
    
    def generate_metric_plot(self, x, y, x_real, y_real, save_loc, metric_name):
        fig, ax = plt.subplots(dpi=120)
        sns.scatterplot(x=x, y=y, label="Generated")
        sns.scatterplot(x=[x_real], y=[y_real], label="Real")
        ax.set_title(metric_name)
        ax.set_xlabel('metric value')
        ax.set_ylabel("performance")
        fig.tight_layout()
        fig.savefig(save_loc)
        plt.close()

    def word_distribution_plot(self, dataset_id, label, word_distrib, save_loc, real_word_distrib=None, n_words=30):
        fig, ax = plt.subplots(figsize=(8, 8))
        sorted_items = sorted(word_distrib.items(), key=lambda x: x[1], reverse=True)
        top_n_items = sorted_items[:n_words]
        x = [x[0] for x in top_n_items]
        y = [x[1] for x in top_n_items]
        source = ["fake"] * len(x)
        
        if real_word_distrib is not None:
            y_real = [real_word_distrib.get(word, 0) for word in x]
            x += x
            y += y_real
            source += ["real"] * len(y_real)
        else:
            source = ["real"] * len(x)
        
        df = pd.DataFrame(
            {"word": x, "frequency": y, "source": source}
        )

        sns.barplot(x="frequency", y="word", hue='source', data=df, ax=ax)
        ax.set_title(f"{dataset_id}, {label}: Top {n_words} Words")
        fig.tight_layout()
        fig.savefig(save_loc)
        plt.close()

    def generate_generation_plot(self, y_real, y, ids, save_loc):
        fig, ax = plt.subplots(dpi=120)
        palette = sns.color_palette(["#c44e52"] + ["#4c72b0"] * len(y))
        y_tot = [y_real] + y
        all_x = ["real"]  + [str(id) for id in ids]
        argsort = np.argsort(y_tot)[::-1]
        sns.barplot(x=all_x, y=y_tot, palette=palette, ax=ax, order=np.array(all_x)[argsort])
        ax.set_title('Generator performance')
        ax.set_xlabel('id')
        ax.set_ylabel('performance')
        fig.savefig(save_loc)
        plt.close()

    def deduplicate_vectors(self, array, dataset_indices, labels_indices, epsilon=1e-8):
        kdtree = cKDTree(array)
        dedup_mask = np.ones(array.shape[0], dtype=bool)
        new_dataset_indices = []
        
        for i, vector in enumerate(array):
            datasets_index = [dataset_indices[i]]
            if dedup_mask[i]:
                neighbors = kdtree.query_ball_point(vector, r=epsilon)
                for neighbor in neighbors[1:]:
                    dedup_mask[neighbor] = False
                    datasets_index.append(dataset_indices[neighbor])
            new_dataset_indices.append(datasets_index)

        unique_vectors = array[dedup_mask]
        unique_labels = labels_indices[dedup_mask]
        real_new_dataset_indices = [index for i, index in enumerate(new_dataset_indices) if dedup_mask[i]]
        return unique_vectors, real_new_dataset_indices, unique_labels

    def generate_vector_plot(self, base_path, sorted_indices):
        vector_plots_locations = dict()
        os.makedirs(os.path.join(base_path, "plots", "general", "vectors"), exist_ok=True)
        vector_plots_locations["general"] = []
        for label in self.evaluation_dataset.get_labels():
            os.makedirs(os.path.join(base_path, "plots", label, "vectors"), exist_ok=True)
            vector_plots_locations[label] = []

        label_descriptions = ["general"] + list(self.evaluation_dataset.get_labels())
        assert "general" not in self.evaluation_dataset.get_labels() # just a precaution
        
        all_vectors = []
        dataset_indices = []
        labels_indices = []
        all_vectors.extend(self.evaluation_dataset.get_vectors())
        dataset_indices.extend([0] * self.evaluation_dataset.size())
        labels_indices.extend(list(self.evaluation_dataset.get_label_col()))

        for index in sorted_indices:
            dataset = self.pipeline_datasets[index][1]
            all_vectors.extend(dataset.get_vectors())
            dataset_indices.extend([index + 1] * dataset.size())
            labels_indices.extend(list(dataset.get_label_col()))

        all_vectors = np.array(all_vectors)
        dataset_indices = np.array(dataset_indices)
        labels_indices = np.array(labels_indices)
        pca = PCA(n_components=50)
        tsne = TSNE(n_components=2)

        logger.debug("Fitting TSNE + PCA")
        train_vectors, dataset_dedup, labels_indices = self.deduplicate_vectors(all_vectors, dataset_indices, labels_indices)
        pca_vectors = pca.fit_transform(train_vectors)
        tsne_vectors = tsne.fit_transform(pca_vectors)
        logger.debug("Done fitting TSNE + PCA")

        eval_indices = np.array([0 in x for x in dataset_dedup])
        for i, index in enumerate(sorted_indices):
            for label in label_descriptions:
                dataset_indi = np.array([index + 1 in x for x in dataset_dedup])
                locs = np.logical_or(eval_indices, dataset_indi)
                if label != "general":
                    locs = np.logical_and(locs, labels_indices == label)

                hues = (eval_indices[locs]).astype(np.int16)
                fig, ax = plt.subplots(dpi=120)
                sns.scatterplot(x=tsne_vectors[locs, 0], y=tsne_vectors[locs, 1], hue=hues, ax=ax, s=25, alpha=0.4, legend=False)
                # Create a custom legend
                legend_labels = ["Generated", "Real"]
                custom_legend = [plt.Line2D([0], [0], marker='o', color='w', label=label, 
                                            markerfacecolor=color, markersize=8) 
                                            for label, color in zip(legend_labels, sns.color_palette()[:len(legend_labels)])]

                # Add the custom legend to the plot
                ax.legend(handles=custom_legend)
                ax.set_title(f'Vector space dataset {i}')
                total_extension_path = os.path.join("plots", label, "vectors", "vector_space_" + str(index) + ".png")
                vector_plots_locations[label].append(total_extension_path)
                fig.savefig(os.path.join(base_path, total_extension_path))
                plt.close()

        return vector_plots_locations


    def generate_converter_plot(self, y, ids, save_loc):
        fig, ax = plt.subplots(dpi=120)
        if len(y) > 0:
            palette = sns.color_palette(["#4c72b0"] * len(y))
            argsort = np.argsort(y)[::-1]
            sns.barplot(x=ids, y=y, palette=palette, ax=ax, order=np.array(ids)[argsort])
        ax.set_title('Converter performance gain')
        ax.set_xlabel('id')
        ax.set_ylabel('performance gain')
        fig.savefig(save_loc)
        plt.close()

    def correlation_plot(self, list_dict_metrics, dict_real_metrics, save_loc):
        df = pd.DataFrame(list_dict_metrics)
        
        df = df.fillna(0)

        df = df._get_numeric_data()

        corr_matrix = df.corr()

        fig, ax = plt.subplots(figsize=(20, 20))
        sns.set(font_scale=1.4)
        mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
        cmap = sns.diverging_palette(220, 10, as_cmap=True)
        sns.heatmap(corr_matrix, cmap=cmap, mask=mask, linewidths=0.5, square=True, ax=ax, cbar_kws={"shrink": 0.5})
        plt.xticks(rotation=90)
        plt.yticks(rotation=0)
        ax.set_title("Correlation matrix (yes it looks ugly)")
        fig.tight_layout()
        fig.savefig(save_loc)
        plt.close()

    def calculate_global_evaluation_metrics(self, dict_real_metrics, list_dict_metrics, n_cores, dtype=torch.bfloat16, reload_=True, 
                                            batch_size=16):
        # real to fake metrics
        key_names = list(dict_real_metrics.keys())
        for model_folders in key_names:
            if model_folders.startswith("model."):
                metric_name = model_folders.split(".")[1]
                if not os.path.exists(dict_real_metrics[model_folders]):
                    logger.error(f"Model does not exist {dict_real_metrics[model_folders]}")
                    continue
                device = get_best_device()
                if "finetune_causal_metrics" not in model_folders:
                    tokenizer = finetune.get_tokenizer()
                    model = DistilBertForSequenceClassification.from_pretrained(dict_real_metrics[model_folders])
                    label_encoder = pickle.load(open(dict_real_metrics["label_encoders." + metric_name], "rb"))
                    
                    model = model.to(device)
                    eval_funcs = []
                    for i, (_, dataset) in enumerate(self.pipeline_datasets):
                        # eval_funcs.append(lambda dataset=dataset: finetune.evaluate_metrics(model, tokenizer, label_encoder, dataset.df, device))
                        if metric_name + "_real_to_fake" not in list_dict_metrics[i]:
                            eval_funcs.append(lambda dataset=dataset: finetune.evaluate_metrics(model, tokenizer, label_encoder, dataset.df, device))
                        else:
                            eval_funcs.append(lambda dataset=dataset: None)

                    # shutil.rmtree(dict_real_metrics[model_folders])
                else:
                    eval_funcs = []
                    for label in self.evaluation_dataset.get_labels():
                        model = os.path.join(dict_real_metrics[model_folders], label)
                        tokenizer = load_tokenizer(os.path.join(dict_real_metrics[model_folders], label))
                        
                        for i, (_, dataset) in enumerate(self.pipeline_datasets):
                            label_dataset = dataset.df[dataset.df["label"] == label].copy()
                            label_dataset.drop(["label"], axis=1, inplace=True)
                            if metric_name + "_real_to_fake" not in list_dict_metrics[i]:
                                eval_funcs.append(lambda dataset=label_dataset, model_name=model, tokenizer_=tokenizer, device_=device: 
                                                    SupervisedEvaluator.get_perplexity(dataset, model_name, tokenizer_, device_, 
                                                                                       batch_size=batch_size, dtype=dtype))
                            else:
                                eval_funcs.append(lambda dataset=label_dataset: None)
                results = run_parallel(eval_funcs, n_cores)

                if "finetune_causal_metrics" in model_folders:
                    real_results = []
                    for i in range(len(self.pipeline_datasets)):
                        if metric_name + "_real_to_fake" not in list_dict_metrics[i]:
                            real_results.append(np.mean([results[i + j] for j in range(0, len(results), len(self.pipeline_datasets))]) )
                        else:
                            real_results.append(list_dict_metrics[i].get(metric_name + "_real_to_fake"))
                    results = real_results[:]

                for i in range(len(results)):
                    if results[i] is not None:
                        list_dict_metrics[i][metric_name + "_real_to_fake"] = results[i]
                        list_dict_metrics[i] = unnest_dictionary(list_dict_metrics[i])

                key_names = list(dict_real_metrics.keys())
                for metric in key_names:
                    if metric.startswith(metric_name + "_fake_to_fake"):
                        dict_real_metrics[metric.replace("_fake_to_fake", "_real_to_fake")] = dict_real_metrics[metric]
                    elif metric.startswith(metric_name) and "finetune_causal_metrics" in model_folders:
                        dict_real_metrics[metric.replace("_fake_to_fake", "_real_to_fake")] = dict_real_metrics[metric]                            

        return dict_real_metrics, list_dict_metrics

    def calculate_dataset_info(self, eval1, eval2, folder, n_cores, dataset_range, include_eval=True):
        os.makedirs(os.path.join(folder, "temp"), exist_ok=True)
        
        all_calculate_funcs = []
        for k, (_, dataset) in enumerate(self.pipeline_datasets[dataset_range[0]:dataset_range[1]]):
            path = os.path.join(folder, "temp", str(k + dataset_range[0]))
            if not os.path.exists(path):
                all_calculate_funcs.append(lambda dataset=dataset, k=k: dataset.save(os.path.join(folder, "temp", str(k + dataset_range[0])), False, True))

        if include_eval:
            path1 = os.path.join(folder, "temp", "eval1")
            if not os.path.exists(path1):
                all_calculate_funcs.append(lambda i=0: eval1.save(os.path.join(folder, "temp", "eval1"), False, True))
            path2 = os.path.join(folder, "temp", "eval2")
            if not os.path.exists(path2):
                all_calculate_funcs.append(lambda i=0: eval2.save(os.path.join(folder, "temp", "eval2"), False, True))

        all_calculate_funcs.append(self.evaluation_dataset.calculate_all)

        run_parallel(all_calculate_funcs, n_cores=n_cores)

        locations = [dataset.save_path for (_, dataset) in self.pipeline_datasets]
        for i in range(len(self.pipeline_datasets)):
            path = os.path.join(folder, "temp", str(i + dataset_range[0]))
            if os.path.exists(path):
                new_dataset = Dataset.load(os.path.join(folder, "temp", str(i)))
                new_dataset.save_path = locations[i]
                self.pipeline_datasets[i] = (self.pipeline_datasets[i][0], new_dataset)

        if include_eval:
            eval1 = Dataset.load(os.path.join(folder, "temp", "eval1"))
            eval2 = Dataset.load(os.path.join(folder, "temp", "eval2"))
        
        return eval1, eval2

    def calculate_evaluation_metrics(self, eval1, eval2, folder, first_phase_n_cores, second_phase_n_cores, exclude, dataset_range, n_runs=5, include_eval=True, 
                                     reload_=True, force=[], **kwargs):
        list_dict_metr, dict_eval_metr = self.load_all_results(folder)
        if not reload_:
            dict_eval_metr = None
            list_dict_metr = [None for _ in range(len(self.pipeline_datasets))]

        real_evaluator = SupervisedEvaluator(eval1, eval2)
        all_funcs = []
        if include_eval:
            eval_metrics_path = os.path.join(folder, "results", f"dict_real_metrics.json")
            models_unnecessary_functions = dir(real_evaluator)
            models_unnecessary_functions = [func for func in models_unnecessary_functions if not func.startswith("finetune") or func in exclude]
            all_funcs.append(lambda save_path=eval_metrics_path: real_evaluator.calculate_all(exclude=models_unnecessary_functions, n_runs=1, persist_temp_folder=True, save_path=save_path, 
                                                                                              results=dict_eval_metr, **kwargs))
            all_funcs.append(lambda save_path=eval_metrics_path, f=force: real_evaluator.calculate_all(exclude=exclude, n_runs=n_runs, save_path=save_path, results=dict_eval_metr,
                                                                                             force=force, 
                                                                                                **kwargs))
            

        for i, (_, dataset) in enumerate(self.pipeline_datasets[dataset_range[0]:dataset_range[1]]):
            list_metrics_path = os.path.join(folder, "results", f"list_dict_metrics_{i + dataset_range[0]}-{i+dataset_range[0]+1}.json")
            evaluator = None
            if self.evaluation_dataset is None:
                evaluator = UnsupervisedEvaluator(dataset)
            elif len(self.evaluation_dataset.get_labels()) <= 1:
                evaluator = SemiSupervisedEvaluator(dataset, self.evaluation_dataset)
            else:
                evaluator = SupervisedEvaluator(dataset, self.evaluation_dataset)

            all_funcs.append(lambda evaluator=evaluator, save_path=list_metrics_path, i=i: evaluator.calculate_all(exclude=exclude, n_runs=n_runs, save_path=save_path, 
                                                                                                                results=list_dict_metr[i + dataset_range[0]], force=force, **kwargs))

        if second_phase_n_cores > 1 and len(all_funcs) > 1:
            metrics = run_parallel(all_funcs, n_cores=second_phase_n_cores)
        else:
            metrics = [func() for func in all_funcs]
    
    def generate_eval_file(self, list_dict_metrics, dict_real_metrics, folder, title, template_folder, run_vector_plots=False, run_word_distrib_plots=False):
        all_keys_in_metrics = set()
        for keys in list_dict_metrics:
            all_keys_in_metrics = all_keys_in_metrics.union(set(keys.keys()))
        
        for key in all_keys_in_metrics:
            for dict_metrics in list_dict_metrics:
                dict_metrics[key] = dict_metrics.get(key, 0)
        

        y = [list_dict_metrics[i][self.y] for i in range(len(list_dict_metrics))]
        y_real = dict_real_metrics[self.y]

        metrics = []
        metric_names = set(dict_real_metrics)
        for list_metric in list_dict_metrics:
            metric_names = metric_names.union(set(list_metric))

        metric_names = list(metric_names)
        metric_names.sort()
        metrics_good = []

        logger.info("Generating metric plots")
        correlation_plot_loc = os.path.join("plots", "correlation_metrics.png")
        self.correlation_plot(list_dict_metrics, dict_real_metrics, os.path.join(folder, correlation_plot_loc))

        for metric in metric_names:
            x = [list_dict_metrics[i].get(metric, 0) for i in range(len(list_dict_metrics))]
            x_real = dict_real_metrics.get(metric, 0)

            if np.isscalar(x_real) and not isinstance(x_real, (bool, str)) and all(np.isscalar(x_i) and not isinstance(x_i, (bool, str)) for x_i in x):
                self.generate_metric_plot(x, y, x_real, y_real, os.path.join(folder, "plots", "metrics", metric + ".png"), metric)
                metrics.append({
                    "name": metric, 
                    "plot": os.path.join("plots", "metrics", metric + ".png")
                })
                metrics_good.append(metric)
            else:
                logger.debug(f"Skipping metric {metric} because it is not a scalar")


        generators_list = []
        generators_jinja = []
        average_score_per_generator = dict()

        converters_list = []
        converters_jinja = []
        average_score_per_converter = dict()

        logger.info("Generating generator and converter jinja templates.")
        for i, (pipeline, _) in enumerate(self.pipeline_datasets):
            if pipeline.generator not in generators_list:
                kwargs = pipeline.generator.generate_settings()
                if kwargs.get("prompts") is None:
                    kwargs["prompts"] = []

                kwargs_no_prompt = kwargs.copy()
                del kwargs_no_prompt["prompts"]
                del kwargs_no_prompt["class"]

                generators_jinja.append({
                    "id": len(generators_list), 
                    "class": kwargs["class"],
                    "kwargs": kwargs_no_prompt, 
                    "prompts": kwargs["prompts"],
                })
                generators_list.append(pipeline.generator)
                
                average_score_per_generator[len(generators_list) - 1] = [y[i], 1]
            else:
                index = generators_list.index(pipeline.generator)
                average_score_per_generator[index][0] += y[i]
                average_score_per_generator[index][1] += 1
                

            for converter in pipeline.converters:
                if converter not in converters_list:
                    kwargs = converter.generate_settings()
                    if kwargs.get("prompts_per_level") is None:
                        kwargs["prompts_per_level"] = []

                    kwargs_no_prompt = kwargs.copy()
                    del kwargs_no_prompt["prompts_per_level"]
                    del kwargs_no_prompt["class"]
                    converters_jinja.append({
                        "id": len(converters_list), 
                        "class": kwargs["class"], 
                        "prompts": kwargs["prompts_per_level"],
                        "kwargs": kwargs_no_prompt
                    })
                    converters_list.append(converter)
                    average_score_per_converter[len(converters_list) - 1] = [y[i], 1]
                else:
                    index = converters_list.index(converter)
                    average_score_per_converter[index][0] += y[i]
                    average_score_per_converter[index][1] += 1
                    
        logger.info("Generating general plots")

        average_score_per_generator = {key: val[0] / val[1] for key, val in average_score_per_generator.items()}
        average_score_not_converter = {key: 1 / max(len(y) - val[1], 1) * (np.sum(y) - val[0]) for key, val in average_score_per_converter.items()}
        average_score_per_converter = {key: val[0] / val[1] for key, val in average_score_per_converter.items()}
        gain_per_converter = {key: average_score_per_converter[key] - average_score_not_converter[key] for key in average_score_per_converter}
        generation_plot = os.path.join("plots", "generators.png")
        
        self.generate_generation_plot(y_real, list(average_score_per_generator.values()), list(average_score_per_generator.keys()), 
                                      os.path.join(folder, generation_plot))

        conversion_plot = os.path.join("plots", "converters.png")
        self.generate_converter_plot(list(gain_per_converter.values()), list(gain_per_converter.keys()), 
                                     os.path.join(folder, conversion_plot))
        

        sorted_indices = np.argsort(y)[::-1]

        all_locations = [
            self.pipeline_datasets[i][1].save_path for i in range(len(sorted_indices))
        ]
        existing_locations = [loc for loc in all_locations if loc is not None]
        try:
            common_prefix = next((existing_locations[0][:i] for i,(p,*r) in enumerate(zip(*existing_locations)) 
                                        if any(p!=c for c in r)),min(existing_locations,key=len))
            common_prefix = common_prefix[:common_prefix.rfind("/") + 1]
        except ValueError:
            common_prefix = ""

        all_locations = [
            loc if loc is not None else "-" for loc in all_locations
        ]

        logger.debug(f"Removing common prefix {common_prefix} from locations")

        datasets = []

        datasets.append({
            "id": "real",
            "y": dict_real_metrics[self.y],
            "size": self.evaluation_dataset.size(), 
            "generator": "-",
            "converters": "-", 
            "location": "-", 
            "metrics": {
                key: dict_real_metrics.get(key, 0) for key in metrics_good
            }
        })

        logger.info("Generating word distribution plots")
        word_distribution_plots = dict()

        for label in self.evaluation_dataset.get_labels():
            loc_base = os.path.join("plots", label, "word_distributions", "real.png")
            loc = os.path.join(folder, loc_base)
            if run_word_distrib_plots:
                self.word_distribution_plot("real", label, self.evaluation_dataset.get_word_distribution(label), loc)
                word_distribution_plots[label] = [loc_base]
        
        for i, index in enumerate(sorted_indices):
            logger.debug(f"Generating plots for dataset {str(self.pipeline_datasets[index][1])}")
            metrics_index = {
                key: list_dict_metrics[index].get(key, 0) for key in metrics_good
            }
            datasets.append({
                "id": i,
                "y": list_dict_metrics[index][self.y],
                "size": self.pipeline_datasets[index][1].size(), 
                "generator": generators_list.index(self.pipeline_datasets[index][0].generator),
                "converters": [converters_list.index(converter) for converter in self.pipeline_datasets[index][0].converters], 
                "location": all_locations[index][len(common_prefix):], 
                "metrics": metrics_index
            })
            if run_word_distrib_plots:
                for label in self.evaluation_dataset.get_labels():
                    loc_base = os.path.join("plots", label, "word_distributions", f"{i}.png")
                    loc = os.path.join(folder, loc_base)
                    if label in self.pipeline_datasets[index][1].get_labels():
                        self.word_distribution_plot(i, label, self.pipeline_datasets[index][1].get_word_distribution(label), 
                                                    loc, self.evaluation_dataset.get_word_distribution(label))
                        word_distribution_plots[label].append(loc_base)

        info = {
            "title": title, 
            "date": datetime.now().strftime("%d/%m/%Y, %H:%M:%S"),
            "folder": common_prefix, 
            "n_datasets": len(self.pipeline_datasets),
        }

        real = {
            "y": dict_real_metrics[self.y],
            "size": self.evaluation_dataset.size()
        }

        if run_vector_plots:
            logger.info("Generating vector plots")
            vector_plots_locations = self.generate_vector_plot(folder, sorted_indices)
        else:
            vector_plots_locations = dict()

        env = Environment(
            loader=FileSystemLoader(template_folder),
        )

        logger.info("Generating final evaluation html file")
        template = env.get_template("evaluation_template.html")
        rendered = template.render(info=info, datasets=datasets, generators=generators_jinja, converters=converters_jinja, 
                                   metrics=metrics, real=real, generator_plot=generation_plot, conversion_plot=conversion_plot, k=3, 
                                   k_word=2, k_vector=3, word_distribution_plots=word_distribution_plots, vector_plots_locations=vector_plots_locations, 
                                   metric_correlation_plot=correlation_plot_loc, metric_names=metrics_good)
        with open(os.path.join(folder, "evaluation.html"), "w", encoding="utf-8") as f:
            f.write(rendered)

    def load_all_results(self, folder):
        results_folder = os.path.join(folder, "results")
        real_metrics_path = os.path.join(results_folder, "dict_real_metrics.json")
        dict_real_metrics = None
        if os.path.exists(real_metrics_path):
            with open(real_metrics_path, "r") as f:
                dict_real_metrics = json.load(f)


        list_dict_metrics = [None for _ in range(len(self.pipeline_datasets))]
        all_result_path = f"list_dict_metrics_0-{len(list_dict_metrics) - 1}.json"
        if os.path.exists(os.path.join(results_folder, all_result_path)):
            with open(os.path.join(results_folder, all_result_path), "r") as f:
                list_dict_metrics = json.load(f)
                if isinstance(list_dict_metrics, dict):
                    list_dict_metrics = [list_dict_metrics]

        for file in os.listdir(results_folder):
            if file.startswith("list_dict_metrics_") and file != all_result_path:
                with open(os.path.join(results_folder, file), "r") as f:
                    index = int(file[len("list_dict_metrics_"):file.find("-")])
                    metrics = json.load(f)
                    if isinstance(metrics, dict):
                        metrics = [metrics]
                    for i, metrics in enumerate(metrics):
                        list_dict_metrics[i + index] = metrics

        return list_dict_metrics, dict_real_metrics
    
    def store_all_results(self, folder, list_dict_metrics, dict_real_metrics):
        results_folder = os.path.join(folder, "results")
        shutil.rmtree(results_folder)
        os.makedirs(results_folder, exist_ok=True)
        with open(os.path.join(results_folder, "dict_real_metrics.json"), "w") as f:
            json.dump(dict_real_metrics, f, indent=2, sort_keys=True)
        
        with open(os.path.join(results_folder, f"list_dict_metrics_0-{len(list_dict_metrics) - 1}.json"), "w") as f:
            json.dump(list_dict_metrics, f, indent=2, sort_keys=True)
    
    def check_results_present(self, folder, dataset_range=None, include_eval=True):
        list_dict_metrics, dict_real_metrics = self.load_all_results(folder)
        if dict_real_metrics is None and include_eval:
            return False
        
        if dataset_range is None:
            dataset_range = (0, len(self.pipeline_datasets))
        
        for i in range(*dataset_range):
            if i < len(list_dict_metrics) and list_dict_metrics[i] is None:
                return False
        
        return True

    def run(self, template_folder, folder="./metaevaluation/", n_runs=5, first_phase_cores=1, second_phase_cores=1, 
            exclude=[], force=[], title="Emotions", dataset_range=None, include_eval=True, reload_=True, real_metrics_loc=None, run_vector_plots=False, run_word_distrib_plots=False, 
            **kwargs):
        set_seed(42)
        if dataset_range is None:
            dataset_range = (0, len(self.pipeline_datasets))

        os.makedirs(os.path.join(folder, "plots", "general", "word_distributions"), exist_ok=True)
        for label in self.evaluation_dataset.get_labels():
            os.makedirs(os.path.join(folder, "plots", label, "word_distributions"), exist_ok=True)
        
        for i, (pipeline, _) in enumerate(self.pipeline_datasets):
            pipeline.save(os.path.join(folder, "pipelines", f"pipeline_{i}.json"))

        os.makedirs(os.path.join(folder, "plots", "metrics"), exist_ok=True)
        os.makedirs(os.path.join(folder, "results"), exist_ok=True)

        if real_metrics_loc is not None:
            with open(real_metrics_loc, "r") as f:
                dict_real_metrics = json.load(f)

            with open(os.path.join(folder, "results", "dict_real_metrics.json"), "w") as f:
                json.dump(dict_real_metrics, f, indent=2, sort_keys=True)

        sentences, labels = self.evaluation_dataset.get_all()
        size = len(sentences) // 2
        eval1 = Dataset(sentences[:size], labels[:size])
        eval2 = Dataset(sentences[size:], labels[size:])
        logger.info("Starting calculation evaluators")
        self.calculate_evaluation_metrics(eval1, eval2, folder, first_phase_cores, second_phase_cores, exclude, dataset_range, include_eval=include_eval, reload_=reload_, force=force, n_runs=n_runs, **kwargs)
        logger.success("Done calculating metrics.")

        list_dict_metrics, dict_real_metrics = self.load_all_results(folder)

        if dict_real_metrics is not None and all([metrics is not None for metrics in list_dict_metrics]):
            logger.info("Calculation final metrics...")
            dict_real_metrics, list_dict_metrics = self.calculate_global_evaluation_metrics(dict_real_metrics, list_dict_metrics, 
                                                                                            second_phase_cores, reload_=reload_)
            self.store_all_results(folder, list_dict_metrics, dict_real_metrics)
            logger.success("Done calculating final metrics.")
            logger.info("Generating evaluation file")
            self.generate_eval_file(list_dict_metrics, dict_real_metrics, folder, title, template_folder)
            logger.success("Done generating evaluation file.")
            logger.info("Removing temporary files")
            if os.path.exists(os.path.join(folder, "temp")):
                shutil.rmtree(os.path.join(folder, "temp"))

    @staticmethod
    def load_from_multipipeline(file, evaluation_dataset, y=None):
        multipipeline = MultiPipeline.load_datasets(file)
        if y is None:
            return MetaEvaluator(multipipeline, evaluation_dataset)
        else:
            return MetaEvaluator(multipipeline, evaluation_dataset, y)
