import numpy as np
import copy
import random
import gym
from gym.envs.registration import register
import os
np.set_printoptions(precision=4,suppress=True)
import sys
sys.path.append('..')

import argparse
parser = argparse.ArgumentParser(description='lam seed')
parser.add_argument('--lam', type=float, help='lambda', required=True)
parser.add_argument('--seed', type=int, help='seed', required=True)
args = parser.parse_args()

state_idx = np.zeros((6,6))
idx = 0
for x_ in range(6):
    for y_ in range(6):
        state_idx[x_][y_] = idx
        idx += 1

class Policy:
    def __init__(self, n_s, n_a, bs):
        self.n_state = n_s
        self.n_action = n_a
        self.batch_size = bs
        self.theta = np.random.rand(n_s, n_a)
        self.weight = np.zeros(n_s)
        self.actions = np.array([i for i in range(n_a)])

        self.traj_buf = []
        return

    def put_data(self, traj):
        self.traj_buf.append(traj)

    def get_state_index(self, s):
        x = int(s[0]) -1
        y = int(s[1]) -1
        return int(state_idx[x][y])

    def get_action_prob(self, s):
        s_idx = self.get_state_index(s)
        res_exp = np.exp(self.theta[s_idx])
        sum_res_exp = np.sum(res_exp)
        prob = res_exp / sum_res_exp
        return prob

    def get_action(self, s):
        prob = self.get_action_prob(s)

        random_num = np.random.rand()
        sum_prob = 0.0

        for i in range(len(self.actions)):
            sum_prob += prob[i]
            if sum_prob >= random_num:
                return self.actions[i], prob[i]

    def eval_action(self, s):
        prob = self.get_action_prob(s)
        action = np.argmax(prob)
        return action

    def get_derivative(self, s, a):
        der_theta = np.zeros((self.n_state, self.n_action))

        prob = self.get_action_prob(s)
        s_idx = self.get_state_index(s)
        der_theta[s_idx][a] = 1.
        der_theta[s_idx] -= prob
        return der_theta
    
    # reinforce + variance, sample n traj, update k times
    def reinforce_baseline_varaince_k(self, lr_policy, lr_value, gamma, lam, ent_coef, k):
        for i in range(k):
            stop = self.reinforce_baseline_variance(lr_policy, lr_value, gamma, lam, ent_coef)
            if stop:
                #print('early stop:', i)
                break
        
        '''clean buf'''
        self.traj_buf = []

    def reinforce_baseline_variance(self, lr_policy, lr_value, gamma, lam, ent_coef):
        # compute return of each traj
        ret_lst = []
        # compute sum_log_pi of each traj
        sum_der_lst = []
        # compute reinforce pg of each traj
        rf_der_lst = []
        # importance sampling ratio
        is_ratio_lst = []

        for t in range(self.batch_size):
            traj = self.traj_buf[t]
            transition = copy.deepcopy(traj)
            transition.reverse()
            traj_len = len(traj)

            ret = 0.0
            sum_der = np.zeros((self.n_state, self.n_action))
            der_reinforce = np.zeros((self.n_state, self.n_action))
            old_pi_lst, current_pi_lst = [], []
            for i in range(traj_len):
                item = traj[i]
                ret = item[2] + gamma * ret

                old_pi_lst.append(item[3])
                pi_a = self.get_action_prob(item[0])[item[1]]
                current_pi_lst.append(pi_a)

                der_theta = self.get_derivative(item[0], item[1])

                s_idx = self.get_state_index(item[0])
                value = self.weight[s_idx]
                delta = ret - value

                der_reinforce += pow(gamma, traj_len-1-i) * der_theta * delta
                sum_der += der_theta

                # update value function step-wise
                self.weight[s_idx] += (lr_value / self.batch_size) * delta

            ret_lst.append(ret)
            sum_der_lst.append(sum_der)
            rf_der_lst.append(der_reinforce)

            # compute IS
            old_pi = np.array(old_pi_lst)
            current_pi = np.array(current_pi_lst)
            log_ratio = np.log(current_pi) - np.log(old_pi)
            is_ratio_lst.append(np.exp(log_ratio.sum()))

        '''choose IS'''
        is_ratio = np.array(is_ratio_lst)
        is_idx = np.where((is_ratio<=1.3) & (is_ratio >= 0.7))
        is_ratio_choose = is_ratio[is_idx]
        ret_choose = np.array(ret_lst)[is_idx]
        rf_der_choose = np.array(rf_der_lst)[is_idx]
        sum_der_choose = np.array(sum_der_lst)[is_idx]
        choose_size = len(is_ratio_choose)


        ret2_choose = np.square(ret_choose)

        der_J = np.zeros((self.n_state, self.n_action))
        der_M = np.zeros((self.n_state, self.n_action))

        for i in range(choose_size):
            der_J += rf_der_choose[i] * is_ratio_choose[i]
            der_M += ret2_choose[i] * sum_der_choose[i] * is_ratio_choose[i]
        der_J /= choose_size
        der_M /= choose_size

        J = 0
        for i in range(int(choose_size / 2)):
            J += ret_choose[i] * is_ratio_choose[i]
        J /= int(choose_size / 2)
        
        der_J2 = np.zeros((self.n_state, self.n_action))
        for i in range(int(choose_size/2), choose_size):
            der_J2 += rf_der_choose[i] * is_ratio_choose[i]
        der_J2 /= choose_size - int(choose_size/2)

        self.theta += lr_policy * (der_J - lam * (der_M - 2* J * der_J2))

        '''entropy'''
        if (ent_coef is not None) and (ent_coef > 0):
            # choose selected trajectories
            all_sample = []
            for idx in is_idx[0]:
                all_sample += self.traj_buf[idx]
            random.shuffle(all_sample)

            sample_size = min(400, len(all_sample))
            der_ent = np.zeros((self.n_state, self.n_action))

            for i in range(sample_size):
                item = all_sample[i]
                pi = self.get_action_prob(item[0])
                for a in self.actions:
                    a = int(a)
                    der_theta_ = self.get_derivative(item[0], a)
                    der_ent -= (1 + np.log(pi[a])) * pi[a] * der_theta_

            self.theta += lr_policy * ent_coef * der_ent

        '''check sample size'''
        if choose_size < 30:
            #print('choose size', choose_size)
            return True
        else:
            return False

