import matplotlib.pyplot as plt
import numpy as np
import os 
import pickle
import torch
import json
import pathlib
from tqdm import tqdm 

from itertools import product
from config import metrics

class SaveData:
    """ init does not contain res_dict as argument because """
    def __init__(self,
                 path_save_dir:str, 
                 model_name: str = "linear", 
                 data_type:str = "raw",
                 val_methods: list = ["first_window", "shift_windows", "pw_shift_windows", "pw_slide_windows"],
                 learning_rate: list = 0.01, 
                 loss_name: str = "mse", 
                 datasets_names: list = ['supervised'],
                 train_ratio: float = 0.7, 
                 epochs: int = 100,
                 batch_size: int = 8,
                 patience: int = 10,
                 delta: float = 0.01,                 
                 seed: int = 42,
                 testing : bool = False):
        
        self.model_name = model_name
        self.data_type = data_type
        self.val_methods = val_methods
        self.datasets_names = datasets_names
        self.window_size = int(self.model_name.split("_")[2]) if data_type=="raw" else 0
        self.learning_rate = learning_rate
        self.loss_name = loss_name
        self.train_ratio = train_ratio
        self.epochs = epochs
        self.batch_size = batch_size
        self.patience = patience
        self.delta = delta
        self.seed = seed
        self.testing = testing
        self.path_save_dir = path_save_dir

    def create_save_dir(self):
        os.makedirs(self.path_save_dir, exist_ok=True)  
    
    def res_to_cpu(self, res_dict_plots):
        """ Move the tensors to CPU and get their value. This will allow pickling"""
        if self.data_type == "raw":
            metric_boxes = ["{}_box_{}".format(metric, method) for (metric, method) in product(metrics, self.val_methods)]
        else:
            metric_boxes = ["{}_box".format(metric) for metric in metrics]

        for metric_box in metric_boxes:
            if res_dict_plots[metric_box] == []:
                pass
            else: 
                res_dict_plots[metric_box] = [[tensor.cpu().item() for tensor in epoch_list] for epoch_list in res_dict_plots[metric_box]]
        return res_dict_plots

    def save_experiment_params(self):
        with open(f"{self.path_save_dir}/experiment_params.pkl", "wb") as f:
            pickle.dump(
                {
                    "model_name": self.model_name,
                    "data_type": self.data_type,
                    "val_methods": self.val_methods,
                    "learning_rate": self.learning_rate,
                    "loss_name": self.loss_name,
                    "datasets_names": self.datasets_names,
                    "train_ratio": self.train_ratio,
                    "epochs": self.epochs,
                    "batch_size": self.batch_size,
                    "patience": self.patience,
                    "delta": self.delta,
                    "seed": self.seed,
                    "testing": self.testing,
                },
                f,
            )


    def save_to_plot(self, res_dict_cpu):
        os.makedirs(self.path_save_dir + "/plots", exist_ok=True)
        with open(f"{self.path_save_dir}/plots/to_plot.pkl", "wb") as f:
            pickle.dump(res_dict_cpu, f)


    def save_plot_feat(self, res_dict_cpu):
        train_list = res_dict_cpu[f"train_loss_{self.loss_name}"]
        val_list = res_dict_cpu[f"val_loss_{self.loss_name}"]
        auc_pr_list_cpu = res_dict_cpu["AUC-PR_box"]

        metric_array = np.array(auc_pr_list_cpu).T  # shape: (num_runs, num_epochs)

        len_exp = len(train_list)
        epochs = np.arange(1, len_exp + 1)

        fig, ax1 = plt.subplots(figsize=(10, 6))

        ax1.set_xlabel("Epochs")
        ax1.set_ylabel("Loss")
        ax1.set_title(self.model_name + " model")
        ax1.plot(epochs, train_list, label="Train Loss", color="red")
        ax1.plot(epochs, val_list, label="Val Loss", color="blue")
        ax1.set_yscale("log")

        ax2 = ax1.twinx()
        ax2.set_ylabel(f"AUC-PR score")
        ax2.set_ylim(0, 1)
        ax2.grid(True, axis='y', linestyle=':', linewidth=1, color='gray')

        box2 = ax2.boxplot(metric_array, positions=epochs, patch_artist=True, showfliers=False)

        # Legend
        ax2.legend(
            [box2["boxes"][0], ax1.lines[0], ax1.lines[1]],
            [f"AUC-PR score", "Train Loss", "Val Loss"],
            loc="upper right"
        )

        plt.savefig(f"{self.path_save_dir}/plots/plot_{self.loss_name}.png", bbox_inches='tight')
        plt.close()
    

    def save_plot_cnn(self, res_dict_plots_cpu, res_dict_models, metrics):
        """ Produces the plots with all important metrics in config, early stopping, val//train_loss"""
        os.makedirs(self.path_save_dir + "/plots", exist_ok=True)
        for metric in metrics:
            save_path = self.path_save_dir + "/plots" + f"/{metric}"
            os.makedirs(save_path, exist_ok=True)

            for method in self.val_methods:

                train_list = res_dict_plots_cpu[f"train_loss_{self.loss_name}"]
                val_list = res_dict_plots_cpu[f"val_loss_{self.loss_name}_{method}"]
                metric_list_cpu = res_dict_plots_cpu[f"{metric}_box_{method}"]

                best_epoch_val = res_dict_models[f"best_model_val_loss_{self.loss_name}_{method}"]
                best_epoch_metric = res_dict_models[f"best_model_{metric}_box_{method}"]

                metric_array = np.array(metric_list_cpu).T  # shape: (num_runs, num_epochs)

                len_exp = len(train_list)
                epochs = np.arange(1, len_exp + 1)

                fig, ax1 = plt.subplots(figsize=(10, 6))

                ax1.set_xlabel("Epochs")
                ax1.set_ylabel("Loss")
                ax1.set_title(self.model_name + " model" + f" with lr = {self.learning_rate}")
                ax1.plot(epochs, train_list, label="Train", color="red")
                ax1.plot(epochs, val_list, label="Val", color="blue")
                ax1.set_yscale("log")
                ax1.axvline(x=best_epoch_val, color='blue', linestyle='--', linewidth=2, label="Early Stopping on Val")
                ax1.axvline(x=best_epoch_metric, color='orange', linestyle=':', linewidth=3, label=f"Early Stopping on {metric} median")


                # Secondary axis for AUC-PR
                ax2 = ax1.twinx()
                ax2.set_ylabel(f"{metric} score")
                ax2.set_ylim(0, 1)
                ax2.grid(True, axis='y', linestyle=':', linewidth=1, color='gray')

                box2 = ax2.boxplot(metric_array, positions=epochs, patch_artist=True, showfliers=False)

                # Legend
                ax2.legend(
                    [box2["boxes"][0], ax1.lines[0], ax1.lines[1], ax1.lines[2], ax1.lines[3]],
                    [f"{metric} score", "Train Loss", "Val Loss", "Early Stopping on Val", f"Early Stopping on {metric}"],
                    loc="upper right"
                )

                name = "_".join(f"val_loss_{self.loss_name}_{method}".split("_")[3:])
                plt.savefig(f"{save_path}/plot_{name}.png", bbox_inches='tight')
                plt.close()


    def save_models(self, res_dict_models, stoppers):
        save_path = self.path_save_dir + "/best_models"
        os.makedirs(save_path, exist_ok=True)
        dict_best_models = {"best_model_{}".format(strategy): res_dict_models["best_model_{}".format(strategy)] for strategy in stoppers.keys()}
        with open(save_path + "/epochs_best_models.json", "w") as f:
            json.dump(dict_best_models, f)
            

    def save_results(self, res_dict_plots, res_dict_models, stoppers):
        self.create_save_dir()
        res_dict_plots_cpu = self.res_to_cpu(res_dict_plots)
        self.save_experiment_params()
        self.save_to_plot(res_dict_plots_cpu)
        self.save_models(res_dict_models = res_dict_models, stoppers = stoppers)
        if self.data_type == "raw":
            self.save_plot_cnn(res_dict_plots_cpu = res_dict_plots_cpu, 
                               res_dict_models = res_dict_models, 
                               metrics = metrics)
        else:
            self.save_plot_feat(res_dict_plots_cpu)
    

    def from_pkl_to_plot(self):
        with open(f"{self.path_save_dir}/to_plot.pkl", "rb") as f:
            res_dict = pickle.load(f)
        
        if self.data_type == "raw":
            self.save_plot_cnn(res_dict)
        else:
            self.save_plot_feat(res_dict)
        
        
    def save_plot_analysis(self, res_dict_models, res_dict_plots_cpu, target_metric):
        """ Considers a target_metric (VUS or AUC-PPR) and compare the best models when early stopped other metrics like VUS / val_loss / AUC-PR"""
        
        save_path = self.path_save_dir + "/plots"
        
        metrics_plotted = [name for name in os.listdir(save_path) if os.path.isdir(os.path.join(save_path, name))]

        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
        box_plots_metric = []
        box_plots_valloss = []
        stop_val = []
        stop_metric = []
        for metric in metrics_plotted:
            for method in self.val_methods:
                target_data = np.array(res_dict_plots_cpu[f"{target_metric}_box_{method}"]) #shape (num_epochs, num_runs) and should be (num_runs, num_epochs) for boxplots
                print(f"target_data shape {target_data.shape}")

                # substract 1 to the best epoch to match the index of the data (epochs start at 1 , index at 0)
                # if no early stop epoch is 0 and index becomes -1 which is the last model as we would like it to be
                best_target_metric = res_dict_models[f"best_model_{target_metric}_box_{method}"] -1
                best_valloss = res_dict_models[f"best_model_val_loss_{self.loss_name}_{method}"] -1
                
                print(f"best_target_metric {method} : {best_target_metric};  best_valloss {method} : {best_valloss}")
                print(f"best_target_metric {best_target_metric};  best_valloss {best_valloss}")
        
                stop_val.append(best_valloss)
                stop_metric.append(best_target_metric)

                # epoch 1 is index 0 in the data. If there is no early stop then the best is 0 so is becomes 0-1 which is the last one as we would like it to be :)
                box_plots_valloss.append(target_data[best_target_metric, :] - target_data[best_valloss, :])
                
                if metric != target_metric:
                    best_metric = res_dict_models[f"best_model_{metric}_box_{method}"]
                    box_plots_metric.append(target_data[best_target_metric, :] - target_data[best_metric, :]) 

            print(f"stop_val mean {np.mean(stop_val)}")
            print(f"stop_metric mean {np.mean(stop_metric)}")

            array_box_plots = np.array(box_plots_metric).T #shape (num_runs, num_methods)
            array_box_plots_valloss = np.array(box_plots_valloss).T #shape (num_runs, num_methods)

            print(f"array_box_plots_vallos shape {array_box_plots_valloss.shape}")
            print(f"mean array_box_plots_vallos fait main {np.mean(array_box_plots_valloss[:, 0])}")
            print(f"mean array_box_plots_vallos code {np.mean(array_box_plots_valloss, axis =0)}")

            if len(box_plots_metric) > 0: 
                ax1.boxplot(array_box_plots, positions=np.arange(1, len(self.val_methods) + 1), patch_artist=True, showfliers=False)
                ax1.xticks(np.arange(1, len(self.val_methods) + 1), self.val_methods)
                ax1.set_yscale("log")
                ax1.set_xlabel("Methods")
                ax1.set_ylabel(f"{target_metric} - {metric}")
                ax1.set_title(f"{target_metric} - {metric} with {self.model_name}")

            ax2.boxplot(array_box_plots_valloss, positions=np.arange(1, len(self.val_methods) + 1), patch_artist=True, showfliers=False)
            ax2.set_xticks(np.arange(1, len(self.val_methods) + 1), self.val_methods)
            ax2.set_ylim(-0.2, 0.2)
            ax2.set_xlabel("Methods")
            ax2.set_ylabel(f"{target_metric} difference")
            ax2.set_title(f"{target_metric} of '{target_metric} best model - val_loss best model'")

            ax3 = ax2.twinx()
            ax3.set_ylim(-0.2, 0.2)
            ax3.scatter(np.arange(1, len(self.val_methods) + 1), np.mean(array_box_plots_valloss, axis=0), color='green', edgecolors='black', linewidths=0.5, s=30, label="Mean")

            ax4 = ax2.twinx()
            ax4.set_ylim(-0.2, 0.2)
            ax4.axhline(y=0, color='black', linestyle='--', linewidth=1)
            
            print(f" Analysis plotted at : {save_path}/{target_metric}_minus_{metric}_and_val_loss.png")

            plt.savefig(f"{save_path}/{target_metric}_minus_{metric}_and_val_loss.png", bbox_inches='tight')
            plt.close()


if __name__ == "__main__":
    pass

   