import numpy as np
from Algorithm import RM, PPO
from Env import GridWorld
from collections import deque
import json

#%%
num_episode = 1000 * 1000
np.random.seed(42)


env = GridWorld(size = 5, horizon=10)

reward_model = RM(time_horizon = 10, state_size=25, action_size= 4, lr = 0.005, batch_size = 1000)
reward_model.read_RM('./file/PPO/RM_BT.pth')
reward_model.compute_reward()

agent = PPO(time_horizon = 10, state_size = 25, action_size = 4, lr = 0.6, batch_size = 1000)
agent.compute_policy()

value_opt, policy_opt, Q_opt = env.value_iteration()
print('Optimal Value', value_opt[12], 'Optimal First Action', policy_opt)

reward_move_ave = deque(maxlen = 1000)
reward_move_ave_perturb = deque(maxlen = 1000)

# train with PPO
reward_curve = []
reward_curve_set = []
policy_set = []
for episode in range(num_episode):
    # collect new trajectory
    state = env.reset()
    reward_traj = 0
    states, actions, rewards_rm, probs = [], [], [], []
    for step in range(agent.time_horizon):
        action, action_prob = agent.select_action(state)
        next_state, reward, = env.step(action)
        reward_from_RM = reward_model.query_RM(state, action)

        reward_traj += reward

        states.append(state)
        actions.append(action)
        rewards_rm.append(reward_from_RM)
        probs.append(action_prob[action])

        state = next_state
    reward_move_ave.append(reward_traj)

    # calculate the returns
    returns = []
    cumulative_return = 0
    for reward_rm in reversed(rewards_rm):
        cumulative_return = reward_rm + cumulative_return
        returns.insert(0, cumulative_return)

    agent.replay_buffer.append([states, actions, rewards_rm, probs, returns])

    if (episode + 1) % agent.batch_size == 0:
        for epoch in range(5):
            loss = agent.train()
        agent.replay_buffer = []
        print('Episode:', episode + 1, 'Reward', np.average(reward_move_ave), 'Loss', loss)
        agent.update_reference_policy()
        agent.compute_policy()
        reward_curve.append(np.average(reward_move_ave))
        reward_curve_set.append(list(reward_move_ave).copy())
        policy_set.append(agent.policy.tolist().copy())

print()

#%%
with open('./file/PPO/reward_BT.json', 'w') as f:
    json.dump(reward_curve_set, f)
with open('./file/PPO/policy_BT.json', 'w') as f:
    json.dump(policy_set, f)
