import numpy as np
from tqdm import tqdm

from utils import get_dir, is_integer, is_float, collect_files, get_parent_directory
from utils import load_dict, save_dict
from env.create_dataset import create_dataset_transitions
from env.mrp import MRP

################################################################

def compute_msbe_train_from_logs(path_mrp:str, 
                                 discount_factor:float, 
                                 seed:int=42, 
                                 verbose:bool=True):
    path=get_parent_directory(path_mrp)
    dir=get_dir(path)
    dir_dataset_size=[d for d in dir if d.rfind("dataset_")!=-1]
    dir=get_dir(path+dir_dataset_size[0]+"/")
    dir_weight_decay=[d for d in dir if is_integer(d) or is_float(d)]
    print(dir_weight_decay)
    dir=get_dir(path+dir_dataset_size[0]+"/"+dir_weight_decay[0]+"/")
    dir_ratio=[d for d in dir if is_float(d)]
    path_next_models="/models/"
    num_dataset_size=len(dir_dataset_size)
    progress = range(num_dataset_size)
    
    if verbose:
        progress = tqdm(progress, desc="Processing dataset size") 
    for i in progress:
        num_transitions=int(dir_dataset_size[i][dir_dataset_size[i].rfind("_")+1:])
        dataset=create_dataset_transitions(path_mrp=path_mrp, 
                                           num_transitions=num_transitions, 
                                           verbose=False,
                                           seed=seed, 
                                           get_state_id=True, 
                                           save=False)
        dataset_state=np.array(dataset["state"]).squeeze()
        dataset_next_state=np.array(dataset["next_state"]).squeeze()
        dataset_reward=np.array(dataset["reward"]).squeeze()
        for weight_decay in dir_weight_decay:
            dir=get_dir(path+dir_dataset_size[i]+"/"+weight_decay+"/")
            dir_ratio=[d for d in dir if is_float(d)]
            for ratio in dir_ratio:
                path_models=path+dir_dataset_size[i]+"/"+weight_decay+"/"+ratio+path_next_models
                num_filename_models=len(collect_files(path=path_models, extension=".npz"))
                filename_models=['lstd_model_'+str(i)+'_V_pred.npz' for i in range(num_filename_models)]
                msbe_train=np.zeros((num_filename_models,))
                for j in range(num_filename_models):
                    V_pred=np.load(path_models+filename_models[j])["V_pred"]
                    msbe_train[j]=np.mean((dataset_reward+discount_factor*V_pred[dataset_next_state]-V_pred[dataset_state])**2)
                path_errors=path+dir_dataset_size[i]+"/"+weight_decay+"/"+ratio+"/errors.json"
                errors=load_dict(path_errors)
                for j in errors.keys():
                    errors[j]=np.array(errors[j])
                errors["MSBE_train"]=msbe_train
                save_dict(path=path_errors, data=errors)
                                       
################################################################
                
def get_num_important_visited_states(path_mrp:str, 
                                     num_transitions:int, 
                                     verbose:bool=True,
                                     seed:int=42)->int:
    params_mrp=load_dict(path_mrp)
    for k in params_mrp.keys():
        params_mrp[k]=np.array(params_mrp[k])
    if "V" in params_mrp.keys():
        del params_mrp["V"]
    num_states=params_mrp["P"].shape[0]
    mrp=MRP(**params_mrp)    
    stationary_distrib=mrp.get_stationary_distribution()
    important_states_to_visit=np.argwhere(stationary_distrib>0)[:,0]
    num_important_to_visit=important_states_to_visit.shape[0]
    data=create_dataset_transitions(path_mrp=path_mrp, 
                                    num_transitions=num_transitions, 
                                    verbose=False,
                                    seed=seed, 
                                    get_state_id=True, 
                                    save=False)        
    dataset_states, dataset_next_states=np.array(data["state"]).squeeze(), np.array(data["next_state"]).squeeze()
    visited=np.zeros(num_states, dtype=int)
    visited[dataset_states] = 1 
    visited[dataset_next_states] = 1        
    num_unvisited_states=num_states-np.sum(visited)
    num_important_visited_states=np.sum(visited[important_states_to_visit]==1)
    num_unvisited_important_states=num_important_to_visit-num_important_visited_states
    if verbose:
        print("For "+str(num_transitions)
                +" samples, number of unvisited states = "+str(num_unvisited_states)
                +" and number of important unvisited states = "+str(num_unvisited_important_states))
    return num_important_visited_states

   
################################################################

# if __name__ == "__main__":
    # path_mrp="res/gridworld/env.json"
    # discount_factor=0.95
    # seed=42
    # verbose=True
    
    # compute_msbe_train_from_logs(path_mrp=path_mrp, 
    #                              discount_factor=discount_factor, 
    #                              seed=seed, 
    #                              verbose=verbose)