import numpy as np
from Algorithm import RM, PPO
from Env import GridWorld
from Panel import RewardPanelWeibull, RewardPanelLinear, RewardPanel
from collections import deque
import json

#%%
num_episode = 500 * 1000
np.random.seed(42)


env = GridWorld(size = 5, horizon=10)
reward_model = RM(time_horizon = 10, state_size=25, action_size= 4, lr = 0.005, batch_size = 1000)
agent = PPO(time_horizon = 10, state_size = 25, action_size = 4, lr = 0.3, batch_size = 1000)
agent.compute_policy()

# panel = RewardPanelLinear(size = 100)
# panel = RewardPanelWeibull(size = 100)
panel = RewardPanel(size = 100)

value_opt, policy_opt, Q_opt = env.value_iteration()
print('Optimal Value', value_opt[12], 'Optimal First Action', policy_opt)

#%% RM training
if 1:
    for episode in range(int(num_episode / 2)):
        # trajectory 1
        state = env.reset()
        traj1 = []
        reward_traj1 = 0
        for step in range(agent.time_horizon):
            action, action_prob = agent.select_action(state)
            next_state, reward, = env.step(action)

            traj1.append([state, action])
            reward_traj1 += reward

            state = next_state

        # trajectory 2
        state = env.reset()
        traj2 = []
        reward_traj2 = 0
        for step in range(agent.time_horizon):
            action, action_prob_perturb = agent.select_action(state)
            next_state, reward, = env.step(action)

            traj2.append([state, action])
            reward_traj2 += reward

            state = next_state

        # preference from panel
        panel_result, panel_prob = panel.query_panel(reward_traj1, reward_traj2)
        reward_model.add_to_replay_buffer(traj1, traj2, np.average(panel_result))

        if (episode + 1) % reward_model.batch_size == 0:
            print('Episode:', episode + 1)

        if (episode + 1) % (reward_model.batch_size * 100) == 0:
            reward_model.save_replay_buffer('./file/PPO/replay_buffer.pkl')
else:
    reward_model.read_replay_buffer('./file/replay_buffer.pkl')
    print()

for epoch in range(5):
    loss = reward_model.train()
    print('Epoch', epoch + 1, 'Loss', loss)
    reward_model.save_RM('./file/PPO/RM_BT_' + str(epoch) + '.pth')