import gymnasium as gym
import numpy as np
import torch
import json
import random
from collections import deque
import matplotlib.pyplot as plt

"""
Global constants
"""
SEED = 31
MAX_STEPS = 500

STATE_DIM = 4
ACTION_DIM = 2

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

"""
Import Self-Defined Module
"""

from Networks import Reward
from Algorithms import PPO
from Panels import Panel_Env_Reward

#%%
if __name__ == "__main__":
    max_iterations = 50
    max_episodes_per_iteration = 100
    episodes_reward_model = max_episodes_per_iteration * max_iterations

    env = gym.make('CartPole-v1', max_episode_steps=MAX_STEPS)
    env.reset(seed=SEED)

    agent = PPO()
    agent.load_model("./Models/actor_pretrain.pth")

    panel = Panel_Env_Reward(link_function='L', expertise=0.001)
    reward_net = Reward(state_dim=STATE_DIM, action_dim=ACTION_DIM)
    optimizer = torch.optim.AdamW(params=reward_net.parameters(), lr=1e-3)

    if 1:
        # collect trajectories
        trajs = [[], []]
        for episode in range(episodes_reward_model):
            state, _ = env.reset()
            states = []
            actions = []
            rewards = []
            while 1:
                action, log_prob = agent.select_action(state)
                next_state, reward, terminated, truncated, _ = env.step(action)

                rewards.append(reward)
                states.append(state.tolist())
                actions.append(action)

                done = terminated or truncated
                state = next_state
                if done:
                    break
            traj = {
                'states': states,
                'actions': actions,
                'rewards': rewards,
            }
            if episode % 2 == 0:
                trajs[0].append(traj)
            else:
                trajs[1].append(traj)
            print('episode:', episode, 'reward:', np.sum(rewards))

        # compare trajectories
        returns_0 = [np.sum(traj['rewards']) for traj in trajs[0]]
        returns_1 = [np.sum(traj['rewards']) for traj in trajs[1]]
        results, probs = panel.individual_preference_from_reward(returns_0, returns_1)
        labels = np.mean(results, axis=1).tolist()

        data = {
            'trajs': trajs,
            'labels': labels,
        }
        with open('./Models/data.json', 'w') as f:
            json.dump(data, f)

    # train
    with (open('./Models/data.json', 'r')) as f:
        data = json.load(f)
    trajs = data['trajs']
    labels = data['labels']
    for epoch in range(5):
        probs_hat = []
        for traj_0, traj_1, label in zip(trajs[0], trajs[1], labels):
            states = torch.FloatTensor(np.array(traj_0['states']))
            actions = torch.LongTensor(np.array(traj_0['actions']))
            actions = torch.nn.functional.one_hot(actions, ACTION_DIM)
            return_hat_0 = torch.sum(reward_net(states, actions))

            states = torch.FloatTensor(np.array(traj_1['states']))
            actions = torch.LongTensor(np.array(traj_1['actions']))
            actions = torch.nn.functional.one_hot(actions, ACTION_DIM)
            return_hat_1 = torch.sum(reward_net(states, actions))

            prob_hat = 1 / (1 + torch.exp(-(return_hat_1 - return_hat_0)))
            probs_hat.append(prob_hat)

        probs_hat = torch.stack(probs_hat)
        labels = torch.FloatTensor(labels)

        loss = - (torch.log(probs_hat + 1e-18) * labels + torch.log(1 - probs_hat + 1e-18) * (1 - labels)).mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(loss.item())
    torch.save(reward_net.state_dict(), './Models/reward_net.pth')

    env.close()
    print()
