import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import json
import random
from collections import deque
import matplotlib.pyplot as plt

"""
Global constants
"""
SEEDs = [33, 81, 34, 44, 42, 41, 31, 173, 139, 83]
MAX_STEPS = 1000

STATE_DIM = 17
ACTION_DIM = 6
ACTION_HIGH = torch.FloatTensor(np.ones(ACTION_DIM))
ACTION_LOW = - torch.FloatTensor(np.ones(ACTION_DIM))

"""
Import Self-Defined Module
"""

from Evaluation import evaluate_actor
from Panels import Panel_Env_Reward
from Algorithms import ZSPO

#%%
if __name__ == "__main__":
    max_iterations = 500
    max_episodes_per_iteration = 1
    returns_curves = []

    for SEED in SEEDs:
        env = gym.make('HalfCheetah-v5', max_episode_steps=MAX_STEPS)
        env.reset(seed=SEED)
        random.seed(SEED)
        np.random.seed(SEED)
        torch.manual_seed(SEED)

        agent = ZSPO()
        agent.load_model("./Models/actor_initial.pth")

        panel = Panel_Env_Reward(link_function='BT', expertise=0.1)
        # panel = Panel_Env_Reward(link_function='L', expertise=0.001)

        returns_curve = []
        returns_queue = deque(maxlen=100)
        for policy_iter in range(max_iterations):
            perturb_vec = agent.perturb_actor()

            returns_0 = evaluate_actor(agent.actor, env, num_of_episodes=max_episodes_per_iteration, deterministic=1)
            returns_queue.extend(returns_0)

            # Evaluate +epsilon perturbation
            returns_plus = evaluate_actor(agent.actor_perturb, env, num_of_episodes=max_episodes_per_iteration, deterministic=1)

            # Compute preference and train
            results_plus, prob_plus = panel.batch_preference_from_reward(returns_0, returns_plus)
            agent.train(np.mean(results_plus))

            print('SEED:', SEED,
                  ',Policy Iteration:', policy_iter,
                  ',Returns:', round(np.mean(returns_0), 4), '+-',
                  round(2 * np.std(returns_0) / np.sqrt(max_episodes_per_iteration), 4),
                  ',Prob:', round(np.mean(prob_plus), 4),
                  ',Cumulative Returns:', round(np.mean(returns_queue), 4),
                  ",Reward diff (mean1 - mean0):", round(np.mean(returns_plus) - np.mean(returns_0), 4))
            # torch.save(agent.actor.state_dict(), './ZSPO/actor_' + str(policy_iter) + '.pth')

            returns_curve.append(float(np.mean(returns_0)))
        returns_curves.append(returns_curve)
        env.close()
    data = {
        'returns': returns_curves
    }
    with open('./ZSPO/data.json', 'w') as f:
        json.dump(data, f, indent=4)

    plt.figure(figsize=(10, 5))
    plt.plot(np.mean(returns_curves, axis=0))
    plt.xlabel('Episode')
    plt.ylabel('Return')
    plt.title('Return Over Training Episodes')
    plt.grid()
    plt.tight_layout()
    plt.savefig("./ZSPO/return_plot.png")  # Optional: save to file
    plt.show()
