import numpy as np
import gym
from collections import deque
import matplotlib.pyplot as plt
from gym.envs.registration import register
import os
import tqdm
np.set_printoptions(precision=4,suppress=True)
import sys
sys.path.append('..')


############ setting ############
seed = 0
np.random.seed(seed)
goal_r = 20.
stoc_trans = False


register(
    id='GuardedMazeEnv-v0',
    entry_point='GuardedMaze_Discrete:GuardedMaze',
    kwargs=dict(
        mode=1,
        max_steps=100,
        guard_prob=1.0,
        goal_reward=goal_r,
        stochastic_trans=stoc_trans,
    )
)

# make environment
env = gym.make('GuardedMazeEnv-v0')
eval_env = gym.make('GuardedMazeEnv-v0')
env.seed(seed=seed)
eval_env.seed(seed=2**31-1-seed)
GAMMA = 0.999

print('goal reward:', goal_r, 'gamma:', GAMMA)

online_reward = deque(maxlen=1000)
online_reward.append(-1)

def test(env, q_table, n_episodes=10):
    ep_return = []
    all_r_list = []
    for _ in range(n_episodes):
        done = False
        r_list = []
        state = env.reset()
        while not done:
            action = np.argmax(q_table[state[0]][state[1]])
            next_state, r, done, info = env.step(action)
            r_list.append(r)
            all_r_list.append(r)
            state = next_state

        G = 0
        for i in range(len(r_list)):
            G += r_list[i] * pow(GAMMA, i)

        ep_return.append(G)
    
    J = np.array(ep_return).mean()

    return J, sum(all_r_list), np.mean(all_r_list)


def update(iter_num, env, eval_env):
    Q_table = np.random.rand(8,8,4)

    for k in range(iter_num):
        # first step, compute y = evaluate current policy
        G, total_r, mean_r = test(env, Q_table, 10)
        y = (1-GAMMA) * G
        #y = mean_r
        print('G:', G, 'mean_r:', mean_r, 'y:', y, 'goal_r:', goal_r - 0.2 * goal_r**2 + 2 * 0.2 * 20 * y)

        # second step, learn policy with modified reward
        Q_table = q_learning(env, eval_env, y)


    # final test
    G, total_r, mean_r = test(eval_env, Q_table, 1)
    eval_env.show()
    plt.show()


def q_learning(env, eval_env, y):
    q_table = np.random.rand(8,8,4)

    EPISODES = 20000
    STARTING_EPSILON = 1.0  # Starting epsilon
    EPSILON_END = 0.1
    STEPS_MAX = 50000       # Gradually reduce epsilon over these many steps
    LR = 5e-3
    TEST_INTVL = 10
    LAMBDA = 0.2

    EPSILON = STARTING_EPSILON

    pbar = tqdm.trange(EPISODES)
    for epi in pbar:
        done = False
        state = env.reset()

        while not done:
            # choose action
            if np.random.rand() < EPSILON:
                action = np.random.randint(4)
            else:
                qvalues = q_table[state[0]][state[1]]
                action = np.argmax(qvalues)
        
            next_state, r, done, info = env.step(action)

            online_reward.append(r)

            EPSILON = max(EPSILON_END, EPSILON - (1.0 / STEPS_MAX))

            d = (done and info['goal'])

            # modify reward
            modify_r = r - LAMBDA * pow(r, 2) + 2 * LAMBDA * r * y

            curr_q = q_table[state[0]][state[1]][action]
            targ_q = modify_r + GAMMA * (1-d) * np.max(q_table[next_state[0]][next_state[1]])
            new_q = curr_q + LR * (targ_q - curr_q)

            q_table[state[0]][state[1]][action] = new_q

            state = next_state

        if (epi+1) % TEST_INTVL == 0:
            G, total_r, mean_r = test(eval_env, q_table, 1)
            pbar.set_description("testR(%.2f)" % total_r)

    return q_table

update(4, env, eval_env)