import os 
import torch 
import pandas as pd
import numpy as np
import gym
import tqdm
import pickle
import sys
import time
import random
from itertools import product
from tqdm import tqdm

from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize

from src.features_dataset import TimeseriesFeaturesDataset
from src.create_split import SplitTrainTest
from src.env import EnvNewAPI_raw, EnvNewAPI_feat 
from src.timeseries_dataset import TimeseriesDataset
from src.validation import Validation

from src.policies.cnn_128 import ConvNet

from src.utils.early_stopping import EarlyStopper, early_stop_check_all
from src.utils.utils_runner import NamedEvalCallback
from stable_baselines3.common.callbacks import CallbackList

from src.utils.data_loader import DataLoader
from src.utils.scores_loader import ScoresLoader


from itertools import product

from config import *

path = os.getcwd()
print(path)
if path not in sys.path:
    sys.path.append(path)
    
class Runner:
    def __init__(self,
                 path_save_dir: str,
                 model_name: str = "cnn_128", 
                 data_type:str = "raw",
                 val_methods: list = ["first_window", "shift_windows", "pw_shift_windows", "pw_slide_windows"], 
                 learning_rate: float = 0.0003, 
                 datasets_names: list = ['supervised'], 
                 train_ratio: float = 0.7, 
                 n_steps: int = 64,
                 epochs: int = 1,
                 total_timesteps: int = 1000,
                 n_val : int = 10,
                 batch_size: int = 8,
                 patience: int = 10,
                 delta: float = 0.01,
                 seed: int = 42,
                 device: torch.device = torch.device("cpu"), 
                 testing: bool = False):
        """ Faire en sorte de ne pas avoir à créer le csv à chaque fios donner en argument les paramètres du split et aller 
        le chercher si il existe sinon le créer"""

        self.device = device
        self.model_name = model_name
        self.val_methods = val_methods
        self.data_type = data_type
        self.learning_rate = learning_rate
        self.datasets_names = datasets_names
        self.train_ratio = train_ratio
        self.epochs = epochs
        self.n_steps = n_steps
        self.total_timesteps = total_timesteps
        self.n_val = n_val
        self.batch_size = batch_size
        self.patience = patience
        self.delta = delta
        self.seed = seed
        self.window_size = 128
        self.testing=testing
        self.path_save_dir = path_save_dir

        # Create a res_dict dictionary with metrics of interest
        print(f"val_methods : {self.val_methods}")
        print(f"metrics : {metrics}")
        
        if self.data_type == "features":
            res_dict_stopping = {"val_{}_box".format(metric) : [] for (metric) in metrics}
            self.res_dict_plots = res_dict_stopping | {"train_{}_box".format(metric) : [] for metric in metrics} | {"{}_{}_box_weights".format(eval_set, metric) : [] for (eval_set, metric) in product(["train", "val"], metrics)} 

            #paths imported from config.py
            if "supervised" in datasets_names:
                path_feat_train = path_feat_train_supervised
                path_feat_val = path_feat_val_supervised
                
            else:
                path_feat_train = path_feat_train_ood.format(datasets_names[0])
                path_feat_val = path_feat_val_ood.format(datasets_names[0])

            # collect train/val lists of filenames
            self.train = TimeseriesFeaturesDataset(path_features=path_feat_train, path_data=path_data, path_scores=path_scores,seed=self.seed, device=self.device, testing=self.testing)              
            self.val = TimeseriesFeaturesDataset(path_features=path_feat_val, path_data=path_data, path_scores=path_scores,seed=self.seed, device=self.device, testing=self.testing)             

            print(f"len train {len(self.train)}")
            print(f"len val {len(self.val)}")

            policy = "MlpPolicy" 
            # policy_kwargs = dict(
            #     log_std_init=np.log(0.25)
            # )
            policy_kwargs = None
            
            self.env_train = EnvNewAPI_feat(data=self.train)

        elif self.data_type == "raw": 
            res_dict_stopping = {"val_{}_box_{}".format(metric, method) : [] for (metric, method) in product(metrics, self.val_methods)}
            self.res_dict_plots = res_dict_stopping | {"train_{}_box_{}".format(metric, method) : [] for (metric, method) in product(metrics, self.val_methods)}

            # lists of names of .csv files containing all windows for 1 ts
            if "supervised" in datasets_names :
                self.fnames_train, self.fnames_val, fnames_test = SplitTrainTest(path_features=path_features, datasets_names=self.datasets_names, train_ratio=self.train_ratio, seed=self.seed).create_splits_raw(read_from_file=path_splits)
            
            else :
                self.fnames_train, self.fnames_val, fnames_test = SplitTrainTest(path_features=path_features, datasets_names=self.datasets_names, train_ratio=self.train_ratio, seed=self.seed).create_splits_raw(read_from_file=path_splits_ood.format(datasets_names[0]))
            

            print(f"self.fnames_train : {len(self.fnames_train)}")
            print(f"self.fnames_val : {len(self.fnames_val)}")

            if testing :
                self.fnames_train = self.fnames_train[:10]
                self.fnames_val = self.fnames_val[:10]
                self.fnames_test = fnames_test

            self.train = TimeseriesDataset(fnames_ts=self.fnames_train, path_data=path_data, path_raw=path_raw, path_scores=path_scores, window_size=self.window_size,seed=self.seed, device=self.device)
            self.val = TimeseriesDataset(fnames_ts=self.fnames_val, path_data=path_data, path_raw=path_raw, path_scores=path_scores, window_size=self.window_size,seed=self.seed ,device=self.device)

            self.window_size = 128

            policy = "CnnPolicy"
            policy_kwargs = dict(
                features_extractor_class=ConvNet,
                features_extractor_kwargs=dict(features_dim=12),
                # log_std_init=np.log(0.25)
            )
            
            self.env_train = EnvNewAPI_raw(window_size=self.window_size, data=self.train, device=self.device)
            
        else:
            raise ValueError(f"Unknown data_type '{self.data_type}' spotted, try with 'raw' or 'features'")
       
        # Create an early stopping method for each strategy (metric, val_method)
        self.stoppers = {"{}".format(strategy) : EarlyStopper(patience, delta) for strategy in res_dict_stopping.keys()}
        
        # Dictionnary including all results
        self.res_dict_models = {"best_model_{}".format(strategy) : self.n_val for strategy in res_dict_stopping.keys()} | {"params_best_model_{}".format(strategy) : {} for strategy in res_dict_stopping.keys()}

        log_dir = "./ppo_tensorboard/"
        os.makedirs(log_dir, exist_ok=True)

        # Checking phase to prevent errors hard to debug in the training phase
        check_env(self.env_train)

        self.env_train = Monitor(self.env_train, log_dir)
        self.validation = Validation(device=self.device, testing=self.testing)

        # Create RL model
        self.model = PPO(
            policy, 
            self.env_train, 
            verbose=1,
            learning_rate=learning_rate,
            n_steps=n_steps,              # better batch for PPO
            batch_size=batch_size,
            n_epochs=epochs,               # multiple passes over data
            gamma=0.99,
            gae_lambda=0.95,           # GAE
            clip_range=0.2,
            ent_coef=0.01,             # encourage exploration
            vf_coef=0.5,               # critic loss weight
            max_grad_norm=0.5,
            tensorboard_log=log_dir, 
            policy_kwargs=policy_kwargs,
            device='cpu')
        
        # Modifier le biais de la couche qui sort les moyennes
        # with torch.no_grad():
        #     self.model.policy.action_net.bias.data.fill_(0.5)

    def run(self):
        """ Utilise train_cnn et val_cnn pour retourner un dictionnaire de résultats"""
        tic_eval = time.perf_counter() 
        self.timesteps_learning = self.total_timesteps // self.n_val
        for i in range(self.n_val):
            # Training
            name_save = f"ppo_{self.model_name}_nstep{self.n_steps}_epochs{self.epochs}_bs{self.batch_size}_seed{self.seed}_nsteptot{self.total_timesteps}"
            tic_train = time.perf_counter()
            self.model.learn(total_timesteps=self.timesteps_learning,  
                            tb_log_name=name_save,
                            reset_num_timesteps=False)

            print(f"Training for 1 validation took {time.perf_counter()-tic_train} seconds")

            # Validation on train and val
            tic_val = time.perf_counter()
            if self.data_type == "features":
                self.validation.evaluate_feat(model=self.model, eval_data=self.train, res_dict_plots=self.res_dict_plots, eval_set="train")
                self.validation.evaluate_feat(model=self.model, eval_data=self.val, res_dict_plots=self.res_dict_plots, eval_set="val")
             
            if self.data_type == "raw":
                self.validation.evaluate_raw(model=self.model,  fnames_eval=self.fnames_train, window_size=self.window_size, res_dict=self.res_dict_plots, val_methods=self.val_methods, eval_set="train")
                self.validation.evaluate_raw(model=self.model,  fnames_eval=self.fnames_val, window_size=self.window_size, res_dict=self.res_dict_plots, val_methods=self.val_methods, eval_set="val")
          
            # stop the activity of the stopper associated to a metric if earlystopping is required
            early_stop_check_all(model=self.model, 
                                 epoch=i+1, 
                                 stoppers=self.stoppers, 
                                 res_dict_plots=self.res_dict_plots, 
                                 res_dict_models=self.res_dict_models, 
                                 path_save_dir=self.path_save_dir) 
            
            print(f"Validating 1 time took {time.perf_counter() - tic_val} seconds")  
        
        if self.data_type == "raw":
            print(f" weights summed to zero this many times : {self.env_train.env.n_average}") 
        
        print(f"Total evaluation time: {time.perf_counter() - tic_eval} seconds")
        return self.res_dict_plots, self.res_dict_models, self.stoppers

if __name__ == "__main__":
    runner = Runner(model_name = "polynomial_deg_1_bias_0", 
                 data_type = "features",
                 val_methods = ["first_window", "shift_windows", "pw_shift_windows", "pw_slide_windows"], 
                 learning_rate = 0.01, 
                 datasets_names = ['supervised'],
                 train_ratio = 0.7, 
                 n_steps = 1014,
                 epochs = 5,
                 total_timesteps = 800000,
                 batch_size = 128,
                 patience = 10,
                 delta = 0.001,
                 seed = 42,
                 device = torch.device("cpu"), 
                 testing = False)
    
    runner.run()
    
    # print(runner.res_dict_plots[f"AUC-PR_box_shift_windows"])
