import numpy as np
import copy
import random
import gym
from collections import deque
from gym.envs.registration import register
import os
import sys
sys.path.append('..')
np.set_printoptions(precision=4,suppress=True)

import argparse
parser = argparse.ArgumentParser(description='lam seed')
parser.add_argument('--lam', type=float, help='lambda', required=True)
parser.add_argument('--seed', type=int, help='seed', required=True)
args = parser.parse_args()

state_idx = np.zeros((6,6))
idx = 0
for x_ in range(6):
    for y_ in range(6):
        state_idx[x_][y_] = idx
        idx += 1

class Policy:
    def __init__(self, n_s, n_a):
        self.n_state = n_s
        self.n_action = n_a
        self.theta = np.random.rand(n_s, n_a)
        self.weight = np.zeros(n_s)
        self.actions = np.array([i for i in range(n_a)])

        self.y = 0

        self.reset()

        return

    def reset(self):
        self.total_rewards = 0
        self.der_log_pi = np.zeros((self.n_state, self.n_action))
    
    def pre_step(self, s):
        a, prob = self.get_action(s)
        der_log_pi = self.get_derivative(s, a)
        self.der_log_pi = self.der_log_pi + der_log_pi
        return a
    
    def post_step(self, r):
        self.total_rewards += r

    def get_state_index(self, s):
        x = int(s[0]) -1
        y = int(s[1]) -1
        return int(state_idx[x][y])

    def get_action_prob(self, s):
        s_idx = self.get_state_index(s)
        res_exp = np.exp(self.theta[s_idx])
        sum_res_exp = np.sum(res_exp)
        prob = res_exp / sum_res_exp
        return prob

    def get_action(self, s):
        prob = self.get_action_prob(s)

        random_num = np.random.rand()
        sum_prob = 0.0

        for i in range(len(self.actions)):
            sum_prob += prob[i]
            if sum_prob >= random_num:
                return self.actions[i], prob[i]

    def eval_action(self, s):
        prob = self.get_action_prob(s)
        action = np.argmax(prob)
        return action

    def get_derivative(self, s, a):
        # nabla log pi(a|s)
        der_theta = np.zeros((self.n_state, self.n_action))

        prob = self.get_action_prob(s)
        s_idx = self.get_state_index(s)
        der_theta[s_idx][a] = 1.
        der_theta[s_idx] -= prob
        return der_theta

    def learn(self, lr_policy, lr_value, lam):
        R = self.total_rewards

        # update y first
        self.y = self.y + lr_value * (2*R + 1/lam - 2*self.y)

        # update theta then
        self.theta += lr_policy * (2* self.y * R - R**2) * self.der_log_pi

        # reset after update at the end of the episode
        self.reset()


############### functional ##############
def play_episode(env, agent):
    s = env.reset()
    while True:
        a = agent.pre_step(s)
        next_s, r, done, _ = env.step(a)
        agent.post_step(r)

        if done:
            break

        s = next_s

def eval_model(env, agent, n_episodes):
    ep_r_lst = []

    for i in range(n_episodes):
        s, done, ep_r = env.reset(), False, 0
        

        while not done:
            a = agent.eval_action(s)
            s_prime, r, done, _ = env.step(a)

            ep_r += r
            s = s_prime

        ep_r_lst.append(ep_r)

    return np.array(ep_r_lst)

def get_save_dir(stoc_trans, goal, lam, seed, lr_policy, lr_value):
    save_dir = './save'
    if stoc_trans:
        save_dir += '/stoc_trans/goal_' + str(goal)
    else:
        save_dir += '/deter_trans/goal_' + str(goal)

    save_dir += '/lam_' + str(lam) + '/seed_' + str(seed)
    save_dir += '/lr_p_' + str(lr_policy) + '/lr_v_' + str(lr_value) + '/' 
    return save_dir

############ setting ##############
seed = args.seed
lr_policy = 3e-5
lr_value = lr_policy
lam = args.lam
train_episodes = 4000*50
test_intvl = 1000
eval_episodes = 10

goal_r = 20.
stoc_trans = False

register(
    id='GuardedMazeEnv-v0',
    entry_point='GuardedMaze_Discrete:GuardedMaze',
    kwargs=dict(
        mode=1,
        max_steps=100,
        guard_prob=1.0,
        goal_reward=goal_r,
        stochastic_trans=stoc_trans,
    )
)    

env = gym.make('GuardedMazeEnv-v0')
eval_env = gym.make('GuardedMazeEnv-v0')
env.seed(seed)
eval_env.seed(2**31-1-seed)
np.random.seed(seed)
random.seed(seed)     

state_space = 36
action_space = 4

# create agent
agent = Policy(state_space, action_space)
print('lam:', lam, 'seed:', seed)
print('lr_p:', lr_policy, 'lr_v:', lr_value)


eval_r_lst = []

for ep in range(train_episodes):
    play_episode(env, agent)
    agent.learn(lr_policy, lr_value, lam)

    if (ep+1) % test_intvl == 0:
        # test deterministic policy
        eval_r = eval_model(eval_env, agent, eval_episodes)
        eval_r_lst.append(eval_r)

# create save dir
save_dir = get_save_dir(stoc_trans, goal_r, lam, seed, lr_policy, lr_value)
os.makedirs(save_dir, exist_ok=True)

with open(save_dir + 'eval_r.npy', 'wb') as f:
    np.save(f, np.array(eval_r_lst))
    
# save theta
with open(save_dir + 'theta.npy', 'wb') as f:
    np.save(f, agent.theta)

  