import gymnasium as gym
import torch.optim as optim
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
import random
from torch.distributions import Normal
from collections import deque
import matplotlib.pyplot as plt
"""
Global constants
"""
SEED = 42
MAX_STEPS = 1000

STATE_DIM = 17
ACTION_DIM = 6
ACTION_HIGH = torch.FloatTensor(np.ones(ACTION_DIM))
ACTION_LOW = - torch.FloatTensor(np.ones(ACTION_DIM))

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

"""
Import Self-Defined Module
"""
from Algorithms import PPO

if __name__ == "__main__":
    # Training Loop
    env = gym.make('HalfCheetah-v5', max_episode_steps=MAX_STEPS)
    env.reset(seed=SEED)

    agent = PPO()
    torch.save(agent.actor.state_dict(), './Models/actor_initial.pth')

    num_episodes = 1000
    returns_queue = deque(maxlen=100)
    returns_curve = []
    average_returns_curve = []

    memory = {'states': [], 'actions': [], 'rewards': [], 'log_probs': [], 'dones': [], 'values': []}
    for episode in range(num_episodes):
        state, _ = env.reset()
        total_reward = 0

        while 1:
            action, log_prob = agent.select_action(state)
            value = agent.critic(torch.tensor(state, dtype=torch.float32).unsqueeze(0)).item()

            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated

            memory['states'].append(state)
            memory['actions'].append(action)
            memory['rewards'].append(reward)
            memory['log_probs'].append(log_prob)
            memory['dones'].append(done)
            memory['values'].append(value)

            total_reward += reward
            state = next_state

            if done:
                state, _ = env.reset()
                break

        next_value = agent.critic(torch.tensor(state, dtype=torch.float32).unsqueeze(0)).item()
        advantages, returns = agent.compute_gae(memory['rewards'], memory['dones'], memory['values'], next_value)
        memory['advantages'] = advantages
        memory['returns'] = returns

        if len(memory['states']) >= agent.batch_size:
            agent.train(memory)
            memory = {'states': [], 'actions': [], 'rewards': [], 'log_probs': [], 'dones': [], 'values': []}

        returns_queue.append(total_reward)
        returns_curve.append(total_reward)
        average_returns_curve.append(np.mean(returns_queue))
        if episode % 1 == 0:
            print(f"Episode {episode}, Return: {total_reward:.2f}, Avg Return: {np.mean(returns_queue):.2f}")
            torch.save(agent.actor.state_dict(), './Pretrain/actor_' + str(episode) + '.pth')

    data = {
        'returns': returns_curve,
        'average_returns': average_returns_curve
    }
    with open('./Pretrain/data.json', 'w') as f:
        json.dump(data, f, indent = 4)

    plt.figure(figsize=(10, 5))
    plt.plot(average_returns_curve)
    plt.xlabel('Episode')
    plt.ylabel('Average Return (last 100 episodes)')
    plt.title('Average Return Over Training Episodes')
    plt.grid()
    plt.tight_layout()
    plt.savefig("./Pretrain/average_return_plot.png")  # Optional: save to file
    plt.show()

    plt.figure(figsize=(10, 5))
    plt.plot(returns_curve)
    plt.xlabel('Episode')
    plt.ylabel('Return')
    plt.title('Return Over Training Episodes')
    plt.grid()
    plt.tight_layout()
    plt.savefig("./Pretrain/return_plot.png")  # Optional: save to file
    plt.show()

    env.close()
