import numpy as np
from Algorithm import DPO
from Env import GridWorld
from Panel import RewardPanelLinear, RewardPanelWeibull, RewardPanel
from collections import deque
import json

#%%
num_policy_updates = 1000
num_trajs_per_policy = 1000
num_episode = num_policy_updates * num_trajs_per_policy
np.random.seed(42)

env = GridWorld(size = 5, horizon=10)
agent = DPO(time_horizon = 10, state_size = 25, action_size = 4, lr = 0.5, batch_size = num_trajs_per_policy)
agent.compute_policy()

# panel = RewardPanelLinear(size = 100)
# panel = RewardPanelWeibull(size =100)
panel = RewardPanel(size = 100)

value_opt, policy_opt, Q_opt = env.value_iteration()
print('Optimal Value', value_opt[12], 'Optimal First Action', policy_opt)

reward_move_ave = deque(maxlen = agent.batch_size)

#%% train with DPO
reward_curve = []
reward_curve_set = []
policy_set = []
for episode in range(num_episode):
    # collect first trajectory
    state = env.reset()
    reward_traj_1 = 0
    traj1 = []
    for step in range(agent.time_horizon):
        action, action_prob = agent.select_action(state, old = 1)
        next_state, reward, = env.step(action)

        reward_traj_1 += reward
        traj1.append([state, action, action_prob[action]])

        state = next_state

    # collect second trajectory
    state = env.reset()
    reward_traj_2 = 0
    traj2 = []
    for step in range(agent.time_horizon):
        action, action_prob = agent.select_action(state, old = 1)
        next_state, reward, = env.step(action)

        reward_traj_2 += reward
        traj2.append([state, action, action_prob[action]])

        state = next_state

    panel_result, panel_prob = panel.query_panel(reward_traj_1, reward_traj_2)
    preference_prob = np.average(panel_result)

    agent.replay_buffer.append([traj1, traj2, preference_prob])

    # evaluation of current policy
    state = env.reset()
    reward_traj = 0
    for step in range(agent.time_horizon):
        action, action_prob = agent.select_action(state, old=0)
        next_state, reward, = env.step(action)

        reward_traj += reward
        state = next_state

    reward_move_ave.append(reward_traj)

    if (episode + 1) % agent.batch_size == 0:
        for epoch in range(5):
            loss = agent.train()
        agent.replay_buffer = []
        agent.compute_policy()
        print('Episode:', episode + 1, 'Reward', np.round(np.average(reward_move_ave),2), 'Loss', np.round(loss, 2))
        reward_curve.append(np.average(reward_move_ave))
        reward_curve_set.append(list(reward_move_ave).copy())
        policy_set.append(agent.policy.tolist().copy())

print()
#%%
with open('./file/DPO/reward.json', 'w') as f:
    json.dump(reward_curve_set, f)
with open('./file/DPO/policy.json', 'w') as f:
    json.dump(policy_set, f)