import gymnasium as gym
import torch.optim as optim
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
import random
from torch.distributions import Normal
from collections import deque
import matplotlib.pyplot as plt

"""
Global constants
"""
SEED = 31
MAX_STEPS = 1000

STATE_DIM = 11
ACTION_DIM = 3
ACTION_HIGH = torch.FloatTensor(np.ones(ACTION_DIM))
ACTION_LOW = - torch.FloatTensor(np.ones(ACTION_DIM))

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

"""
Import Self-Defined Module
"""

from Networks import Actor
from Algorithms import ZSPO
from Evaluation import evaluate_actor
from Panels import Panel_Env_Reward

#%%
if __name__ == "__main__":
    max_iterations = 200
    max_episodes_per_iteration = 5

    env = gym.make('Hopper-v5', max_episode_steps=MAX_STEPS)
    env.reset(seed=SEED)

    agent = ZSPO(expertise=0.01)
    agent.load_model("./Models/actor_pretrain.pth")

    panel = Panel_Env_Reward(link_function='L', expertise = 0.01, size = 100)

    ct = 0
    returns_curve = []
    stds_curve = []
    returns_queue = deque(maxlen=10)
    for policy_iter in range(max_iterations):
        perturb_vec = agent.perturb_actor()

        returns_0 = evaluate_actor(agent.actor, env, num_of_episodes=max_episodes_per_iteration, deterministic=1)
        returns_queue.extend(returns_0)



        actor_plus = Actor(STATE_DIM, ACTION_DIM)
        actor_plus.load_state_dict(agent.actor_perturb.state_dict())

        # Evaluate +epsilon perturbation
        returns_plus = evaluate_actor(actor_plus, env, num_of_episodes=max_episodes_per_iteration, deterministic=1)

        actor_minus = Actor(STATE_DIM, ACTION_DIM)
        actor_minus.load_state_dict(agent.actor_perturb.state_dict())

        # Now apply -epsilon to get antithetic version
        perturb_vec_for_minus = [v.clone() for v in perturb_vec]  # Safe copy
        with torch.no_grad():
            for module in actor_minus.modules():
                if isinstance(module, nn.Linear):
                    for param in module.parameters():
                        param.sub_(2 * perturb_vec_for_minus.pop(0) * agent.perturbation_dist)
            for name, param in actor_minus.named_parameters():
                if '.' not in name:
                    if not any(name.startswith(pn + '.') for pn, _ in actor_minus.named_modules()):
                        param.sub_(2 * perturb_vec_for_minus.pop(0) * agent.perturbation_dist)

        returns_minus = evaluate_actor(actor_minus, env, num_of_episodes=max_episodes_per_iteration, deterministic=1)

        # Compute preference and train
        results, prob_truth = panel.batch_preference_from_reward(returns_minus, returns_plus)
        return_diff = np.mean(returns_plus) - np.mean(returns_minus)

        results_plus, prob_plus = panel.batch_preference_from_reward(returns_0, returns_plus)
        results_minus, prob_minus = panel.batch_preference_from_reward(returns_0, returns_minus)

        diff_plus = np.mean(returns_plus) - np.mean(returns_0)
        diff_minus = np.mean(returns_minus) - np.mean(returns_0)

        if prob_truth > 0.5:
            if prob_plus > 0.5:
                agent.train(prob_plus)
        elif prob_truth < 0.5:
            if prob_minus > 0.5:
                agent.train(1 - prob_minus)


        print('Policy Iteration:', policy_iter,
              ',Returns:', round(np.mean(returns_0), 4), '+-', round(np.std(returns_0), 4),
              ',Prob:', round(np.mean(results_plus),4), round(np.mean(results_minus), 4),
              ',Cumulative Returns:', round(np.mean(returns_queue), 4),
              ",Reward diff (mean1 - mean0):", round(np.mean(returns_plus) - np.mean(returns_0),4), round(np.mean(returns_minus) - np.mean(returns_0),4))
        torch.save(agent.actor.state_dict(), './ZSPO/actor_' + str(policy_iter) + '.pth')

        returns_curve.append(float(np.mean(returns_0)))
        stds_curve.append(float(np.std(returns_0)))

    data = {
        'returns': returns_curve,
        'stds': stds_curve
    }
    with open('./ZSPO/data.json', 'w') as f:
        json.dump(data, f, indent=4)

    plt.figure(figsize=(10, 5))
    plt.plot(returns_curve)
    plt.xlabel('Episode')
    plt.ylabel('Return')
    plt.title('Return Over Training Episodes')
    plt.grid()
    plt.tight_layout()
    plt.savefig("./ZSPO/return_plot.png")  # Optional: save to file
    plt.show()

    env.close()