import gymnasium as gym
import numpy as np
from tqdm import tqdm

import sys
from pathlib import Path
parent_folder = Path(__file__).resolve().parent.parent
sys.path.append(str(parent_folder))

from env.mrp import MRP, Finite_MRP
from utils import load_dict, get_path_without_extension

sys.path.remove(str(parent_folder))

##############################################################################

def create_dataset_transitions(path_mrp:str="env/config/env.json", 
                               num_transitions:int=10_000, 
                               get_state_id:bool=False,
                               T:int=None,
                               sigma_reward:float=None,
                               save:bool=True,
                               verbose:bool=True, 
                               seed:int=None)->dict:
    params_mrp = load_dict(path_mrp)
    if "V" in params_mrp.keys():
        del params_mrp["V"]
    for k in params_mrp.keys():
        params_mrp[k]=np.array(params_mrp[k])
    num_states=params_mrp["P"].shape[0]
    if get_state_id:
        params_mrp["state_labels"]=np.expand_dims(np.arange(num_states), axis=1)
    if T is None:
        mrp = MRP(**params_mrp)
    else:
        params_mrp["T"]=T
        mrp=Finite_MRP(**params_mrp)
    if save:
        path=get_path_without_extension(path_mrp)+"_dataset"
    else:
        path=None
    return mrp.create_datatset(path=path, 
                               num_transitions=num_transitions, 
                               sigma_reward=sigma_reward,
                               verbose=verbose, 
                               seed=seed)     
    
##############################################################################

if __name__ == "__main__":
       
    path="res/new_taxi/env.json"
    seed=42
    data=create_dataset_transitions(path_mrp=path,
                                    num_transitions=500,
                                    verbose=True,
                                    seed=42,
                                    get_state_id=True,
                                    save=False)
    
    dataset_states=np.array(data["state"])
    dataset_next_states=np.array(data["next_state"])
    dataset_dones=np.array(data["done"])
    print(dataset_states.squeeze())
    
    # print("Number of samples: ", dataset_states.shape[0])
    # end=np.argwhere(dataset_dones)[:,0]
    # print(dataset_states[end[0]])
    # print(dataset_next_states[end[0]])
    # print(dataset_states[end[0]+1])
    # print(dataset_next_states[end[0]+1])
    # print(dataset_states[end[0]+2])
    # print(dataset_next_states[end[0]+2])
    # print(end)
    
    # restart=np.argwhere(np.array(data["done"])>0)[:,0]+1
    # done=np.array(data["done"])
    # state=np.array(data["state"])
    # next_state=np.array(data["next_state"])
    # print(state[restart])
    # print(state)

    