import os
import gym
from gym.envs.registration import register
import numpy as np
import random
from policy import Policy

import sys
sys.path.append('..')

import argparse
parser = argparse.ArgumentParser(description='seed lambda')
parser.add_argument('--seed', type=int, help='seed', required=True)
parser.add_argument('--lam', type=float, help='lambda', required=True)
args = parser.parse_args()
########### setting ############
seed = args.seed
gamma = 0.999
lr_policy = 1e-4
lr_value = 0.01
lam = args.lam
ent_coef = None
episodes = 50
train_epochs = 4000
test_intvl = 20
inner_update = 10
goal_r = 20.
stoc_trans = False
use_baseline = True

####### regist environment ########
register(
    id='GuardedMazeEnv-v0',
    entry_point='GuardedMaze_Discrete:GuardedMaze',
    kwargs=dict(
        mode=1,
        max_steps=100,
        guard_prob=1.0,
        goal_reward=goal_r,
        stochastic_trans=stoc_trans,
    )
)

######## interact function #######
def eval_policy(env, model, n_episodes=10):
    ep_r_lst = []
    for i in range(n_episodes):
        s, done, ep_r = env.reset(), False, 0
        while not done:
            a = model.eval_action(s)
            s_prime, r, done, info = env.step(a)

            ep_r += r
            s = s_prime

        ep_r_lst.append(ep_r)
        
    return np.array(ep_r_lst)

def play_episode(env, model):
    global gamma

    traj = []

    s = env.reset()
    done = False
    while not done:
        a, pi_a = model.get_action(s) 
        s2, r, done, info = env.step(a)

        traj.append([s, a, r, pi_a])
        
        s = s2

    return traj

############ main #################
env = gym.make('GuardedMazeEnv-v0')
eval_env = gym.make('GuardedMazeEnv-v0')
env.seed(seed=seed)
eval_env.seed(seed=2**31-1-seed)
np.random.seed(seed)
random.seed(seed)

state_space = 36
action_space = 4

model = Policy(state_space, action_space, episodes)
print('--- seed:', seed, ', sample size:', episodes, ', lambda:',lam, ', baseline:', use_baseline, ' ent_coef:', ent_coef)

testR_lst = []
for epi in range(train_epochs):
    for _ in range(episodes):
        traj = play_episode(env, model)
        model.put_data(traj)
    model.reinforce_baseline_gini_k(lr_policy, lr_value, gamma, lam, ent_coef, inner_update)


    if (epi+1) % test_intvl == 0:
        test_r = eval_policy(eval_env, model)
        testR_lst.append(test_r)
        #pbar.set_description("Test(%.2f) | lr(%f)" % (test_r, lr_policy))


# save test reward
def get_save_dir(bs, goal_r, stoc_trans, use_baseline, lam, seed, lr_policy, lr_value=None, ent_coef=None):
    save_dir = './'
    if stoc_trans:
        save_dir += 'test_stoc/bs_'+str(bs)+'/goal_'+str(goal_r)+'/'
    else:
        save_dir += 'test_deter/bs_'+str(bs)+'/goal_'+str(goal_r)+'/'

    if use_baseline:
        assert lr_value is not None
        if ent_coef is None:
            save_dir += 'rf_bl_gn/lam='+str(lam)+'/seed_'+str(seed)+'/lr_p='+str(lr_policy)+'/lr_v='+str(lr_value)+'/'
        else:
            save_dir += 'rf_bl_gn_ent/ent_coef_'+str(ent_coef)+'/lam='+str(lam)+'/seed_'+str(seed)+'/lr_p='+str(lr_policy)+'/lr_v='+str(lr_value)+'/'
    else:
        save_dir += 'rf_gn/lam='+str(lam)+'/seed_'+str(seed)+'/lr_p='+str(lr_policy)+'/'

    return save_dir

save_dir = get_save_dir(episodes, goal_r, stoc_trans, use_baseline, lam, seed, lr_policy, lr_value, ent_coef)
os.makedirs(save_dir, exist_ok=True)
with open(save_dir+'r.npy', 'wb') as f:
    np.save(f, np.array(testR_lst))
