import gymnasium as gym
import numpy as np
import torch
import json
import random
from collections import deque
import matplotlib.pyplot as plt

"""
Global constants
"""
SEEDs = [33, 81, 34, 44, 42, 41, 31, 173, 139, 83]
MAX_STEPS = 500

STATE_DIM = 4
ACTION_DIM = 2

"""
Import Self-Defined Module
"""

from Algorithms import DPO
from Evaluation import evaluate_actor
from Panels import Panel_Env_Reward

#%%
if __name__ == "__main__":
    max_iterations = 500
    max_episodes_per_iteration = 1

    returns_curves = []
    for SEED in SEEDs:
        env = gym.make('CartPole-v1', max_episode_steps=MAX_STEPS)
        env.reset(seed=SEED)
        random.seed(SEED)
        np.random.seed(SEED)
        torch.manual_seed(SEED)

        agent = DPO()
        agent.load_model("./Models/actor_pretrain.pth")

        panel = Panel_Env_Reward(link_function='BT', expertise=0.1)
        # panel = Panel_Env_Reward(link_function='L', expertise=0.001)

        returns_curve = []
        returns_queue = deque(maxlen=10)
        for policy_iter in range(max_iterations):
            memory = {'states_0': [], 'actions_0': [], 'states_1': [], 'actions_1': [], 'log_probs_0': [], 'log_probs_1': [], 'probs':[]}
            returns_0, memory['states_0'], memory['actions_0'], memory['log_probs_0'] = evaluate_actor(agent.actor_ref, env, num_of_episodes=max_episodes_per_iteration, require_trajs=1)
            returns_1, memory['states_1'], memory['actions_1'], memory['log_probs_1'] = evaluate_actor(agent.actor_ref, env, num_of_episodes=max_episodes_per_iteration, require_trajs=1)

            returns_queue.extend(returns_0)

            results, probs = panel.individual_preference_from_reward(returns_0, returns_1)
            memory['probs'] = np.mean(results, axis=1)

            returns = evaluate_actor(agent.actor, env, num_of_episodes=max_episodes_per_iteration, deterministic=1)
            returns_queue.extend(returns)

            loss = agent.train(memory)

            print('SEED:', SEED, 
                  ',Policy Iteration:', policy_iter,
                  ',Returns:', round(np.mean(returns), 4), '+-', round(np.std(returns) / np.sqrt(max_episodes_per_iteration), 4),
                  ',Loss:', round(loss,4),
                  ',Cumulative Returns:', round(np.mean(returns_queue), 4))
            # torch.save(agent.actor.state_dict(), './DPO/actor_' + str(policy_iter) + '.pth')

            returns_curve.append(float(np.mean(returns)))
        returns_curves.append(returns_curve)
        env.close()
    data = {
        'returns': returns_curves
    }
    with open('./DPO/data.json', 'w') as f:
        json.dump(data, f, indent=4)

    plt.figure(figsize=(10, 5))
    plt.plot(np.mean(returns_curves, axis=0))
    plt.xlabel('Episode')
    plt.ylabel('Return')
    plt.title('Return Over Training Episodes')
    plt.grid()
    plt.tight_layout()
    plt.savefig("./DPO/return_plot.png")  # Optional: save to file
    plt.show()
