import gymnasium as gym
import numpy as np
import torch
import json
import random
from collections import deque
import matplotlib.pyplot as plt

"""
Global constants
"""
SEEDs = [33, 81, 34, 44, 42, 41, 31, 173, 139, 83]
SEEDs = [139, 83]
MAX_STEPS = 500

STATE_DIM = 4
ACTION_DIM = 2

"""
Import Self-Defined Module
"""

from Algorithms import ZSPO
from Evaluation import evaluate_actor
from Panels import Panel_Env_Reward

#%%
if __name__ == "__main__":
    max_iterations = 200
    for max_episodes_per_iteration in [1, 4, 16, 64]:

        returns_curves = []
        for SEED in SEEDs:
            env = gym.make('CartPole-v1', max_episode_steps=MAX_STEPS)
            env.reset(seed=SEED)
            random.seed(SEED)
            np.random.seed(SEED)
            torch.manual_seed(SEED)

            agent = ZSPO()
            agent.load_model("./Models/actor_pretrain.pth")

            # panel = Panel_Env_Reward(link_function='BT', expertise = 0.1)
            panel = Panel_Env_Reward(link_function='L', expertise = 0.001)

            returns_curve = []
            returns_queue = deque(maxlen=10)
            for policy_iter in range(max_iterations):
                perturb_vec = agent.perturb_actor()

                returns_0 = evaluate_actor(agent.actor, env, num_of_episodes=max_episodes_per_iteration, deterministic=1)
                returns_queue.extend(returns_0)

                # Evaluate +epsilon perturbation
                returns_plus = evaluate_actor(agent.actor_perturb, env, num_of_episodes=max_episodes_per_iteration, deterministic=1)

                results_plus, prob_plus = panel.individual_preference_from_reward(returns_0, returns_plus)
                agent.train(np.mean(results_plus))

                print('Seed:', SEED,
                      'Policy Iteration:', policy_iter,
                      ',Returns:', round(np.mean(returns_0), 4), '+-', round(2 * np.std(returns_0) / np.sqrt(max_episodes_per_iteration), 4),
                      ',Prob:', round(np.mean(prob_plus),4),
                      ',Cumulative Returns:', round(np.mean(returns_queue), 4),
                      ",Reward diff (mean1 - mean0):", round(np.mean(returns_plus) - np.mean(returns_0),4))
                # torch.save(agent.actor.state_dict(), './ZSPO/actor_' + str(policy_iter) + '.pth')

                returns_curve.append(float(np.mean(returns_0)))
            returns_curves.append(returns_curve)
            torch.save(agent.actor.state_dict(), './ZSPO/actor_' + str(max_episodes_per_iteration) + '_' + str(SEED)  + '.pth')
            env.close()
        data = {
            'returns': returns_curves
        }
        with open('./ZSPO/data' + '_' + str(max_episodes_per_iteration) + '.json', 'w') as f:
            json.dump(data, f, indent=4)

        plt.figure(figsize=(10, 5))
        plt.plot(np.mean(returns_curves, axis=0))
        plt.xlabel('Episode')
        plt.ylabel('Return')
        plt.title('Return Over Training Episodes')
        plt.grid()
        plt.tight_layout()
        plt.savefig('./ZSPO/return_plot'+ '_' + str(max_episodes_per_iteration) + '.png')  # Optional: save to file
        # plt.show()

