import gym
import numpy as np
import pickle as pkl

from maxent import *
import argparse

n_states = 6**4 #
n_actions = 2
one_feature = 6 
q_table = np.zeros((n_states, n_actions))
feature_matrix = np.eye((n_states)) 

gamma = 0.99
q_learning_rate = 0.03
theta_learning_rate = 0.05

np.random.seed(1)

def int_min_max(x):
    return int(min(max(x, 0), one_feature-1))

def idx_demo(env, one_feature):
    with open('../rl_baselines_zoo/experts/CartPole-v1_expert_demo.pkl', 'rb') as f:
        obs_expert, actions_expert, rewards_expert, next_obs_expert, dones_expert = pkl.load(f)
    obs = np.concatenate(obs_expert)
    actions = np.array(actions_expert).reshape(-1, 1)

    raw_demo = np.concatenate([obs, actions], axis=1)
    demonstrations = np.zeros((raw_demo.shape[0], 2))
    #env_low = env.observation_space.low     
    #env_high = env.observation_space.high  
    env_low = -np.abs(obs).max(axis=0)/3
    env_high = np.abs(obs).max(axis=0)/3
    env_distance = (env_high - env_low) / one_feature
    for i, (pos, vel, ang, angvel, action) in enumerate(raw_demo):
        pos_idx, vel_idx, ang_idx, angvel_idx = ((pos, vel, ang, angvel) - env_low)/env_distance
        pos_idx, vel_idx, ang_idx, angvel_idx = int_min_max(pos_idx), int_min_max(vel_idx),int_min_max(ang_idx), int_min_max(angvel_idx)
        state_idx = pos_idx + vel_idx * one_feature + (ang_idx*(one_feature**2)) + (angvel_idx*(one_feature**3))

        demonstrations[i, 0] = state_idx
        demonstrations[i, 1]= action
    
    return demonstrations, sum(dones_expert)[0], env_low, env_high



def idx_state(env, state):
    #env_low = env.observation_space.low
    #env_high = env.observation_space.high
    env_distance = (env_high - env_low) / one_feature 
    pos_idx, vel_idx, ang_idx, angvel_idx = (state - env_low)/env_distance
    pos_idx, vel_idx, ang_idx, angvel_idx = int_min_max(pos_idx), int_min_max(vel_idx),int_min_max(ang_idx), int_min_max(angvel_idx)
    state_idx = pos_idx + vel_idx * one_feature + (ang_idx*(one_feature**2)) + (angvel_idx*(one_feature**3))
    return state_idx

def update_q_table(state, action, reward, next_state):
    q_1 = q_table[state][action]
    q_2 = reward + gamma * max(q_table[next_state])
    q_table[state][action] += q_learning_rate * (q_2 - q_1)
    
def get_reward_mean(env, q_table, trials=50):
    R_list = []
    for i in range(trials):
        R = 0
        state = env.reset()
        while True:
            state_idx = idx_state(env, state)
            action = np.argmax(q_table[state_idx])
            next_state, reward, done, _ = env.step(action)
            R+=reward
            state = next_state
            
            if done:
                break
        R_list.append(R)
    return sum(R_list)/trials


def main():
    parser= argparse.ArgumentParser()
    parser.add_argument("-i", "--index", type=str,
                        help="index of experiment")
    
    args = parser.parse_args()
    
    env = gym.make('CartPole-v1')
    global env_low, env_high
    demonstrations, num_demo, env_low, env_high = idx_demo(env, one_feature)
    print(env_low, env_high)

    expert = expert_feature_expectations(feature_matrix, demonstrations, num_demo)
    learner_feature_expectations = np.zeros(n_states)

    theta = -(np.random.uniform(size=(n_states,)))

    episodes, scores = [], []
    scores_avg = []

    for episode in range(10000):
        print('{}/10000'.format(episode), end='\r')
        state = env.reset()
        score = 0

        if (episode != 0 and episode == 10000) or (episode > 10000 and episode % 5000 == 0):
            learner = learner_feature_expectations / episode
            maxent_irl(expert, learner, theta, theta_learning_rate)
                
        while True:
            state_idx = idx_state(env, state)
            action = np.argmax(q_table[state_idx])
            next_state, reward, done, _ = env.step(action)
            
            irl_reward = get_reward(feature_matrix, theta, n_states, state_idx)
            next_state_idx = idx_state(env, next_state)
            update_q_table(state_idx, action, irl_reward, next_state_idx)
            
            learner_feature_expectations += feature_matrix[int(state_idx)]

            score += reward
            state = next_state
            
            if done:
                scores.append(score)
                episodes.append(episode)
                break

        if episode % 500 == 0:
            score_avg = get_reward_mean(env, q_table, trials=50)
            print('{} episode score is {:.2f}'.format(episode, score_avg))
            scores_avg.append(score_avg)
            np.save("./results_cartpole/maxent_q_table_{}".format(args.index.zfill(2)), arr=q_table)
            np.save("./results_cartpole/score_avg_{}".format(args.index.zfill(2)), scores_avg)

if __name__ == '__main__':
    main()