import matplotlib.pyplot as plt
import numpy as np
import os 
import pickle
import torch
import json
import pathlib

from itertools import product
from config import metrics

class SaveData:
    """ init does not contain res_dict as argument because """
    def __init__(self, 
                 path_save_dir:str,
                 model_name: str = "linear", 
                 data_type:str = "raw",
                 val_methods: list = ["first_window", "shift_windows", "pw_shift_windows", "pw_slide_windows"],
                 datasets_names: list = ['supervised'],
                 num_generations: int = 10, 
                 num_parents_mating: int = 5,
                 num_solutions: int = 5,
                 sample_size: int = 8,
                 train_ratio: float = 0.7, 
                 patience: int = 10,
                 delta: float = 0.01,                 
                 seed: int = 42,
                 testing : bool = False):
        
        self.model_name = model_name
        self.data_type = data_type
        self.val_methods = val_methods
        self.datasets_names = datasets_names
        self.window_size = 128 if data_type=="raw" else 0
        self.train_ratio = train_ratio
        self.num_generations = num_generations
        self.num_parents_mating = num_parents_mating
        self.num_solutions = num_solutions
        self.sample_size = sample_size
        self.patience = patience
        self.delta = delta
        self.seed = seed
        self.testing = testing
        self.path_save_dir = path_save_dir


    def create_save_dir(self):
        os.makedirs(self.path_save_dir, exist_ok=True)
    
    def res_to_cpu(self, res_dict_plots):
        """ Move the tensors to CPU and get their value. This will allow pickling"""
        if self.data_type == "raw":
            metric_boxes = ["{}_{}_box_{}".format(eval_set, metric, method) for (eval_set, metric, method) in product(["train", "val"], metrics, self.val_methods)]
        else:
            metric_boxes = ["{}_{}_box".format(eval_set, metric) for (eval_set, metric) in product(["train", "val"], metrics)]

        for metric_box in metric_boxes:
            if res_dict_plots[metric_box] == []:
                pass
            else: 
                res_dict_plots[metric_box] = [[tensor.cpu().item() for tensor in epoch_list] for epoch_list in res_dict_plots[metric_box]]
        return res_dict_plots

    def save_experiment_params(self):
        with open(f"{self.path_save_dir}/experiment_params.pkl", "wb") as f:
            pickle.dump(
                {
                    "model_name": self.model_name,
                    "data_type": self.data_type,
                    "val_methods": self.val_methods,
                    "datasets_names": self.datasets_names,
                    "train_ratio": self.train_ratio,
                    "num_generations": self.num_generations,
                    "num_parents_mating": self.num_parents_mating,
                    "num_solutions": self.num_solutions,
                    "sample_size": self.sample_size,
                    "patience": self.patience,
                    "delta": self.delta,
                    "seed": self.seed,
                    "testing": self.testing,
                },
                f,
            )

    def save_to_plot(self, res_dict_cpu):
        os.makedirs(self.path_save_dir + "/plots", exist_ok=True)
        with open(f"{self.path_save_dir}/plots/to_plot.pkl", "wb") as f:
            pickle.dump(res_dict_cpu, f)


    def save_plot_feat(self, res_dict_plots_cpu, res_dict_models):
        os.makedirs(self.path_save_dir + "/plots", exist_ok=True)
        for metric in metrics:
            save_path = self.path_save_dir + "/plots" + f"/{metric}"
            os.makedirs(save_path, exist_ok=True)
            
            epochs = np.arange(1, self.num_generations + 1)

            train = res_dict_plots_cpu[f"train_{metric}_box"]
            print(f"np.array(train).T.shape : {np.array(train).T.shape}")
            train = np.mean(np.array(train).T, axis=0)

            print(f"train shape: {train.shape}, num generations: {self.num_generations}")

            metric_list_cpu = res_dict_plots_cpu[f"val_{metric}_box"]
            metric_array = np.array(metric_list_cpu).T  # shape: (num_runs, num_epochs)
            print(f"metric_array shape: {metric_array.shape}, num runs: {metric_array.shape[0]}, num epochs: {metric_array.shape[1]}")
            val = np.mean(metric_array, axis=0)  # shape: (num_epochs,)

            best_epoch_metric = res_dict_models[f"best_model_val_{metric}_box"]
        
            print(f"epochs: {epochs}, train: {train}, val: {val}, metric_array shape: {metric_array.shape}")

            fig, ax1 = plt.subplots(figsize=(10, 6))

            box = ax1.boxplot(metric_array, positions=epochs, patch_artist=True, showfliers=False, zorder=1)
            
            scatter_train = ax1.scatter(epochs, train, label=f"Mean {metric} on Train", color="red", edgecolors='black', linewidths=0.5, s=30, zorder=3)
            scatter_val = ax1.scatter(epochs, val, label=f"Mean {metric} on Val", color="white", edgecolors='black', linewidths=0.5, s=30, zorder=3)

            earlystop_line = ax1.axvline(x=best_epoch_metric, color='orange', linestyle=':', linewidth=3, label=f"Early Stopping on {metric} median")
            
            ax1.set_title(self.model_name + " model")
            ax1.set_xlabel("Evaluations")
            ax1.set_ylabel(f"{metric} score")
            ax1.set_ylim(0, 1)
            ax1.grid(True, axis='y', linestyle=':', linewidth=1, color='gray')

            # Legend
            ax1.legend(
                [box["boxes"][0], scatter_train, scatter_val, earlystop_line],
                [f"{metric} score", f"Mean {metric} on Train", f"Mean {metric} on Val", f"Early Stopping on {metric}"],
                loc="upper right"
            )

            plt.savefig(f"{save_path}/plot.png", bbox_inches='tight')
            plt.close()

    def save_plot_cnn(self, res_dict_plots_cpu, res_dict_models):
        os.makedirs(self.path_save_dir + "/plots", exist_ok=True)
        for metric in metrics:
            save_path = self.path_save_dir + "/plots" + f"/{metric}"
            os.makedirs(save_path, exist_ok=True)

            for method in self.val_methods:

                train = res_dict_plots_cpu[f"train_{metric}_box_{method}"]
                train = np.mean(np.array(train).T, axis=0)

                metric_list_cpu = res_dict_plots_cpu[f"val_{metric}_box_{method}"]
                metric_array = np.array(metric_list_cpu).T  # shape: (num_runs, num_epochs)
                val = np.mean(metric_array, axis=0)  # shape: (num_epochs,)

                best_epoch_metric = res_dict_models[f"best_model_val_{metric}_box_{method}"]
                epochs = np.arange(1, self.num_generations + 1)

                fig, ax1 = plt.subplots(figsize=(10, 6))
                print(f"metric array shape : {metric_array.shape}, epochs shape : {epochs.shape} ")
                box = ax1.boxplot(metric_array, positions=epochs, patch_artist=True, showfliers=False, zorder=1)

                ax1.set_xlabel("Evaluations")
                ax1.set_ylabel("AUC-PR")
                ax1.set_title(self.model_name + " model")
                scatter_val = ax1.scatter(epochs, val, label=f"Mean {metric} on Val", color="white", edgecolors='black', linewidths=0.5, s=30, zorder=3)
                scatter_train = ax1.scatter(epochs, train, label=f"Mean {metric} on Train", color="red", edgecolors='black', linewidths=0.5, s=30, zorder=3)
                
                print(f"epochs: {epochs}, train: {train}, val: {val}, metric_array shape: {metric_array.shape}")
                
                earlystop_line = ax1.axvline(x=best_epoch_metric, color='orange', linestyle=':', linewidth=3, label=f"Early Stopping on {metric} median")
                ax1.grid(True, axis='y', linestyle=':', linewidth=1, color='gray')
                ax1.set_ylim(0, 1)

                # Legend
                ax1.legend(
                    [scatter_train, scatter_val, box["boxes"][0], earlystop_line],
                    [f"Mean {metric} on Train", f"Mean {metric} on Val", f"{metric} on Val", f"Early Stopping on {metric}"],
                    loc="upper right"
                )

                plt.savefig(f"{save_path}/plot_{method}.png", bbox_inches='tight')
                plt.close()


    def save_models(self, res_dict_models, stoppers):
        save_path = self.path_save_dir + "/best_models"
        os.makedirs(save_path, exist_ok=True)
        dict_best_models = {"best_model_{}".format(strategy): res_dict_models["best_model_{}".format(strategy)] for strategy in stoppers.keys()}
        with open(save_path + "/epochs_best_models.json", "w") as f:
            json.dump(dict_best_models, f)
            

    def save_results(self, res_dict_plots, res_dict_models, stoppers):
        self.create_save_dir()
        res_dict_plots_cpu = self.res_to_cpu(res_dict_plots)
        self.save_experiment_params()
        self.save_to_plot(res_dict_plots_cpu)
        self.save_models(res_dict_models = res_dict_models, stoppers = stoppers)
        if self.data_type == "raw":
            self.save_plot_cnn(res_dict_plots_cpu = res_dict_plots_cpu, 
                               res_dict_models = res_dict_models)
        else:
            self.save_plot_feat(res_dict_plots_cpu = res_dict_plots_cpu, 
                               res_dict_models = res_dict_models)
    
    def from_pkl_to_plot(self):
        with open(f"{self.path_save_dir}/to_plot.pkl", "rb") as f:
            res_dict = pickle.load(f)
        
        if self.data_type == "raw":
            self.save_plot_cnn(res_dict)
        else:
            self.save_plot_feat(res_dict)
        
        
    def save_plot_analysis(self, res_dict_models, res_dict_plots_cpu, target_metric):
        """ Considers the mertic of interest in (VUS or AUC-PPR) and compare the best models when early stopped on VUS or on val_loss or AUC-PR"""
        
        save_path = self.path_save_dir + "/plots"
        metrics_plotted = [name for name in os.listdir(save_path) if os.path.isdir(os.path.join(save_path, name))]

        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
        box_plots_metric = []
        box_plots_valloss = []
        stop_val = []
        stop_metric = []
        for metric in metrics_plotted:
            for method in self.val_methods:
                target_data = np.array(res_dict_plots_cpu[f"{target_metric}_box_{method}"]) #shape (num_epochs, num_runs) and should be (num_runs, num_epochs) for boxplots
                valloss_data = np.array(res_dict_plots_cpu[f"val_loss_{self.loss_name}_{method}"]) 
                print(f"target_data shape {target_data.shape}")

                best_target_metric = np.argmax(np.mean(target_data, axis=1)) 
                best_valloss = np.argmin(valloss_data) 
                print(f"best_target_metric {best_target_metric};  best_valloss {best_valloss}")
                
                stop_val.append(best_valloss)
                stop_metric.append(best_target_metric)

                box_plots_valloss.append(target_data[best_target_metric, :] - target_data[best_valloss, :])
                
                if metric != target_metric:
                    best_metric = res_dict_models[f"best_model_{metric}_box_{method}"]
                    box_plots_metric.append(target_data[best_target_metric, :] - target_data[best_metric, :]) 

            print(f"stop_val mean {np.mean(stop_val)}")
            print(f"stop_metric mean {np.mean(stop_metric)}")

            array_box_plots = np.array(box_plots_metric).T #shape (num_runs, num_methods)
            array_box_plots_valloss = np.array(box_plots_valloss).T #shape (num_runs, num_methods)

            print(f"array_box_plots_vallos shape {array_box_plots_valloss.shape}")
            print(f"mean array_box_plots_vallos fait main {np.mean(array_box_plots_valloss[:, 0])}")
            print(f"mean array_box_plots_vallos code {np.mean(array_box_plots_valloss, axis =0)}")

            if len(box_plots_metric) > 0: 
                ax1.boxplot(array_box_plots, positions=np.arange(1, len(self.val_methods) + 1), patch_artist=True, showfliers=False)
                ax1.xticks(np.arange(1, len(self.val_methods) + 1), self.val_methods)
                ax1.set_yscale("log")
                ax1.set_xlabel("Methods")
                ax1.set_ylabel(f"{target_metric} - {metric}")
                ax1.set_title(f"{target_metric} - {metric} with {self.model_name}")

            ax2.boxplot(array_box_plots_valloss, positions=np.arange(1, len(self.val_methods) + 1), patch_artist=True, showfliers=False)
            ax2.set_xticks(np.arange(1, len(self.val_methods) + 1), self.val_methods)
            ax2.set_ylim(-0.2, 0.2)
            ax2.set_xlabel("Methods")
            ax2.set_ylabel(f"{target_metric} difference")
            ax2.set_title(f"{target_metric} of '{target_metric} best model - val_loss best model'")

            ax3 = ax2.twinx()
            ax3.set_ylim(-0.2, 0.2)
            ax3.scatter(np.arange(1, len(self.val_methods) + 1), np.mean(array_box_plots_valloss, axis=0), color='green', edgecolors='black', linewidths=0.5, s=30, label="Mean")

            ax4 = ax2.twinx()
            ax4.set_ylim(-0.2, 0.2)
            ax4.axhline(y=0, color='black', linestyle='--', linewidth=1)
            
            print(f" Analysis plotted at : {save_path}/{target_metric}_minus_{metric}_and_val_loss.png")

            plt.savefig(f"{save_path}/{target_metric}_minus_{metric}_and_val_loss.png", bbox_inches='tight')
            plt.close()


if __name__ == "__main__":
    pass
    