import numpy as np
import copy
import random
import gym
from collections import deque
from gym.envs.registration import register
import os
import sys
sys.path.append('..')
np.set_printoptions(precision=4,suppress=True)

import argparse
parser = argparse.ArgumentParser(description='lam seed')
parser.add_argument('--lam', type=float, help='lambda', required=True)
parser.add_argument('--seed', type=int, help='seed', required=True)
args = parser.parse_args()

state_idx = np.zeros((6,6))
idx = 0
for x_ in range(6):
    for y_ in range(6):
        state_idx[x_][y_] = idx
        idx += 1

class Policy:
    def __init__(self, n_s, n_a, b):
        self.n_state = n_s
        self.n_action = n_a
        self.theta = np.random.rand(n_s, n_a)
        self.weight = np.zeros(n_s)
        self.actions = np.array([i for i in range(n_a)])

        self.b = b
        self.J = 0 # mean
        self.V = 0 # variance

        self.reset()

        return

    def reset(self):
        self.total_rewards = 0
        self.der_log_pi = np.zeros((self.n_state, self.n_action))
    
    def pre_step(self, s):
        a, prob = self.get_action(s)
        der_log_pi = self.get_derivative(s, a)
        self.der_log_pi = self.der_log_pi + der_log_pi
        return a
    
    def post_step(self, r):
        self.total_rewards += r

    def get_state_index(self, s):
        x = int(s[0]) -1
        y = int(s[1]) -1
        return int(state_idx[x][y])

    def get_action_prob(self, s):
        s_idx = self.get_state_index(s)
        res_exp = np.exp(self.theta[s_idx])
        sum_res_exp = np.sum(res_exp)
        prob = res_exp / sum_res_exp
        return prob

    def get_action(self, s):
        prob = self.get_action_prob(s)

        random_num = np.random.rand()
        sum_prob = 0.0

        for i in range(len(self.actions)):
            sum_prob += prob[i]
            if sum_prob >= random_num:
                return self.actions[i], prob[i]

    def eval_action(self, s):
        prob = self.get_action_prob(s)
        action = np.argmax(prob)
        return action

    def get_derivative(self, s, a):
        # nabla log pi(a|s)
        der_theta = np.zeros((self.n_state, self.n_action))

        prob = self.get_action_prob(s)
        s_idx = self.get_state_index(s)
        der_theta[s_idx][a] = 1.
        der_theta[s_idx] -= prob
        return der_theta

    def learn(self, lr_policy, lr_value, lam):
        R = self.total_rewards

        if self.V < self.b:
            grad = 0
        else:
            grad = 2 * (self.V - self.b)

        coef = R - lam * grad * (R**2 - 2 * self.J)
        # update policy parameter
        self.theta += lr_policy * coef * self.der_log_pi

        # update J and V
        tmp_J = copy.deepcopy(self.J)
        tmp_V = copy.deepcopy(self.V)
        self.J += lr_value * (R - tmp_J)
        self.V += lr_value * (R**2 - tmp_J**2 - tmp_V)

        # reset after update at the end of the episode
        self.reset()


############### functional ##############
def play_episode(env, agent):
    s = env.reset()
    while True:
        a = agent.pre_step(s)
        next_s, r, done, _ = env.step(a)
        agent.post_step(r)

        if done:
            break

        s = next_s

def eval_model(env, agent, n_episodes):
    ep_r_lst = []

    for i in range(n_episodes):
        s, done, ep_r = env.reset(), False, 0

        while not done:
            a = agent.eval_action(s)
            s_prime, r, done, _ = env.step(a)

            ep_r += r
            s = s_prime

        ep_r_lst.append(ep_r)


    return np.array(ep_r_lst)

def get_save_dir(stoc_trans, goal, lam, b, seed, lr_policy, lr_value):
    save_dir = './test'
    if stoc_trans:
        save_dir += '/stoc_trans/goal_'+str(goal)
    else:
        save_dir += '/deter_trans/goal_'+str(goal)

    save_dir += '/lam_' + str(lam) + '/b_' + str(b) + '/seed_' + str(seed)
    save_dir += '/lr_p_' + str(lr_policy) + '/lr_v_' + str(lr_value) + '/' 
    return save_dir

############ setting ##############
seed = args.seed
lr_policy = 1e-5
lr_value = lr_policy * 100
lam = args.lam
b = 50
train_episodes = 4000*50
test_intvl = 1000
eval_episodes = 10

goal_r = 20.
stoc_trans = False

register(
    id='GuardedMazeEnv-v0',
    entry_point='GuardedMaze_Discrete:GuardedMaze',
    kwargs=dict(
        mode=1,
        max_steps=100,
        guard_prob=1.0,
        goal_reward=goal_r,
        stochastic_trans=stoc_trans,
    )
)    

env = gym.make('GuardedMazeEnv-v0')
eval_env = gym.make('GuardedMazeEnv-v0')
env.seed(seed)
eval_env.seed(2**31-1-seed)
np.random.seed(seed)
random.seed(seed)     

state_space = 36
action_space = 4

# create agent
agent = Policy(state_space, action_space, b)
print('lam:', lam, 'b:', b, 'seed:', seed)
print('lr_p:', lr_policy, 'lr_v:', lr_value)

eval_r_lst = []

for ep in range(train_episodes):
    play_episode(env, agent)
    agent.learn(lr_policy, lr_value, lam)

    if (ep+1) % test_intvl == 0:
        # test 
        eval_r = eval_model(eval_env, agent, eval_episodes)
        eval_r_lst.append(eval_r)

# create save dir
save_dir = get_save_dir(stoc_trans, goal_r, lam, b, seed, lr_policy, lr_value)
os.makedirs(save_dir, exist_ok=True)

with open(save_dir + 'eval_r.npy', 'wb') as f:
    np.save(f, np.array(eval_r_lst))
    
# save theta
with open(save_dir + 'theta.npy', 'wb') as f:
    np.save(f, agent.theta)
    

  