import pygad.torchga as torchga
import pygad
import torch
import pandas as pd
import numpy as np
import os
import time
import random
from tqdm import tqdm
from torchmetrics.classification import AveragePrecision

from sklearn.metrics import average_precision_score

from src.utils.data_loader import DataLoader
from src.utils.scores_loader import ScoresLoader
from src.timeseries_dataset import TimeseriesDataset 
from src.utils.early_stopping import early_stop_check_all

from config import path_data, path_raw, path_scores, metrics


def apply_wts_batch(scores, wts_pred, device):
    """
    Applying the weights predicted by the model to each detection anomaly scores of the detectors
    For raw data the shapes will be : 
        scor shape : torch.Size([1, ts_len, 12])
        wts_pred shape : torch.Size([1, 12])
        pred shape : torch.Size([1, ts_len])
    """
    
    all_score_w = scores * wts_pred.unsqueeze(1)
    pred = torch.mean(all_score_w, dim=2).to(device)
    return pred

def fix_length(wts_pred, scor, window_size):
    """
    New args : 
        - wts_pred : (>len_ts, 12) : weights predicted by the model
        -scor : (len_ts, 12) : scores of the detectors
        -window_size : int : size of the window used to create the dataset
    Return :
        - wts_with_overlap : (len_ts, 12) : concatenation per windows of weights predicted by the model with the same length as the scores

    The definition of shifting windows induce a length difference between wts_pred and scor.
    This function fix it. wts_pred is always longer because the concatenations of windows is longer than the score. This is due to the fact that the first and second windows overlap.
    """

    overlap_len = 128 - scor.shape[0]%128
    wts_mean = (wts_pred[(window_size-overlap_len):window_size, :] + wts_pred[window_size:window_size+overlap_len, :])/ 2
    wts_before_overlap = wts_pred[:(window_size - overlap_len), :]
    wts_after_overlap = wts_pred[window_size+overlap_len:, :]
    wts_with_overlap = torch.cat((wts_before_overlap, wts_mean, wts_after_overlap), dim=0)
    return wts_with_overlap


