from tqdm import tqdm
import numpy as np
import torch
import pandas as pd
import os
import time
from pycatch22 import catch22_all

from config import *
from torchmetrics.classification import AveragePrecision    
from src.utils.data_loader import DataLoader
from src.utils.scores_loader import ScoresLoader

def apply_wts_batch(scores, wts_pred, device):
    """
    Applying the weights predicted by the model to each detection anomaly scores of the detectors
    For raw data the shapes will be : 
        scor shape : torch.Size([1, ts_len, 12])
        wts_pred shape : torch.Size([1, 12])
        pred shape : torch.Size([1, ts_len])
    """
    
    all_score_w = scores * wts_pred.unsqueeze(1)
    pred = torch.mean(all_score_w, dim=2).to(device)
    return pred

class Validation:
    def __init__(self, device, testing):
        self.testing = testing 
        self.device = device

    
    def fix_length(self, wts_pred, scor, window_size):
        """
        New args : 
            - wts_pred : (>len_ts, 12) : weights predicted by the model
            -scor : (len_ts, 12) : scores of the detectors
            -window_size : int : size of the window used to create the dataset
        Return :
            - wts_with_overlap : (len_ts, 12) : concatenation per windows of weights predicted by the model with the same length as the scores

        The definition of shifting windows induce a length difference between wts_pred and scor.
        This function fix it. wts_pred is always longer because the concatenations of windows is longer than the score.
        This is due to the fact that the first and second windows overlap.
        """
        if scor.shape[0]%128 == 0 :
            return wts_pred
        else: 
            overlap_len = 128 - scor.shape[0]%128
            wts_mean = (wts_pred[(window_size-overlap_len):window_size, :] + wts_pred[window_size:window_size+overlap_len, :])/ 2
            wts_before_overlap = wts_pred[:(window_size - overlap_len), :]
            wts_after_overlap = wts_pred[window_size+overlap_len:, :]
            wts_with_overlap = torch.cat((wts_before_overlap, wts_mean, wts_after_overlap), dim=0)
            return wts_with_overlap


    def evaluate_raw(self, model, fnames_eval, window_size, res_dict, val_methods, eval_set):
        """
        Val the model on the given dataset
        Note : 
            -  In this function loss_fn.weigth is defined based on lab od size [1, len_ts] so every usage of loss_fn should
        consider a lab of size [1, len_ts] and not [len_ts]
            - Sanity check for indexes in METHOD 4: 
                -- wts_pred : we move from one batch to an other, 1 wts per window 
                -- wt_pred_pw_slide : the difference in the indexes is the window_size AND we slide from the start of the current batch to just before the start of the next with i which goes up to window_batch_size.
        """
        auc_pr_metric = AveragePrecision(task="binary")     
        aucpr_lists = {method: [] for method in val_methods} 
        ffvus_lists = {method: [] for method in val_methods}

        with torch.no_grad():
            for fname_csv in tqdm(fnames_eval, desc=f"Evaluating the model on the {eval_set}", total=len(fnames_eval)):  
                # Load the windows associated with the ts
                df_shift = pd.read_csv(os.path.join(path_raw.format(window_size), fname_csv), index_col=False)
                shift_windows = torch.tensor(df_shift.iloc[:, 2:].values, dtype=torch.float32, requires_grad=False).to(self.device) #ok
                
                first_window =  shift_windows[0][np.newaxis, np.newaxis, :] #ok, first window is the first row of the shift_windows tensor

                # Load the labels and the scores
                fname_ts = [fname_csv[:-4]] # remove the .csv extension exemple : "dataset/ts_name.csv" -> "dataset/ts_name"
                
                tic = time.time()

                ts, label, _ = DataLoader(path_data).load_timeseries(fname_ts) # load_timeseries wants a list of file names as input
                scores, _ = ScoresLoader(path_scores).load(fname_ts)

                ts = torch.tensor(ts[0], dtype=torch.float32, requires_grad=False).to(self.device) # ts is the first element of the list of timeseries
                scor = torch.tensor(scores, dtype=torch.float32, requires_grad=False).to(self.device) # scor is the first element of the list of scores, shape [1, len_ts, 12]
                lab = torch.tensor(label, dtype=torch.float32, requires_grad=False).to(self.device) # lab is the first element of the list of labels

                tac = time.time()
                print(f"RL : TIME EVAL IMPORT DATA : {tac - tic}")

                # METHOD 1 : first_window
                if "first_window" in val_methods:
                    action_first, _ = model.predict(first_window, deterministic=True)
                    action_first = action_first / np.linalg.norm(action_first)
                    action_first = torch.tensor(action_first, dtype=torch.float32, device=self.device)
                    score_pred_first = apply_wts_batch(scor, action_first[None, :], self.device)

                    if "AUC-PR" in metrics:
                        aucpr_lists["first_window"].append(auc_pr_metric(score_pred_first, lab.long()))    

                # METHOD 2 : Average shift_windows
                if "shift_windows" in val_methods:
                    tic = time.time()

                    tic_predict = time.time()
                    # actions = [model.predict(shift_windows[k][np.newaxis, np.newaxis, :])[0] for k in range(shift_windows.size(0))]
                    actions = [model.predict(shift_windows[k][None, None, :])[0] for k in range(shift_windows.size(0))]
                    tac_predict = time.time()
                    print(f" RL : TIME PREDICT ACTIONS SHIFT_WINDOW : {tac_predict - tic_predict}")
                    
                    actions = [action/np.linalg.norm(action) for action in actions]
                    actions = torch.tensor([action[None, :] for action in actions], dtype=torch.float32, device=self.device)
                    action_shift = torch.mean(actions, dim=0) # action_shift is then of size (1, 12)
                    score_pred_shift = apply_wts_batch(scor, action_shift, self.device)

                    if "AUC-PR" in metrics:
                        aucpr_lists["shift_windows"].append(auc_pr_metric(score_pred_shift, lab.long()))
                    tac = time.time()
                    print(f" RL : TIME EVAL SHIFT_WINDOW : {tac-tic}")

                # METHOD 3 :  pw_shift_windows
                if "pw_shift_windows" in val_methods:
                    action_pw_shift = []
                    for i in range(shift_windows.size(0)): #  shape (<n_windows>, <window_size>)
                        action, _ = model.predict(shift_windows[i][torch.newaxis, np.newaxis, :]) 
                        action = action / np.linalg.norm(action)
                        action = torch.tensor(action[None, :], dtype=torch.float32, device=self.device)
                        for j in range(shift_windows.size(1)): 
                            action_pw_shift.append(action)

                    # A fix on length has to be done as lenghts of action_pw_shift and scor can be different because of overlapping of the 1st and second windows.
                    action_pw_shift = torch.stack(action_pw_shift)# action_pw_shift of size (>len_ts, 1, 12) 
                    action_pw_shift = torch.reshape(action_pw_shift, (action_pw_shift.shape[0],action_pw_shift.shape[2])) #ok shape (>len_ts, 12)

                    scor = torch.reshape(scor, (scor.shape[1],scor.shape[2])) #shape (len_ts, 12)
                    action_pw_shift = self.fix_length(action_pw_shift, scor, window_size) #shape (len_ts, 12) now
                    score_pred_pw_shift  = torch.mean(scor * action_pw_shift, axis=1).to(self.device)
                    score_pred_pw_shift = torch.reshape(score_pred_pw_shift, (1,score_pred_pw_shift.shape[0]))

                    if "AUC-PR" in metrics:
                        aucpr_lists["pw_shift_windows"].append(auc_pr_metric(score_pred_pw_shift, lab.long()))
                    
            # Update results in dict
            print(f"res_dict keys : {res_dict.keys()}\n")

            for method in val_methods:
                if "AUC-PR" in metrics:
                    res_dict[f"{eval_set}_AUC-PR_box_{method}"].append(aucpr_lists[method])
                    print(f"mean AUC-PR for the {method} method : {np.mean(aucpr_lists[method])}")
                
    def evaluate_feat(self, model, eval_data, res_dict_plots, eval_set):
        """Val the model on the given dataset"""
        auc_pr_metric = AveragePrecision(task="binary")
        auc_pr_list = []
        weights_list = []

        if self.testing :
            n_ts = 5
        else : 
            n_ts = eval_data.__len__()

        for ts_idx in tqdm(range(n_ts), desc=f"Evaluating the model on the {eval_set}", total=n_ts):
            features, all_scores, label = eval_data.__getitem__(ts_idx) 
            action, _ = model.predict(features, deterministic=True)
            action = action / np.linalg.norm(action)
            weights_list.append(action)
            action = torch.tensor(action, dtype=torch.float32, device=self.device) 

            all_scores = all_scores[None, :, :]
            action = action[None, :]
            label = label[None, :]
           
            score_pred = apply_wts_batch(all_scores, action, self.device)
            auc_pr_list.append(auc_pr_metric(score_pred, label.long()))
        
        res_dict_plots[f"{eval_set}_AUC-PR_box"].append(auc_pr_list)
        res_dict_plots[f"{eval_set}_AUC-PR_box_weights"].append(weights_list)
    
        print(f"mean AUC-PR (Validation on {eval_set}): {np.mean(auc_pr_list)}")

        return weights_list

