import gym
import numpy as np
import pickle as pkl

from maxent import *
import argparse

n_states = 400 # position - 20, velocity - 20
n_actions = 3
one_feature = 20 # number of state per one feature
q_table = np.zeros((n_states, n_actions)) # (400, 3)
feature_matrix = np.eye((n_states)) # (400, 400)

gamma = 0.99
q_learning_rate = 0.03
theta_learning_rate = 0.05

np.random.seed(1)

def idx_demo(env, one_feature):
    with open('../rl_baselines_zoo/experts/MountainCar-v0_expert_demo.pkl', 'rb') as f:
        obs_expert, actions_expert, rewards_expert, next_obs_expert, dones_expert = pkl.load(f)
    obs = np.concatenate(obs_expert)
    actions = np.array(actions_expert).reshape(-1, 1)

    raw_demo = np.concatenate([obs, actions], axis=1)
    demonstrations = np.zeros((raw_demo.shape[0], 2))
    env_low = env.observation_space.low     
    env_high = env.observation_space.high  
    env_distance = (env_high - env_low) / one_feature

    for i, (pos, vel, action) in enumerate(raw_demo):
        pos_idx, vel_idx = ((pos, vel) - env_low)/env_distance
        pos_idx, vel_idx = int(pos_idx), int(vel_idx)
        state_idx = pos_idx + vel_idx * one_feature

        demonstrations[i, 0] = state_idx
        demonstrations[i, 1]= action
    
    return demonstrations, sum(dones_expert)[0]

def idx_state(env, state):
    env_low = env.observation_space.low
    env_high = env.observation_space.high 
    env_distance = (env_high - env_low) / one_feature 
    position_idx = int((state[0] - env_low[0]) / env_distance[0])
    velocity_idx = int((state[1] - env_low[1]) / env_distance[1])
    state_idx = position_idx + velocity_idx * one_feature
    return state_idx

def update_q_table(state, action, reward, next_state):
    q_1 = q_table[state][action]
    q_2 = reward + gamma * max(q_table[next_state])
    q_table[state][action] += q_learning_rate * (q_2 - q_1)
    
def get_reward_mean(env, q_table, trials=50):
    R_list = []
    for i in range(trials):
        R = 0
        state = env.reset()
        while True:
            state_idx = idx_state(env, state)
            action = np.argmax(q_table[state_idx])
            next_state, reward, done, _ = env.step(action)
            R+=reward
            state = next_state
            
            if done:
                break
        R_list.append(R)
    return sum(R_list)/trials


def main():
    parser= argparse.ArgumentParser()
    parser.add_argument("-i", "--index", type=str,
                        help="index of experiment")
    
    args = parser.parse_args()
    
    env = gym.make('MountainCar-v0')
    demonstrations, num_demo = idx_demo(env, one_feature)

    expert = expert_feature_expectations(feature_matrix, demonstrations, num_demo)
    learner_feature_expectations = np.zeros(n_states)

    theta = -(np.random.uniform(size=(n_states,)))

    episodes, scores = [], []
    scores_avg = []

    for episode in range(50000):
        state = env.reset()
        score = 0

        if (episode != 0 and episode == 10000) or (episode > 10000 and episode % 5000 == 0):
            learner = learner_feature_expectations / episode
            maxent_irl(expert, learner, theta, theta_learning_rate)
                
        while True:
            state_idx = idx_state(env, state)
            action = np.argmax(q_table[state_idx])
            next_state, reward, done, _ = env.step(action)
            
            irl_reward = get_reward(feature_matrix, theta, n_states, state_idx)
            next_state_idx = idx_state(env, next_state)
            update_q_table(state_idx, action, irl_reward, next_state_idx)
            
            learner_feature_expectations += feature_matrix[int(state_idx)]

            score += reward
            state = next_state
            
            if done:
                scores.append(score)
                episodes.append(episode)
                break

        if episode % 1000 == 0:
            score_avg = get_reward_mean(env, q_table, trials=50)
            print('{} episode score is {:.2f}'.format(episode, score_avg))
            scores_avg.append(score_avg)
            #pylab.savefig("./learning_curves/maxent_30000.png")
            np.save("./results_mountaincar/maxent_q_table_{}".format(args.index.zfill(2)), arr=q_table)
            np.save("./results_mountaincar/score_avg_{}".format(args.index.zfill(2)), scores_avg)

if __name__ == '__main__':
    main()