import gymnasium as gym
import numpy as np
import torch
import json
import random
from collections import deque
import matplotlib.pyplot as plt

"""
Global constants
"""
SEEDs = [33, 81, 34, 44, 42, 41, 31, 173, 139, 83]
MAX_STEPS = 1000

STATE_DIM = 17
ACTION_DIM = 6
ACTION_HIGH = torch.FloatTensor(np.ones(ACTION_DIM))
ACTION_LOW = - torch.FloatTensor(np.ones(ACTION_DIM))

"""
Import Self-Defined Module
"""
from Networks import Reward
from Evaluation import evaluate_actor
from Algorithms import PPO
from Panels import Panel_Env_Reward

#%%
if __name__ == "__main__":
    max_iterations = 500
    max_episodes_per_iteration = 1
    returns_curves = []

    for SEED in SEEDs:
        env = gym.make('HalfCheetah-v5', max_episode_steps=MAX_STEPS)
        env.reset(seed=SEED)
        random.seed(SEED)
        np.random.seed(SEED)
        torch.manual_seed(SEED)

        agent = PPO()
        agent.load_model("./Models/actor_initial.pth")

        reward_net = Reward(state_dim=STATE_DIM, action_dim=ACTION_DIM)
        reward_net.load_state_dict(torch.load('./Models/reward_net.pth', weights_only=True))

        returns_curve = []
        returns_queue = deque(maxlen=100)
        policy_iter = 0
        memory = {
            'states': [], 'actions': [], 'rewards': [], 'log_probs': [],
            'dones': [], 'truncateds': [], 'values': []
        }
        for episode in range(max_iterations * max_episodes_per_iteration):
            state, _ = env.reset()
            total_reward = 0

            while 1:
                action, log_prob = agent.select_action(state)
                value = agent.critic(torch.tensor(state, dtype=torch.float32).unsqueeze(0)).item()

                next_state, reward, terminated, truncated, _ = env.step(action)
                reward_prox = reward_net(torch.FloatTensor(state).unsqueeze(0), torch.FloatTensor(action).unsqueeze(0)).item()
                done = terminated or truncated

                memory['states'].append(state)
                memory['actions'].append(action)
                memory['rewards'].append(reward_prox)
                memory['log_probs'].append(log_prob)
                memory['dones'].append(terminated)
                memory['truncateds'].append(truncated)
                memory['values'].append(value)

                total_reward += reward
                state = next_state

                if done:
                    last_state = next_state
                    last_truncated = truncated
                    break

            if last_truncated:
                next_value = agent.critic(torch.tensor(last_state, dtype=torch.float32).unsqueeze(0)).item()
            else:
                next_value = 0.0
            advantages, returns = agent.compute_gae(memory['rewards'], memory['dones'], memory['values'], next_value)
            memory['advantages'] = advantages
            memory['returns'] = returns

            returns_queue.append(total_reward)

            if (episode + 1) % max_episodes_per_iteration == 0:
                returns = evaluate_actor(agent.actor, env, num_of_episodes=max_episodes_per_iteration, deterministic=1)
                returns_curve.append(float(np.mean(returns)))
                agent.train(memory)
                memory = {
                    'states': [], 'actions': [], 'rewards': [], 'log_probs': [],
                    'dones': [], 'truncateds': [], 'values': []
                }
                print('SEED:', SEED,
                      'Policy Iteration:', policy_iter,
                      ',Returns:', round(np.mean(returns), 4), '+-', round(np.std(returns), 4)
                      )
                # torch.save(agent.actor.state_dict(), './RMPPO/actor_' + str(policy_iter) + '.pth')
                policy_iter += 1
        returns_curves.append(returns_curve)
        env.close()

    data = {
        'returns': returns_curves
    }
    with open('./RMPPO/data.json', 'w') as f:
        json.dump(data, f, indent=4)

    plt.figure(figsize=(10, 5))
    plt.plot(np.mean(returns_curves, axis=0))
    plt.xlabel('Episode')
    plt.ylabel('Return')
    plt.title('Return Over Training Episodes')
    plt.grid()
    plt.tight_layout()
    plt.savefig("./RMPPO/return_plot.png")  # Optional: save to file
    plt.show()
