from tqdm import tqdm
import torch
import time
import numpy as np
import pandas as pd
import os

from torchmetrics.classification import AveragePrecision
from src.utils.timeseries_dataset import make_df_slide

from src.utils.data_loader import DataLoader
from src.utils.scores_loader import ScoresLoader
from src.utils.create_split import SplitTrainTest
from src.utils.model_loader import ModelLoader
from src.utils.loss_loader import LossLoader
from src.utils.params_loader import ParamsLoader

from config import *

torch.set_printoptions(threshold=200)  # Affiche plus d’éléments avant le "..."


def apply_wts_batch(scores, wts_pred, device):
    """
    Applying the weights predicted by the model to each detection anomaly scores of the detectors
    For raw data the shapes will be : 
        scor shape : torch.Size([1, 256, 12])
        wts_pred shape : torch.Size([1, 12])
        pred shape : torch.Size([1, 256])
    """
    
    all_score_w = scores * wts_pred.unsqueeze(1)
    pred = torch.mean(all_score_w, dim=2).to(device)
    return pred


class EvaluaTorch:
    def __init__(self, model, train_loader, val_loader, optimizer, loss_fn, loss_name, device):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        self.loss_name = loss_name
        self.device = device


    def fix_length(self, wts_pred, scor, window_size):
        """
        New args : 
            - wts_pred : (>len_ts, 12) : weights predicted by the model
            -scor : (len_ts, 12) : scores of the detectors
            -window_size : int : size of the window used to create the dataset
        Return :
            - wts_with_overlap : (len_ts, 12) : concatenation per windows of weights predicted by the model with the same length as the scores

        The definition of shifting windows induce a length difference between wts_pred and scor.
        This function fix it. wts_pred is always longer because the concatenations of windows is longer than the score. This is due to the fact that the first and second windows overlap.
        """

        overlap_len = 128 - scor.shape[0]%128
        wts_mean = (wts_pred[(window_size-overlap_len):window_size, :] + wts_pred[window_size:window_size+overlap_len, :])/ 2
        wts_before_overlap = wts_pred[:(window_size - overlap_len), :]
        wts_after_overlap = wts_pred[window_size+overlap_len:, :]
        wts_with_overlap = torch.cat((wts_before_overlap, wts_mean, wts_after_overlap), dim=0)
        return wts_with_overlap


    def train_feat(self):
        """ Train the torch model for 1 epoch """
        self.model.train() #affect nn.Dropout et nn.BatchNorm et 
        train_loss = 0
        # n_ts = 0

        for batch_idx, (feat, scor, lab) in tqdm(enumerate(self.train_loader), desc="Training the model", total=len(self.train_loader)):
            tic = time.time()
            batch_loss = 0
            n_batches = batch_idx + 1
            batch_size = feat.size(0)
            print(f"batch_size : {batch_size}\n")
            
            self.optimizer.zero_grad()
            
            # Compute the weights
            wts_pred = self.model(feat)
            print(f"scor shape : {scor.shape}\n")
            print(f"wts_pred shape : {wts_pred.shape}\n")

            # Apply the weights to the scores
            score_pred = apply_wts_batch(scor, wts_pred, self.device)
            print(f"score_pred shape : {score_pred.shape}\n")

            if self.loss_name == "bce":
                weights = torch.where(lab == 0, torch.tensor(0.1), torch.tensor(0.9)) # weights 0.1 for class 0 and 0.9 for class 1
                self.loss_fn.weight = weights
            
            # Compute the loss
            print(f" lab shape : {lab.shape}\n")
            loss = self.loss_fn(score_pred, lab.float())
                
            loss.backward()

            for name, param in self.model.named_parameters():
                print(f"Gradients for {name}: {torch.norm(param.grad)}")

            self.optimizer.step()
            batch_loss = loss.item()
            train_loss += loss.item()

            print("\n")
            print(f"bacth Loss : {batch_loss}\n")
            print(f"Average Loss over passed data : {train_loss/n_batches}\n")
            print(f"Training on 1 batch took : {time.time()-tic} seconds\n")
        print(f"Train Loss:{train_loss/n_batches:>8f} (Average Loss over batches) \n")
        return train_loss/n_batches


    def validation_feat(self):
        """Val the model on the given dataset"""
        auc_pr_metric = AveragePrecision(task="binary")
        self.model.eval()
        val_loss = 0
        auc_pr_list = []
        with torch.no_grad():
            for batch_idx, (feat, scor, lab) in tqdm(enumerate(self.val_loader), desc="Validating the model", total=len(self.val_loader)):
                feat, scor, lab = feat.to(self.device), scor.to(self.device), lab.to(self.device)
                
                n_batches = batch_idx + 1
                
                # Prédictions
                wts_pred = self.model(feat)
                score_pred = apply_wts_batch(scor, wts_pred, self.device)

                if self.loss_name == "bce":
                    weights = torch.where(lab == 0, torch.tensor(0.1), torch.tensor(0.9)) # weights 0.1 for class 0 and 0.9 for class 1
                    self.loss_fn.weight = weights
                    
                # Compute AUC-PR for every batch item 
                for i in range(lab.shape[0]):  # Loop on every batch item
                    auc_pr_list.append(auc_pr_metric(score_pred[i], lab[i].long()))
                
                val_loss += self.loss_fn(score_pred, lab.float()).item() 
                            
        print(f"Validation Loss: {val_loss/n_batches:.6f} (Average Loss)\n") # Normalisation by the number of val samples

        return val_loss/n_batches, auc_pr_list
    

    def train_cnn(self, res_dict):
        """ Train the cnn model for 1 epoch """
        self.model.train() #affect nn.Dropout et nn.BatchNorm et 
        train_loss = 0

        for batch_idx, (ts_window, scor, lab) in tqdm(enumerate(self.train_loader), desc="Training the model", total=len(self.train_loader)):
            batch_loss = 0
            n_batches = batch_idx + 1

            self.optimizer.zero_grad()
            
            wts_pred = self.model(ts_window)
            score_pred = apply_wts_batch(scor, wts_pred, self.device)
            
            if self.loss_name == "bce":
                weights = torch.where(lab == 0, torch.tensor(0.1, device=self.device), torch.tensor(0.9, device=self.device)) # weights 0.1 for class 0 and 0.9 for class 1
                self.loss_fn.weight = weights
            
            loss = self.loss_fn(score_pred, lab.float())

            if self.loss_name == "rmse":
                loss = torch.sqrt(loss)

            loss.backward()

            self.optimizer.step()
            batch_loss = loss.item()
            train_loss += loss.item()

            print("\n")
            print(f"bacth Loss : {batch_loss}\n")
    
        print(f"Train Loss:{train_loss/n_batches:>8f} (Average Loss over batches) \n")

        #Update the dictionary with the results of the experiments
        res_dict[f"train_loss_{self.loss_name}"].append(train_loss/n_batches)

        return train_loss/n_batches


    def validation_cnn(self, fnames_val, window_size, res_dict, val_methods):
        """
        Val the model on the given dataset
        Note : 
            -  In this function loss_fn.weigth is defined based on lab od size [1, len_ts] so every usage of loss_fn should
        consider a lab of size [1, len_ts] and not [len_ts]
            - Sanity check for indexes in METHOD 4: 
                -- wts_pred : we move from one batch to an other, 1 wts per window 
                -- wt_pred_pw_slide : the difference in the indexes is the window_size AND we slide from the start of the current batch to just before the start of the next with i which goes up to window_batch_size.
        """
        auc_pr_metric = AveragePrecision(task="binary")
        self.model.eval()
        n_ts = len(fnames_val)
        val_losses = {method: 0.0 for method in val_methods}
        aucpr_lists = {method: [] for method in val_methods} 
        ffvus_lists = {method: [] for method in val_methods}

        with torch.no_grad():
            for fname_csv in tqdm(fnames_val, desc="Validating the model", total=len(fnames_val)):  
                # Load the windows associated with the ts
                df_shift = pd.read_csv(os.path.join(path_raw.format(window_size), fname_csv), index_col=False)

                shift_windows = torch.tensor(df_shift.iloc[:, 2:].values, dtype=torch.float32, requires_grad=False).to(self.device) #ok
                first_window =  shift_windows[0][np.newaxis, np.newaxis, :] #ok, first window is the first row of the shift_windows tensor

                # Load the labels and the scores
                fname_ts = [fname_csv[:-4]] # remove the .csv extension exemple : "dataset/ts_name.csv" -> "dataset/ts_name"
                ts, label, _ = DataLoader(path_data).load_timeseries(fname_ts) # load_timeseries wants a list of file names as input
                scores, _ = ScoresLoader(path_scores).load(fname_ts)

                ts = torch.tensor(ts[0], dtype=torch.float32, requires_grad=False).to(self.device) # ts is the first element of the list of timeseries
                scor = torch.tensor(scores, dtype=torch.float32, requires_grad=False).to(self.device) # scor is the first element of the list of scores, shape [1, len_ts, 12]
                lab = torch.tensor(label, dtype=torch.float32, requires_grad=False).to(self.device) # lab is the first element of the list of labels

                # df_slide = make_df_slide(ts, window_size) # Each row of df_slide is a window
                # slide_windows = torch.tensor(df_slide.iloc[:, :].values, dtype=torch.float32, requires_grad=False).to(self.device)

                if self.loss_name == "bce":
                    weights = torch.where(lab == 0, torch.tensor(0.1, device=self.device), torch.tensor(0.9, device=self.device)) # weights 0.1 for class 0 and 0.9 for class 1
                    self.loss_fn.weight = weights

                # METHOD 1 : first_window
                if "first_window" in val_methods:
                    wts_pred_first = self.model(first_window)
                    score_pred_first = apply_wts_batch(scor, wts_pred_first, self.device)
                    loss = self.loss_fn(score_pred_first, lab.float())
                    if self.loss_name == "rmse":
                        loss = torch.sqrt(loss)

                    val_losses["first_window"] += loss.item() 
                    if "AUC-PR" in metrics:
                        aucpr_lists["first_window"].append(auc_pr_metric(score_pred_first, lab.long()))
                    if "FFVUS" in metrics:
                        # print(f"vus : {VUSTorch(device=self.device).compute(label=lab[0,:],score=score_pred_first[0,:])[0]}")
                        ffvus_lists["first_window"].append(VUSTorch(device=self.device).compute(label=lab[0,:],score=score_pred_first[0,:])[0])
                                

                # METHOD 2 : Average shift_windows
                if "shift_windows" in val_methods:
                    input = torch.cat([shift_windows[k][np.newaxis, np.newaxis, :] for k in range(shift_windows.size(0))], dim=0)
                    wts_pred_shift = torch.mean(self.model(input), dim=0)[torch.newaxis, :] # wts_pred_shift of size (1, 12)
                    score_pred_shift = apply_wts_batch(scor, wts_pred_shift, self.device)
                    # auc_pr_list_shift.append(auc_pr_metric(score_pred_shift, lab.long()))

                    loss = self.loss_fn(score_pred_shift, lab.float())
                    if self.loss_name == "rmse":
                        loss = torch.sqrt(loss)

                    val_losses["shift_windows"] += loss.item() 
                    if "AUC-PR" in metrics:
                        aucpr_lists["shift_windows"].append(auc_pr_metric(score_pred_shift, lab.long()))
                    # if "VUS" in metrics:
                        # vus_lists["shift_windows"].append(VUSTorch(device=self.device).compute(label=lab[0,:],score=score_pred_shift[0,:]))
                    if "FFVUS" in metrics:
                        ffvus_lists["shift_windows"].append(VUSTorch(device=self.device).compute(label=lab[0,:],score=score_pred_shift[0,:])[0])
                            

                # METHOD 3 :  pw_shift_windows
                if "pw_shift_windows" in val_methods:
                    wts_pred_pw_shift = []
                    for i in range(shift_windows.size(0)): #  shape (<n_windows>, <window_size>)
                        wts_pred = self.model(shift_windows[i][torch.newaxis, np.newaxis, :]) 
                        for j in range(shift_windows.size(1)): 
                            wts_pred_pw_shift.append(wts_pred)

                    # A fix on length has to be done as lenghts of wts_pred_pw_shift and scor can be different because of overlapping of the 1st and second windows.
                    wts_pred_pw_shift = torch.stack(wts_pred_pw_shift)# wts_pred_pw_shift of size (>len_ts, 1, 12) 
                    wts_pred_pw_shift = torch.reshape(wts_pred_pw_shift, (wts_pred_pw_shift.shape[0],wts_pred_pw_shift.shape[2])) #ok shape (>len_ts, 12)

                    scor = torch.reshape(scor, (scor.shape[1],scor.shape[2])) #shape (len_ts, 12)
                    wts_pred_pw_shift = self.fix_length(wts_pred_pw_shift, scor, window_size) #shape (len_ts, 12) now
                    score_pred_pw_shift  = torch.mean(scor * wts_pred_pw_shift, axis=1).to(self.device)
                    score_pred_pw_shift = torch.reshape(score_pred_pw_shift, (1,score_pred_pw_shift.shape[0]))
                    # auc_pr_list_pw_shift.append(auc_pr_metric(score_pred_pw_shift, lab.long()))
                    loss = self.loss_fn(score_pred_pw_shift, lab.float())
                    if self.loss_name == "rmse":
                        loss = torch.sqrt(loss)
                    # val_loss_pw_shift += loss.item()

                    val_losses["pw_shift_windows"] += loss.item() 
                    if "AUC-PR" in metrics:
                        aucpr_lists["pw_shift_windows"].append(auc_pr_metric(score_pred_pw_shift, lab.long()))
                            


            # Update results dict
            print(f"res_dict keys : {res_dict.keys()}\n")

            for method in val_methods:
                res_dict[f"val_loss_{self.loss_name}_{method}"].append(val_losses[method] / n_ts)
                if "AUC-PR" in metrics:
                    res_dict[f"AUC-PR_box_{method}"].append(aucpr_lists[method])
               

    
if __name__ == "__main__":
    # validation_cnn
    window_size = 128
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")   
    res_dict = {
        "train_loss": [],

        "val_loss_first_window": [],
        "val_loss_shift_windows": [],
        "val_loss_pw_shift_windows": [],
        "val_loss_pw_slide_windows": [],

        "auc_pr_box_first_window": [],
        "auc_pr_box_shift_windows": [],
        "auc_pr_box_pw_shift_windows": [],
        "auc_pr_box_pw_slide_windows": []

    }

    fnames_train,fnames_val, fnames_test = SplitTrainTest(path_features, ['KDD21'] , 0.7, 42).create_splits_raw(read_from_file=path_splits.format(window_size))
    model_params = ParamsLoader(model_name="cnn_len_128").load_params()
    model, model_type = ModelLoader("cnn_len_128", model_params, seed=42, device=device).load_model()
    loss_fn = LossLoader(loss_name="mse").get_loss()

    evaluator = EvaluaTorch(model=model, train_loader=None, val_loader=None, optimizer=None, loss_fn=loss_fn, loss_name="adam", device=device)
    print("Starting validation\n")
    evaluator.validation_cnn(fnames_val, window_size, res_dict={})
