from email.policy import default
import d3rlpy
import argparse
from tqdm import tqdm
from gym.spaces import Box
import numpy as np
import policy
import numpy as np
import gym

import random
import os

import warnings 
warnings.filterwarnings("ignore")

def gen(args):
    print('env_name:',args.dataset_name)
    print('policy : ',args.policy)
    env = gym.make(args.dataset_name)
    env.seed(args.seed)

    print(env.env.str_maze_spec)

    random.seed(args.seed)

    robot = policy.get_policy(args.policy_file,env,args.policy)
    # robot.seed(args.seed)
    
    state,action,reward,terminal,timeout=[],[],[],[],[]
    total_reward = 0.0
    
    # file_name = args.file # 直接指明npz文件名称
    
    seed_list = []

    if(args.render):env.render()
    for episode in tqdm(range(args.episodes),desc='生成进度'):
        
        
        # random_seed = 0
        # env.seed(random_seed)

        random_seed = random.randint(2000, 19260817)

        # print("seed: ", random_seed)
        seed_list.append(random_seed)

        env.seed(random_seed)
        env.action_space.seed(random_seed)

        s,done = env.reset(),False
        

        robot.set_seed(random_seed)
        robot.reset()
        
        while(not done):

            a = robot.sample(s,env.get_target())
            s_,r,done,info = env.step(a)
            state.append(tuple(s))
            action.append(tuple(a))
            reward.append(r)
            terminal.append(0)
            timeout.append(done)
            s = s_
            total_reward += r

            # print("state: ", s, "action: ", a, "next state: ", s_, "reward: ", r)

            # if len(reward) > args.steps:
            #     break

    observations = np.array(state)
    actions = np.array(action)
    rewards = np.array(reward)
    terminals = np.array(terminal)
    timeouts = np.array(timeout)
    # file_name = args.file # 直接指明npz文件名称


    dir_name = "data_new"
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)
    dir_name = os.path.join(dir_name, args.policy)
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)
    

    env_type = args.dataset_name.split("-")[1]
    if "dense" in args.dataset_name:
        env_type = args.dataset_name.split("-")[1] + "-" + "dense"

    dir_name = os.path.join(dir_name, env_type)
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)
    
    file_name = os.path.join(dir_name, "data.npz")

    np.savez(file_name,
            observations=observations,
            actions=actions,
            rewards=rewards,
            terminals=terminals,
            timeouts=timeouts)
    print('mean reward :',total_reward / args.episodes)

    print("save     >>>     ", file_name)


    print("seed list: ", seed_list)



def replay(args):
    env = gym.make(args.dataset_name)
    env.reset()


    file_name = os.path.join("data_new", args.policy, args.dataset_name.split("-")[1], "data.npz")

    data = np.load(file_name)
    
    for episode in tqdm(range(args.episodes),desc='回放进度'):
        
        start = data['observations'][0]
        env.set_state(np.array([start[0],start[1]]),np.array([start[2],start[3]]))
        if(args.render):env.render()
        for step in range(len(data['observations'])):
            action = data['actions'][step]
            env.step(action)
            if(args.render):env.render()


# sac
# python generator.py --dataset_name maze2d-umaze-v1 --episodes 20000 --steps 300 --seed 0 --file data/sac/umaze/ --policy sac --policy_file checkpoint/sac_umaze.pt
# python generator.py --dataset_name maze2d-medium-v1 --episodes 20000 --steps 600 --seed 0 --file data/sac/medium/ --policy sac --policy_file checkpoint/sac_medium.pt
# python generator.py --dataset_name maze2d-large-v1 --episodes 20000 --steps 800 --seed 0 --file data/sac/large/ --policy sac --policy_file checkpoint/sac_large.pt
#
# expert
# python generator.py --dataset_name maze2d-umaze-v1 --episodes 20000 --steps 300 --seed 0 --file data/expert/umaze/ --policy expert --render
# python generator.py --dataset_name maze2d-medium-v1 --episodes 20000 --steps 600 --seed 0 --file data/expert/medium/ --policy expert
# python generator.py --dataset_name maze2d-large-v1 --episodes 20000 --steps 800 --seed 0 --file data/expert/large/ --policy expert

# random 
# python generator.py --dataset_name maze2d-umaze-v1 --episodes 20000 --steps 300 --seed 0 --file data/random/umaze/ --policy random --render
# python generator.py --dataset_name maze2d-medium-v1 --episodes 20000 --steps 600 --seed 0 --file data/random/medium/ --policy random
# python generator.py --dataset_name maze2d-large-v1 --episodes 20000 --steps 800 --seed 0 --file data/random/large/ --policy random
#
# replay
# python generator.py --replay --render --dataset_name maze2d-umaze-v1 --episodes 20000 --file data/sac/umaze/
# python generator.py --replay --render --dataset_name maze2d-medium-v1 --episodes 20000 --file data/sac/medium/
# python generator.py --replay --render --dataset_name maze2d-large-v1 --episodes 20000 --file data/sac/large/



# python generator.py --dataset_name maze2d-umaze-v1 --policy random --episodes 12000 --steps 300
# python generator.py --dataset_name maze2d-umaze-v1 --policy expert

# python generator.py --dataset_name maze2d-medium-v1  --policy random --episodes 12000 --steps 600
# python generator.py --dataset_name maze2d-medium-v1 --policy expert

# python generator.py --dataset_name maze2d-large-v1 --policy random --episodes 12000 --steps 800
# python generator.py --dataset_name maze2d-large-v1 --policy expert




if __name__ == '__main__':
    parser = argparse.ArgumentParser("A Library for myself")
    parser.add_argument('--dataset_name',type=str)
    parser.add_argument('--episodes',type=int,default=100)
    parser.add_argument('--steps',type=int,default=300)
    parser.add_argument('--seed',type=int,default=19260817)
    parser.add_argument('--policy',type=str,default='random')
    parser.add_argument('--policy_file',type=str,default='checkpoint/sac_umaze.pt')
    parser.add_argument('--replay',action='store_true',default=False)
    parser.add_argument('--render',action='store_true',default=False)
    args = parser.parse_args()
    if(args.replay == False):gen(args)
    else:replay(args)
    
