import numpy as np
from AlgorithmNumPy import ZSPO
from Env import GridWorld
from Panel import RewardPanelLinear, RewardPanelWeibull, RewardPanel
from collections import deque
import json

num_episode = 1000 * 1000
np.random.seed(42)


env = GridWorld(size = 5, horizon=10)
agent = ZSPO(time_horizon = 10,
             state_size= 25,
             action_size= 4,
             lr = 0.1,
             batch_size = 1000,
             perturbation_dist = 1)
agent.weight = np.ones_like(agent.weight)
agent.perturb_weight()

# panel = RewardPanelLinear(size = 100)
# panel = RewardPanelWeibull(size = 100)
panel = RewardPanel(size = 100)

value_opt, policy_opt, Q_opt = env.value_iteration()
print('Optimal Value', value_opt[12], 'Optimal First Action', policy_opt)

reward_move_ave = deque(maxlen = agent.batch_size)
reward_move_ave_perturb = deque(maxlen = agent.batch_size)

reward_curve = []
reward_curve_set = []
weight_set = []
for episode in range(num_episode):
    # current policy
    state = env.reset()
    reward_traj = 0
    for step in range(agent.time_horizon):
        action, action_prob = agent.select_action(state, perturb=0)
        next_state, reward, = env.step(action)
        reward_traj += reward
        state = next_state

    # perturb policy
    state = env.reset()
    reward_traj_perturb = 0
    for step in range(agent.time_horizon):
        action, action_prob_perturb = agent.select_action(state, perturb=1)
        next_state, reward, = env.step(action)
        reward_traj_perturb += reward
        state = next_state

    panel_result, panel_prob = panel.query_panel(reward_traj, reward_traj_perturb)
    preference_major = np.sign(np.sum(panel_result - 0.5)) / 2 + 0.5 # 0 and 1
    agent.add_to_preference(preference_major)

    reward_move_ave.append(reward_traj)
    reward_move_ave_perturb.append(reward_traj_perturb)

    if (episode + 1) % agent.batch_size == 0:
        correct_action_prob = np.zeros(agent.state_size)
        for s in range(agent.state_size):
            _, action_prob = agent.select_action(s, perturb=0)
            correct_action_prob[s] = action_prob[policy_opt[s]]

        print('Episode', episode + 1, 'Reward', np.round(reward_traj, 2),
              'Average:', np.round(np.average(np.array(reward_move_ave)), 2),
              np.round(np.average(np.array(reward_move_ave_perturb)), 2),
              'First Action:', np.round(correct_action_prob[12 - 5: 12 + 5], 2))
        reward_curve.append(np.average(np.array(reward_move_ave)))
        reward_curve_set.append(list(reward_move_ave.copy()))
        agent.train()
        agent.preference_prob = []
        weight_set.append(agent.weight.tolist().copy())
        agent.perturb_weight()
print('Save')
with open('./file/ZSPO/reward.json', 'w') as f:
    json.dump(reward_curve_set, f)
with open('./file/ZSPO/weight.json', 'w') as f:
    json.dump(weight_set, f)
print('Complete')