class EvaluatorGA:
    def __init__(self, 
                model, 
                data_type:str,
                res_dict_plots:dict, 
                res_dict_models:dict, 
                path_save_dir:str,
                stoppers:dict, 
                fnames_val:list, 
                fnames_train:list,
                train_dataset, 
                val_dataset,
                feat_train, 
                feat_val,
                num_solutions:int, 
                sample_size:int, 
                val_methods:list = ["first_window", "shift_windows", "pw_shift_windows"], 
                num_parents_mating:int = 5, 
                num_generations:int = 250,
                device:str = "cpu", 
                seed:int=42,
                testing:bool=False):
        

        self.num_generations = num_generations # Number of generations.
        self.num_parents_mating = num_parents_mating # Number of solutions to be selected as parents in the mating pool.
        self.model = model 
        self.stoppers = stoppers 
        self.res_dict_plots = res_dict_plots 
        self.res_dict_models = res_dict_models 
        self.path_save_dir = path_save_dir  
        self.fnames_val = fnames_val
        self.fnames_train = fnames_train
        self.feat_train = feat_train 
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.feat_val = feat_val
        self.device = device 
        self.data_type = data_type
        self.seed = seed
        self.sample_size=sample_size
        self.testing = testing 

        torch_ga = torchga.TorchGA(model=self.model, num_solutions=num_solutions)
        self.initial_population = torch_ga.population_weights # Initial population of network weights
        
        print(f"initial population number : {np.array(self.initial_population).shape}")

        self.val_methods = val_methods
        self.window_size = 128
        self.res_dict_plots = res_dict_plots

        self.eval_set = None
        self.fnames = None # names of raw time series
        self.eval_data = None # features data

    def get_fitness_func_raw(self, eval_set, is_best_solution):
        """
        Returns the fitness function for the validation set.
        """
        self.is_best_solution=is_best_solution
        self.eval_set = eval_set

        if self.eval_set == "train":
            self.fnames = self.fnames_train 
            self.dic_ts_to_labels = self.train_dataset.dic_ts_to_labels
            self.dic_ts_to_scores = self.train_dataset.dic_ts_to_scores
            # if is_best_solution==False:
            if not is_best_solution:
                self.fnames = random.sample(self.fnames_train, k=self.sample_size)

        if self.eval_set == "val":
            self.fnames = self.fnames_val
            self.dic_ts_to_labels = self.val_dataset.dic_ts_to_labels
            self.dic_ts_to_scores = self.val_dataset.dic_ts_to_scores
                
        def fitness_func_raw(ga_instance, solution, sol_idx):
            """
            Note : It is used to compute ga_instance.best_solution() with is_best_solution=False and used 
            to evaluate the solutions with is_best_solution=True not on the batches but the train set and the validation set.
            
            solution : models weights 
            data : model input 
            Then prediction = model(input)
            """
            print(f"is_best_solution START fitness_raw : {self.is_best_solution}")
            auc_pr_metric = AveragePrecision(task="binary")
            aucpr_lists = {method: [] for method in self.val_methods} 
            print(len(self.fnames), "files to evaluate")
            with torch.no_grad():
                for fname_csv in tqdm(self.fnames, desc=f"Evaluating the model on the {self.eval_set} set", total=len(self.fnames)):  
                    # Load the windows associated with the ts
                    df_shift = pd.read_csv(os.path.join(path_raw.format(self.window_size), fname_csv), index_col=False)
                    shift_windows = torch.tensor(df_shift.iloc[:, 2:].values, dtype=torch.float32, requires_grad=False).to(self.device) #ok

                    first_window =  shift_windows[0][np.newaxis, np.newaxis, :] 

                    # Load the labels and the scores
                    fname_ts = [fname_csv[:-4]] # remove the .csv extension exemple : "dataset/ts_name.csv" -> "dataset/ts_name"
                    
                    tic = time.time()

                    label = self.dic_ts_to_labels[fname_ts[0].split("/")[-1]]  
                    scores = self.dic_ts_to_scores[fname_ts[0].split("/")[-1]] 

                    scor = torch.tensor(scores, dtype=torch.float32, requires_grad=False).to(self.device) # scor is the first element of the list of scores, shape [1, len_ts, 12]
                    lab = torch.tensor(label, dtype=torch.float32, requires_grad=False).to(self.device) # lab is the first element of the list of labels
                    
                    tac = time.time()
                    print(f"GA : TIME IMPORT DATA : {tac - tic}")

                    # METHOD 1 : first_window
                    if "first_window" in self.val_methods:
                        action_first = pygad.torchga.predict(model=self.model, solution=solution, data=first_window)
                        action_first = action_first / torch.norm(action_first)
                        action_first = torch.tensor(action_first, dtype=torch.float32, device=self.device)
                        score_pred_first = apply_wts_batch(scor, action_first, self.device)

                        if "AUC-PR" in metrics:
                            aucpr_lists["first_window"].append(auc_pr_metric(score_pred_first, lab[None, :].long()))   

                    # METHOD 2 : Average shift_windows
                    if "shift_windows" in self.val_methods:
                        actions = [pygad.torchga.predict(model=self.model, solution=solution, data=shift_windows[k][np.newaxis, np.newaxis, :]) for k in range(shift_windows.size(0))]
                        actions = [action/torch.norm(action) for action in actions]
                        actions = torch.stack(actions)
                        action_shift = torch.mean(actions, dim=0) # action_shift is then of size (1, 12)
                        score_pred_shift = apply_wts_batch(scor, action_shift, self.device)

                        if "AUC-PR" in metrics:
                            aucpr_lists["shift_windows"].append(auc_pr_metric(score_pred_shift, lab[None, :].long()))


                    # METHOD 3 :  pw_shift_windows
                    if "pw_shift_windows" in self.val_methods:
                        action_pw_shift = []
                        for i in range(shift_windows.size(0)): #  shape (<n_windows>, <window_size>)
                            action = pygad.torchga.predict(model=self.model, solution=solution, data=shift_windows[i][torch.newaxis, np.newaxis, :]) 
                            action = action / torch.norm(action)
                            action = torch.tensor(action, dtype=torch.float32, device=self.device)
                            for j in range(shift_windows.size(1)): 
                                action_pw_shift.append(action)

                        # A fix on length has to be done as lenghts of action_pw_shift and scor can be different because of overlapping of the 1st and second windows.
                        action_pw_shift = torch.stack(action_pw_shift)# action_pw_shift of size (>len_ts, 1, 12) 
                        action_pw_shift = torch.reshape(action_pw_shift, (action_pw_shift.shape[0],action_pw_shift.shape[2])) #ok shape (>len_ts, 12)

                        print(f"scor.shape: {scor.shape}")

                        # scor = torch.reshape(scor, (scor.shape[1],scor.shape[2])) #shape (len_ts, 12)
                        action_pw_shift = fix_length(action_pw_shift, scor, self.window_size) #shape (len_ts, 12) now
                        score_pred_pw_shift  = torch.mean(scor * action_pw_shift, axis=1).to(self.device)
                        score_pred_pw_shift = torch.reshape(score_pred_pw_shift, (1,score_pred_pw_shift.shape[0]))

                        if "AUC-PR" in metrics:
                            aucpr_lists["pw_shift_windows"].append(auc_pr_metric(score_pred_pw_shift, lab[None, :].long()))
                
                if is_best_solution :
                    for method in self.val_methods:
                        if "AUC-PR" in metrics:
                            print(f"eval_set : {self.eval_set}, method : {method}, AUC-PR : {np.mean(aucpr_lists[method])}")
                            self.res_dict_plots[f"{self.eval_set}_AUC-PR_box_{method}"].append(aucpr_lists[method])
                
            self.is_best_solution = False
            self.fnames = random.sample(self.fnames_train, k=self.sample_size)

            solution_fitness_val = np.mean(aucpr_lists[self.val_methods[0]]) 
            return solution_fitness_val

        return fitness_func_raw
    
    def get_fitness_func_feat(self, eval_set, is_best_solution):
        self.eval_set = eval_set
        self.is_best_solution = is_best_solution

        if eval_set == "train":
            self.eval_data = self.feat_train 

        if eval_set == "val":
            self.eval_data = self.feat_val

        def fitness_func_feat(ga_instance, solution, sol_idx):
            """Val the model on the given dataset"""
            auc_pr_metric = AveragePrecision(task="binary")
            auc_pr_list = []
            weights_list = []

            n_ts = self.eval_data.__len__()

            if self.testing :
                n_ts = 3

            if self.is_best_solution==False: 
                n_ts = self.sample_size

            for ts_idx in tqdm(range(n_ts), desc=f"Evaluating the model on the {eval_set}", total=n_ts):
                features, all_scores, label = self.eval_data.__getitem__(ts_idx)
                
                weights_pred = pygad.torchga.predict(model=self.model, solution=solution, data=features)
                weights_pred = weights_pred / torch.norm(weights_pred)
                weights_list.append(weights_pred)
                weights_pred = torch.tensor(weights_pred, dtype=torch.float32, device=self.device) 

                all_scores = all_scores[None, :, :]
                weights_pred = weights_pred[None, :]
                label = label[None, :]
            
                score_pred = apply_wts_batch(all_scores, weights_pred, self.device)
                auc_pr_list.append(auc_pr_metric(score_pred, label.long()))
            
            if self.is_best_solution==True :
                self.res_dict_plots[f"{self.eval_set}_AUC-PR_box"].append(auc_pr_list)
                self.res_dict_plots[f"{self.eval_set}_AUC-PR_box_weights"].append(weights_list)

            self.is_best_solution = False
            print(f"mean AUC-PR (Validation on {self.eval_set}): {np.mean(auc_pr_list)}")
            return np.mean(auc_pr_list)
            
        return fitness_func_feat

    
    def get_on_generation(self):

        def on_generation(ga_instance):
            best_solution, best_solution_fitness, best_match_idx = ga_instance.best_solution()
            print(f"Generation = {ga_instance.generations_completed}")
            print(f"is_best_solution START on generation:{self.is_best_solution}")
            if self.data_type == "features" :
                fitness_func_val =  self.get_fitness_func_feat(eval_set="val", is_best_solution=True)
                fitness_val = fitness_func_val(ga_instance, best_solution, best_match_idx)
                fitness_func_train = self.get_fitness_func_feat(eval_set="train", is_best_solution=True)
                fitness_train = fitness_func_train(ga_instance, best_solution, best_match_idx)

            if self.data_type == "raw" :
                fitness_func_val =  self.get_fitness_func_raw(eval_set="val", is_best_solution=True)
                fitness_val = fitness_func_val(ga_instance, best_solution, best_match_idx)
                fitness_func_train =  self.get_fitness_func_raw(eval_set="train", is_best_solution=True)
                fitness_train = fitness_func_train(ga_instance, best_solution, best_match_idx)

            print(f"is_best_solution END on generation:{self.is_best_solution}")
            print(f"Fitness (Train) = {fitness_train} , Fitness (Test) = {fitness_val}")

            model_weights_dict = pygad.torchga.model_weights_as_dict(self.model, best_solution)
            self.model.load_state_dict(model_weights_dict)
            
            early_stop_check_all(model=self.model, epoch=ga_instance.generations_completed, stoppers=self.stoppers, res_dict_plots=self.res_dict_plots, res_dict_models=self.res_dict_models, path_save_dir=self.path_save_dir) # stop the activity of the stopper associated to a metric if earlystopping is required

        return on_generation



