import os
import numpy as np
import multiprocessing as mp
from tqdm import tqdm
from copy import deepcopy

from env.create_dataset import create_dataset_transitions
from train import train_LSTD_offline
from utils import load_dict, get_parent_directory, get_path_without_extension

################################################################

def compute_errors_params(path:str,
                          path_mrp:str,
                          path_dataset:str,
                          params_var:np.ndarray,
                          params_lstd:dict, 
                          num_important_visited_states:int,
                          num_training:int=10, 
                          num_jobs:int=1,
                          verbose:bool=False,
                          save_models:bool=True):
    
    params_lstd_fun=deepcopy(params_lstd)
    params_tested=np.copy(params_var)
    if "ratio" in params_lstd.keys() and "weight_decay" in params_lstd.keys():
        params_tested_str="gamma"
        params_lstd_fun["num_features"]=(np.round(params_lstd["ratio"]*num_important_visited_states)).astype(int)
        del params_lstd_fun["ratio"]
    elif "gamma" in params_lstd.keys() and "weight_decay" in params_lstd.keys():
        params_tested_str="num_features"
        params_tested=(np.round(params_tested*num_important_visited_states)).astype(int)
    elif "ratio" in params_lstd.keys() and "gamma" in params_lstd.keys():
        params_tested_str="weight_decay"
        params_lstd_fun["num_features"]=(np.round(params_lstd["ratio"]*num_important_visited_states)).astype(int)
        del params_lstd_fun["ratio"]
    else:
        raise ValueError(f"No parameters tested in compute_errors_params")
    progress = range(params_tested.shape[0])
    if verbose:
        progress = tqdm(progress, desc="Processing training")
    for i in progress:
        params_lstd_fun[params_tested_str]=params_tested[i]
        train_LSTD_offline(path=path+str(params_var[i])+"/", 
                           path_dataset=path_dataset,
                           path_mrp=path_mrp,
                           params_lstd=params_lstd_fun,
                           num_training=num_training,
                           compute_V=False,
                           save_models=save_models,
                           num_jobs=num_jobs)

################################################################

from post_training import get_num_important_visited_states, compute_msbe_train_from_logs
from plot_msbe import collect_msbe_dataset_size_params, plot_msbe_test_and_train_dataset_size_params
from MSBE_th import predict_msbe_th_params
    
if __name__ == "__main__":
    print("Number of processes on the device: ", mp.cpu_count())
    
    path_mrp="res/toy_ergodic_mrp/env.json"
    list_weights_decay=[10**(-9), 10**(-6), 10**(-3)]
    list_dataset_size=[5_000]
    params_lstd=dict(activation_fun="ReLU", gamma=0.95)
    seed=42
    
    ####### Numerical results for the true MSBE with regularized recursive LSTD #######
    
    ratio_max=2.0 
    delta=0.05 
    ratio=np.arange(delta,ratio_max+delta, delta)
    num_training=30
    num_jobs=1
    path_dataset=get_path_without_extension(path_mrp)+"_dataset.npz"    
    for num_transitions in list_dataset_size:
        num_important_visited_states=get_num_important_visited_states(path_mrp=path_mrp,
                                                                      num_transitions=num_transitions,
                                                                      verbose=False,
                                                                      seed=seed)
        create_dataset_transitions(path_mrp=path_mrp, 
                                   num_transitions=num_transitions,  
                                   verbose=False,  
                                   seed=seed)
        for decay in list_weights_decay:
            params_lstd["weight_decay"]=decay*num_transitions*num_important_visited_states
            path=get_parent_directory(path_mrp)+"dataset_"+str(num_transitions)+"/"+str(decay)+"/"
            compute_errors_params(path=path, 
                                  path_mrp=path_mrp,
                                  path_dataset=path_dataset, 
                                  params_var=ratio, 
                                  params_lstd=params_lstd, 
                                  num_important_visited_states=num_important_visited_states, 
                                  num_training=num_training, 
                                  num_jobs=num_jobs, 
                                  verbose=False, 
                                  save_models=False)
       
    ####### Numerical results for empirical MSBE (obtained offline) with regularized recursive LSTD #######     

    compute_msbe_train_from_logs(path_mrp=path_mrp, 
                                 discount_factor=params_lstd["gamma"], 
                                 seed=seed, 
                                 verbose=False)
    
    
    
    ####### Theoritical results  for the empirical MSBE #######
    delta=0.01
    ratio=np.arange(delta,ratio_max+delta, delta)
    path_msbe=get_parent_directory(path_mrp)+"MSBE_training_th/"
    for num_transitions in list_dataset_size:
        for weight_decay in list_weights_decay:
            path_dir_logs=path_msbe+"dataset_"+str(num_transitions)+"/"+str(weight_decay)+"/"
            predict_msbe_th_params(path_dir_logs=path_dir_logs,
                                   path_mrp=path_mrp,
                                   gamma=params_lstd["gamma"],
                                   ratio=ratio,
                                   weight_decay=weight_decay,
                                   num_transitions=num_transitions,
                                   training=True,
                                   seed=seed)
            
    ####### Theoritical results  for the true MSBE #######
    
    path_msbe=get_parent_directory(path_mrp)+"MSBE_test_th/"
    for num_transitions in list_dataset_size:
        for weight_decay in list_weights_decay:
            path_dir_logs=path_msbe+"dataset_"+str(num_transitions)+"/"+str(weight_decay)+"/"
            predict_msbe_th_params(path_dir_logs=path_dir_logs,
                                   path_mrp=path_mrp,
                                   gamma=params_lstd["gamma"],
                                   ratio=ratio,
                                   weight_decay=weight_decay,
                                   num_transitions=num_transitions,
                                   training=False,
                                   seed=seed)
    
    ####### Plot the results #######
    
    path=get_parent_directory(path_mrp)
    params_var=dict(subfig="weight_decay", axis="ratio")
    msbe_test_dir="MSBE_test_th/"
    msbe_train_dir="MSBE_training_th/"
    msbe=dict(train={}, test={})   
    msbe["train"]["exp"]=collect_msbe_dataset_size_params(path=path,
                                                        params_var_str=params_var["axis"],
                                                        msbe_exp=True,
                                                        training=True,
                                                        verbose=True)

    msbe["train"]["th"]=collect_msbe_dataset_size_params(path=path+msbe_train_dir,
                                                        params_var_str=params_var["axis"],
                                                        msbe_exp=False,
                                                        training=True,
                                                        verbose=True)

    msbe["test"]["exp"]=collect_msbe_dataset_size_params(path=path,
                                                        params_var_str=params_var["axis"],
                                                        msbe_exp=True,
                                                        training=False,
                                                        verbose=True)

    msbe["test"]["th"]=collect_msbe_dataset_size_params(path=path+msbe_test_dir,
                                                        params_var_str=params_var["axis"],
                                                        msbe_exp=False,
                                                        training=False,
                                                        verbose=True)


    plot_msbe_test_and_train_dataset_size_params(path=path,
                                                 msbe=msbe,
                                                 params_var=params_var,
                                                 subfig=True,
                                                 median=False,
                                                 verbose=True)
            