########## functional ############
def play_episode(env, model):
    global gamma

    traj = []

    s = env.reset()
    done = False
    while not done:
        a, pi_a = model.get_action(s)   
        s2, r, done, info = env.step(a)

        traj.append([s, a, r, pi_a])
        
        s = s2

    return traj

def eval_model(env, agent, n_episodes):
    ep_r_lst = []

    # 0: fail, 1: long, 2: short

    for i in range(n_episodes):
        s, done, ep_r = env.reset(), False, 0

        while not done:
            a = agent.eval_action(s)
            s_prime, r, done, _ = env.step(a)

            ep_r += r
            s = s_prime

        ep_r_lst.append(ep_r)

    return np.array(ep_r_lst)

def get_save_dir(bs, goal_r, stoc_trans, use_reinforce, use_baseline, lam, seed, lr_policy, lr_value=None, ent_coef=None):
    save_dir = './'
    if stoc_trans:
        save_dir += 'test_stoc/bs_'+str(bs)+'/goal_'+str(goal_r)+'/'
    else:
        save_dir += 'test_deter/bs_'+str(bs)+'/goal_'+str(goal_r)+'/'

    if use_reinforce:
        if use_baseline:
            assert lr_value is not None
            if ent_coef is None:
                save_dir += 'rf_bl_var/lam='+str(lam)+'/seed_'+str(seed)+'/lr_p='+str(lr_policy)+'/lr_v='+str(lr_value)+'/'
            else:
                save_dir += 'rf_bl_var_ent/ent_coef_' + str(ent_coef) + '/lam='+str(lam)+'/seed_'+str(seed)+'/lr_p='+str(lr_policy)+'/lr_v='+str(lr_value)+'/'
        else:
            save_dir += 'rf_var/lam='+str(lam)+'/seed_'+str(seed)+'/lr_p='+str(lr_policy)+'/'

    else:
        save_dir += 'mean_var/lam='+str(lam)+'/seed_'+str(seed)+'/lr_p='+str(lr_policy)+'/'

    return save_dir

############## setting #############
seed = args.seed
gamma = 0.999
lr_policy = 1e-5
lr_value = lr_policy * 100
lam = args.lam
ent_coef = None
episodes = 50
inner_update = 10
train_epochs = 4000
test_intvl = 20
stoc_trans = False
goal_r = 20.
use_reinforce = True
use_baseline = True
eval_episodes = 10

register(
    id='GuardedMazeEnv-v0',
    entry_point='GuardedMaze_Discrete:GuardedMaze',
    kwargs=dict(
        mode=1,
        max_steps=100,
        guard_prob=1.0,
        goal_reward=goal_r,
        stochastic_trans=stoc_trans,
    )
) 

env = gym.make('GuardedMazeEnv-v0')
eval_env = gym.make('GuardedMazeEnv-v0')
env.seed(seed=seed)
eval_env.seed(seed=2**31-1-seed)
np.random.seed(seed)

state_space = 36
action_space = 4

agent = Policy(state_space, action_space, episodes)
print('--- seed:', seed, ', sample size:', episodes, ', lambda:',lam, ' ent_coef:', ent_coef)
print('lr_policy:', lr_policy, 'lr_value:', lr_value)

eval_r_lst = []

for epi in range(train_epochs):
    for _ in range(episodes):
        traj = play_episode(env, agent)
        agent.put_data(traj)
    agent.reinforce_baseline_varaince_k(lr_policy, lr_value, gamma, lam, ent_coef, inner_update)

    if (epi+1) % test_intvl == 0:
        # test
        eval_r = eval_model(eval_env, agent, eval_episodes)
        eval_r_lst.append(eval_r)

# create save dir
save_dir = get_save_dir(episodes, goal_r, stoc_trans, use_reinforce, use_baseline, lam, seed, lr_policy, lr_value, ent_coef)
os.makedirs(save_dir, exist_ok=True)

with open(save_dir + 'eval_r.npy', 'wb') as f:
    np.save(f, np.array(eval_r_lst))

# save theta
with open(save_dir + 'theta.npy', 'wb') as f:
    np.save(f, agent.theta)
