import os 
import torch 
import pandas as pd
import numpy as np
from tqdm import tqdm
import pickle
import sys
import time
import random
from itertools import product

path = os.getcwd()

if path not in sys.path:
    sys.path.append(path)
    
from torch.utils.data import DataLoader, Dataset

print(os.getcwd())

from src.utils.features_dataset import TimeseriesFeaturesDataset
from src.utils.timeseries_dataset import TimeseriesDataset
from src.utils.create_split import SplitTrainTest
from src.utils.params_loader import ParamsLoader
from src.utils.model_loader import ModelLoader
from src.utils.loss_loader import LossLoader
from src.utils.early_stopping import EarlyStopper, early_stop_check_all
from src.utils.evaluate_torch import EvaluaTorch

from config import *

class Runner:
    def __init__(self,
                 path_save_dir:str,
                 model_name: str = "mlp3", 
                 data_type:str = "features",
                 val_methods: list = ["first_window", "shift_windows", "pw_shift_windows", "pw_slide_windows"], 
                 learning_rate: float = 0.01, 
                 loss_name: str = "mse", 
                 datasets_names: list = ['supervised'],
                 train_ratio: float = 0.7, 
                 epochs: int = 50,
                 batch_size: int = 8,
                 patience: int = 10,
                 delta: float = 0.01,
                 seed: int = 42,
                 device: torch.device = torch.device("cpu"), 
                 testing: bool = False):

        self.device = device
        self.model_name = model_name
        self.val_methods = val_methods
        self.data_type = data_type
        self.learning_rate = learning_rate
        self.loss_name = loss_name
        self.datasets_names = datasets_names
        self.train_ratio = train_ratio
        self.epochs = epochs
        self.batch_size = batch_size
        self.patience = patience
        self.delta = delta
        self.seed = seed
        self.testing = testing
        self.path_save_dir = path_save_dir

    
        if self.data_type == "features":
            # Create a res_dict dictionary with metrics of interest
            res_dict_stopping = {f"val_loss_{self.loss_name}" : [] for method in self.val_methods} | {"{}_box".format(metric) : [] for metric in metrics}
            self.stoppers = {"{}".format(strategy) : EarlyStopper(patience, delta) for strategy in res_dict_stopping.keys()}

            self.res_dict_plots = res_dict_stopping | {f"train_loss_{self.loss_name}": []}  
            self.res_dict_models = {"best_model_{}".format(strategy) : self.epochs for strategy in res_dict_stopping.keys()} | {"params_best_model_{}".format(strategy) : {} for strategy in res_dict_stopping.keys()}

            #paths imported from config.py
            if "supervised" in datasets_names:
                path_feat_train = path_feat_train_supervised
                path_feat_val = path_feat_val_supervised
                
            else:
                path_feat_train = path_feat_train_ood.format(datasets_names[0])
                path_feat_val = path_feat_val_ood.format(datasets_names[0])

            # Builds a Pytorch Dataset object for batches purpose
            self.train = TimeseriesFeaturesDataset(model=self.model_name, path_features=path_feat_train , path_data=path_data, path_scores=path_scores, device=self.device, testing=self.testing)              
            self.val = TimeseriesFeaturesDataset(model= self.model_name, path_features=path_feat_val , path_data=path_data, path_scores=path_scores, device=self.device, testing=self.testing)             

            # Builds a Pytorch Dataloader object to load batches from the Dataset objects
            self.train_loader = DataLoader(self.train, batch_size=self.batch_size, shuffle=True, collate_fn=TimeseriesFeaturesDataset.collate_fn)
            self.val_loader = DataLoader(self.val, batch_size=1, shuffle=False, collate_fn=TimeseriesFeaturesDataset.collate_fn)
            
            print(f"len train : {len(self.train)}")
            print(f"len val : {len(self.val)}")

        elif self.data_type == "raw": 
            self.window_size = int(model_name.split("_")[2]) # ex : "cnn_len_256" -> window_size = 256 
            
            # Create a res_dict dictionary with metrics of interest
            res_dict_stopping = {f"val_loss_{self.loss_name}_{method}" : [] for method in self.val_methods} | {"{}_box_{}".format(metric, method) : [] for (metric, method) in product(metrics, self.val_methods)}
            self.stoppers = {"{}".format(strategy) : EarlyStopper(patience, delta) for strategy in res_dict_stopping.keys()}

            self.res_dict_plots = res_dict_stopping | {f"train_loss_{self.loss_name}": []}  
            self.res_dict_models = {"best_model_{}".format(strategy) : self.epochs for strategy in res_dict_stopping.keys()} 
            
            # lists of names of .csv files containing all windows for 1 ts
            if "supervised" in datasets_names :
                self.fnames_train, self.fnames_val, fnames_test = SplitTrainTest(path_features=path_features, datasets_names=self.datasets_names, train_ratio=self.train_ratio, seed=self.seed).create_splits_raw(read_from_file=path_splits)
            
            else :
                self.fnames_train, self.fnames_val, fnames_test = SplitTrainTest(path_features=path_features, datasets_names=self.datasets_names, train_ratio=self.train_ratio, seed=self.seed).create_splits_raw(read_from_file=path_splits_ood.format(datasets_names[0]))
            
            print(f"self.fnames_train : {len(self.fnames_train)}")
            print(f"self.fnames_val : {len(self.fnames_val)}")

            if self.testing:
                self.fnames_train = self.fnames_train[:2]
                self.fnames_val = self.fnames_val[:2]
                self.fnames_test = fnames_test
                print(f"self.fnames_train : {self.fnames_train}")
                print(f"self.fnames_val : {self.fnames_val}")
                print(f"fnames_test : {fnames_test}")
            
            # Builds a Pytorch Dataset object for batches purpose
            self.train = TimeseriesDataset(fnames_ts=self.fnames_train, path_data=path_data, path_raw=path_raw, path_scores=path_scores, window_size=self.window_size, device=self.device)
            
            # Builds a Pytorch Dataloader object to load batches from the Dataset objects
            self.train_loader = DataLoader(self.train, batch_size=self.batch_size, shuffle=True, collate_fn=TimeseriesDataset.collate_fn)
            self.val_loader = None
            
        else:
            raise ValueError(f"Unknown data_type spotted : {self.data_type}, try with 'raw' or 'features'")
       
        self.model_params = ParamsLoader(self.model_name).load_params()

        # Load the model of interest
        self.model, self.model_type = ModelLoader(self.model_name, self.model_params, seed=self.seed, device=self.device).load_model()

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate, betas=(0.9, 0.98), eps=1e-9)
        self.loss_fn = LossLoader(loss_name=loss_name).get_loss()
        self.evaluator = EvaluaTorch(self.model, self.train_loader, self.val_loader, self.optimizer, self.loss_fn, self.loss_name, self.device)


    def evaluate_feat(self):
        tic_eval = time.time()
        for k in range(self.epochs):

            tic_epoch = time.time()
            self.res_dict_plots[f"train_loss_{self.loss_name}"].append(self.evaluator.train_feat())
            print(f"Training 1 epoch took {time.time()-tic_epoch} seconds")
            
            tic_epoch = time.time()
            val_loss, auc_pr_box = self.evaluator.validation_feat()
            print(f"Validating 1 epoch took {time.time()-tic_epoch} seconds")

            self.res_dict_plots[f"val_loss_{self.loss_name}"].append(val_loss)
            self.res_dict_plots["AUC-PR_box"].append(auc_pr_box)

            # stop the activity of the stopper associated to a metric if earlystopping is required
            early_stop_check_all(model=self.model, 
                                 epoch=k+1, 
                                 stoppers=self.stoppers, 
                                 res_dict_plots=self.res_dict_plots, 
                                 res_dict_models=self.res_dict_models, 
                                 path_save_dir=self.path_save_dir) 

            print(self.res_dict_plots['AUC-PR_box'][-1])
            if self.device != 'cpu':
                print(f"Epoch {k+1}/{self.epochs} : Train Loss: {self.res_dict_plots[f"train_loss_{self.loss_name}"][-1]:>8f} | Val Loss: {self.res_dict_plots[f'val_loss_{self.loss_name}'][-1]:>8f} | AUC PR last median: {torch.median(torch.stack(self.res_dict_plots['AUC-PR_box'][-1])):>8f}")
            else :
                print(f"Epoch {k+1}/{self.epochs} : Train Loss: {self.res_dict_plots[f"train_loss_{self.loss_name}"][-1]:>8f} | Val Loss: {self.res_dict_plots[f'val_loss_{self.loss_name}'][-1]:>8f} | AUC PR last median: {np.median(self.res_dict_plots['AUC-PR_box'][-1]):>8f}")
    
        print(f"Total evaluation time: {time.time()-tic_eval} seconds")
        return self.res_dict_plots, self.res_dict_models, self.stoppers
    
    def evaluate_cnn(self):
        tic_eval = time.time()
        for k in range(self.epochs):

            tic_epoch = time.time()
            self.evaluator.train_cnn(self.res_dict_plots)
            print(f"Training 1 epoch took {time.time()-tic_epoch} seconds")

            tic_epoch = time.time()
            self.evaluator.validation_cnn(self.fnames_val, self.window_size, self.res_dict_plots, self.val_methods)
            print(f"Validating 1 epoch took {time.time()-tic_epoch} seconds")       

            # stop the activity of the stopper associated to a metric if earlystopping is required
            early_stop_check_all(model=self.model, 
                                epoch=k+1, 
                                stoppers=self.stoppers, 
                                res_dict_plots=self.res_dict_plots, 
                                res_dict_models=self.res_dict_models, 
                                path_save_dir = self.path_save_dir) 

            for method in self.val_methods:
                if self.device != 'cpu':
                    print(f"Epoch {k+1}/{self.epochs} : Train Loss: {self.res_dict_plots[f'train_loss_{self.loss_name}'][-1]:>8f} ")
                    print( f"|Val Loss {method}: {self.res_dict_plots['val_loss_{}_{}'.format(self.loss_name, method)][-1]:>8f} | AUC PR last median: {torch.median(torch.stack(self.res_dict_plots['AUC-PR_box_{}'.format(method)][-1])):>8f}")
                else :
                    print(f"Epoch {k+1}/{self.epochs} : Train Loss: {self.res_dict_plots[f'train_loss_{self.loss_name}'][-1]:>8f} ")
                    print( f"|Val Loss shift: {self.res_dict_plots['val_loss_{}_{}'.format(self.loss_name, method)][-1]:>8f} | AUC PR last median: {np.median(self.res_dict_plots['AUC-PR_box_{}'.format(method)][-1]):>8f}")
    
        print(f"Total evaluation time: {time.time()-tic_eval} seconds")
        return self.res_dict_plots, self.res_dict_models, self.stoppers


    def evaluate(self):
        """Évalue le modèle selon son type"""
        if self.data_type=="features":
            evaluation = self.evaluate_feat()
        elif self.data_type=="raw":
            evaluation = self.evaluate_cnn()
        else:
            raise ValueError(f"Unknown data type: {self.data_type}. Accepted types are 'raw' and 'features'.")            
        
        return evaluation


