import os 
import torch 
import pandas as pd
import numpy as np
import gym
import tqdm
import pickle
import sys
import time
import random
import pygad
import pygad.torchga as torchga
from itertools import product
from tqdm import tqdm
from catch22 import catch22_all


from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from torchmetrics.classification import AveragePrecision

from src.features_dataset import TimeseriesFeaturesDataset
from src.create_split import SplitTrainTest 
from src.timeseries_dataset import TimeseriesDataset

from src.models.mlp_default import MLPDefault
from src.models.cnn_128 import ConvNet
from src.utils.utils_ga import EvaluatorGA, apply_wts_batch


from src.utils.early_stopping import EarlyStopper, early_stop_check_all
from stable_baselines3.common.callbacks import CallbackList


from itertools import product

from config import *

path = os.getcwd()
print(path)
if path not in sys.path:
    sys.path.append(path)
    

class Runner:
    def __init__(self,
                path_save_dir:str,
                model_name: str = "cnn_128",
                data_type :str = "raw",
                val_methods : list = ["first_window", "shift_windows", "pw_shift_windows"],
                datasets_names : list = ['supervised'],
                train_ratio : float = 0.7,
                num_generations : int = 1,
                num_parents_mating : int = 5,
                num_solutions : int = 5,
                sample_size: int = 32,
                patience: int = 10,
                delta: float = 0.01,
                seed: int = 42,
                device: torch.device = torch.device("cpu"), 
                testing: bool = False):

        self.device = device
        self.model_name = model_name
        self.val_methods = val_methods
        self.data_type = data_type
        self.datasets_names = datasets_names
        self.train_ratio = train_ratio
        self.num_generations = num_generations
        self.num_parents_mating = num_parents_mating
        self.sample_size = sample_size
        self.patience = patience
        self.delta = delta
        self.seed = seed
        self.testing=testing
        self.path_save_dir = path_save_dir

        self.num_solutions = num_solutions
        self.window_size = 128

        #variables to be updated depending on data_type
        self.fnames_train = None
        self.fnames_val = None

        self.train = None 
        self.val = None 

        # Create a res_dict dictionary with metrics of interest
        print(f"val_methods : {self.val_methods}")
        print(f"metrics : {metrics}")
        
        if self.data_type == "features":

            res_dict_stopping = {"val_{}_box".format(metric) : [] for (metric) in metrics}
            self.res_dict_plots = res_dict_stopping | {"train_{}_box".format(metric) : [] for metric in metrics} | {"{}_{}_box_weights".format(eval_set, metric) : [] for (eval_set, metric) in product(["train", "val"], metrics)} 

            # Create an early stopping method for each strategy (metric, val_method)
            self.stoppers = {"{}".format(strategy) : EarlyStopper(patience, delta) for strategy in res_dict_stopping.keys()}
            # Dictionnary including all results
            self.res_dict_models = {"best_model_{}".format(strategy) : self.num_generations for strategy in res_dict_stopping.keys()} | {"params_best_model_{}".format(strategy) : {} for strategy in res_dict_stopping.keys()}
    
            # paths imported from config.py
            if "supervised" in datasets_names:
                path_feat_train = path_feat_train_supervised
                path_feat_val = path_feat_val_supervised
                
            else:
                path_feat_train = path_feat_train_ood.format(datasets_names[0])
                path_feat_val = path_feat_val_ood.format(datasets_names[0])

            # collect train/val lists of filenames
            self.train = TimeseriesFeaturesDataset(path_features=path_feat_train, path_data=path_data, path_scores=path_scores,seed=self.seed, device=self.device, testing=self.testing)              
            self.val = TimeseriesFeaturesDataset(path_features=path_feat_val, path_data=path_data, path_scores=path_scores,seed=self.seed, device=self.device, testing=self.testing)             
            
            self.model = MLPDefault() # mlp default of RL  

            evaluator = EvaluatorGA(model = self.model, 
                                    data_type= self.data_type,
                                    res_dict_plots = self.res_dict_plots,
                                    res_dict_models = self.res_dict_models,
                                    path_save_dir = self.path_save_dir,
                                    stoppers = self.stoppers, 
                                    fnames_val = self.fnames_val, 
                                    fnames_train = self.fnames_train,
                                    train_dataset = self.train,
                                    val_dataset = self.val,
                                    feat_train = self.train, 
                                    feat_val = self.val,
                                    num_solutions = self.num_solutions, 
                                    sample_size = self.sample_size,
                                    val_methods = self.val_methods, 
                                    num_parents_mating = self.num_parents_mating, 
                                    num_generations = self.num_generations, 
                                    device=self.device,
                                    seed=self.seed, 
                                    testing=self.testing)    

            fitness_func_train = evaluator.get_fitness_func_feat(eval_set="train", is_best_solution=False)

        elif self.data_type == "raw": 
            
            res_dict_stopping = {"val_{}_box_{}".format(metric, method) : [] for (metric, method) in product(metrics, self.val_methods)}
            self.res_dict_plots = res_dict_stopping | {"train_{}_box_{}".format(metric, method) : [] for (metric, method) in product(metrics, self.val_methods)}

            # Create an early stopping method for each strategy (metric, val_method)
            self.stoppers = {"{}".format(strategy) : EarlyStopper(patience, delta) for strategy in res_dict_stopping.keys()}
            # Dictionnary including all results
            self.res_dict_models = {"best_model_{}".format(strategy) : self.num_generations for strategy in res_dict_stopping.keys()} | {"params_best_model_{}".format(strategy) : {} for strategy in res_dict_stopping.keys()}

            # lists of names of .csv files containing all windows for 1 ts
            if "supervised" in datasets_names :
                self.fnames_train, self.fnames_val, fnames_test = SplitTrainTest(path_features=path_features, datasets_names=self.datasets_names, train_ratio=self.train_ratio, seed=self.seed).create_splits_raw(read_from_file=path_splits)
            
            else :
                self.fnames_train, self.fnames_val, fnames_test = SplitTrainTest(path_features=path_features, datasets_names=self.datasets_names, train_ratio=self.train_ratio, seed=self.seed).create_splits_raw(read_from_file=path_splits_ood.format(datasets_names[0]))
            
            print(f"len self.fnames_train : {len(self.fnames_train)}\n")
            print(f"len self.fnames_val : {len(self.fnames_val)}\n")
            
            if testing :
                self.fnames_train = self.fnames_train[:10]
                self.fnames_val = self.fnames_val[:10]
                self.fnames_test = fnames_test

            self.train = TimeseriesDataset(fnames_ts=self.fnames_train, path_data=path_data, path_raw=path_raw, path_scores=path_scores, window_size=self.window_size,seed=self.seed, device=self.device)
            self.val = TimeseriesDataset(fnames_ts=self.fnames_val, path_data=path_data, path_raw=path_raw, path_scores=path_scores, window_size=self.window_size,seed=self.seed ,device=self.device)

            self.model = ConvNet().to(self.device) # cnn default of RL
            
            evaluator = EvaluatorGA(model = self.model, 
                                    data_type = self.data_type,
                                    res_dict_plots = self.res_dict_plots,
                                    res_dict_models = self.res_dict_models,
                                    path_save_dir =  self.path_save_dir,
                                    stoppers = self.stoppers,
                                    fnames_val = self.fnames_val, 
                                    fnames_train = self.fnames_train,
                                    train_dataset = self.train,
                                    val_dataset = self.val,
                                    feat_train = self.train, 
                                    feat_val = self.val,
                                    num_solutions = self.num_solutions,
                                    sample_size = self.sample_size,
                                    val_methods = self.val_methods, 
                                    num_parents_mating = self.num_parents_mating, 
                                    num_generations = self.num_generations, 
                                    device=self.device,
                                    seed=self.seed, 
                                    testing=self.testing)    

            fitness_func_train = evaluator.get_fitness_func_raw(eval_set="train", is_best_solution=False)
 
        else:
            raise ValueError(f"Unknown data_type '{self.data_type}' spotted, try with 'raw' or 'features'")
       
        on_generation = evaluator.get_on_generation()
        self.initial_population = evaluator.initial_population
        self.ga_instance = pygad.GA(num_generations=self.num_generations,
                                    num_parents_mating=self.num_parents_mating,
                                    initial_population=self.initial_population,
                                    fitness_func=fitness_func_train,
                                    on_generation=on_generation)

    def run(self):
        tic_run = time.perf_counter() 

        # Training + Validating
        self.ga_instance.run()
        
        print(f"Total running time: {time.perf_counter() - tic_run} seconds")
        return self.res_dict_plots, self.res_dict_models, self.stoppers

if __name__ == "__main__":
   pass