import gymnasium as gym
import torch.optim as optim
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
import random
from torch.distributions import Normal, Categorical
from collections import deque
import matplotlib.pyplot as plt
"""
Global constants
"""
SEED = 42
MAX_STEPS = 1000

STATE_DIM = 4
ACTION_DIM = 2

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

"""
Import Self-Defined Module
"""
from Networks import Actor, Critic

#%%
"""
PPO Agent
"""
class PPOAgent:
    def __init__(self):
        self.actor = Actor(STATE_DIM,ACTION_DIM)                       # Actor Network
        self.critic = Critic(STATE_DIM)                                 # Critic Network

        self.lr = 3e-4
        self.gamma = 0.99
        self.gae_lambda = 0.95
        self.clip_eps = 0.2
        self.max_grad_norm = 0.5
        self.entropy_coef = 0.0
        self.batch_size = MAX_STEPS
        self.mini_batch_size = 64
        self.ppo_epochs = 10

        self.actor_optim = optim.Adam(self.actor.parameters(), lr=self.lr)
        self.critic_optim = optim.Adam(self.critic.parameters(), lr=self.lr)

    def select_action(self, state):
        state = torch.tensor(state, dtype=torch.float32)
        with torch.no_grad():
            prob = self.actor(state)
        dist = Categorical(prob)
        action = dist.sample()
        log_prob = dist.log_prob(action).sum(dim=-1)
        return action.item(), log_prob.item()

    def evaluate(self, states, actions):
        policy = self.actor(states)
        dist = Categorical(policy)
        log_probs = dist.log_prob(actions)
        entropy = dist.entropy()
        values = self.critic(states)
        return log_probs, entropy, values

    def compute_gae(self, rewards, dones, values, next_value):
        advantages = []
        gae = 0
        values = values + [next_value]
        for step in reversed(range(len(rewards))):
            delta = rewards[step] + self.gamma * values[step + 1] * (1 - dones[step]) - values[step]
            gae = delta + self.gamma * self.gae_lambda * (1 - dones[step]) * gae
            advantages.insert(0, gae)
        returns = [adv + val for adv, val in zip(advantages, values[:-1])]
        return advantages, returns

    def train(self, memory):
        batch_size = len(memory['states'])
        states = torch.tensor(np.array(memory['states']), dtype=torch.float32)
        actions = torch.tensor(np.array(memory['actions']), dtype=torch.int64)
        old_log_probs = torch.tensor(np.array(memory['log_probs']), dtype=torch.float32)
        returns = torch.tensor(np.array(memory['returns']), dtype=torch.float32)
        advantages = torch.tensor(np.array(memory['advantages']), dtype=torch.float32)
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        for _ in range(self.ppo_epochs):
            idxs = np.arange(batch_size)
            np.random.shuffle(idxs)
            for start in range(0, batch_size, self.mini_batch_size):
                end = start + self.mini_batch_size
                mb_idx = idxs[start:end]

                mb_states = states[mb_idx]
                mb_actions = actions[mb_idx]
                mb_old_log_probs = old_log_probs[mb_idx]
                mb_returns = returns[mb_idx]
                mb_advantages = advantages[mb_idx]

                log_probs, entropy, values = self.evaluate(mb_states, mb_actions)

                ratios = torch.exp(log_probs - mb_old_log_probs)
                surr1 = ratios * mb_advantages
                surr2 = torch.clamp(ratios, 1 - self.clip_eps, 1 + self.clip_eps) * mb_advantages
                actor_loss = -torch.min(surr1, surr2).mean() - self.entropy_coef * entropy.mean()

                critic_loss = F.mse_loss(values, mb_returns)

                self.actor_optim.zero_grad()
                actor_loss.backward()
                nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
                self.actor_optim.step()

                self.critic_optim.zero_grad()
                critic_loss.backward()
                nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
                self.critic_optim.step()

if __name__ == "__main__":
    # Training Loop
    env = gym.make('CartPole-v1', max_episode_steps=MAX_STEPS)
    env.reset(seed=SEED)

    agent = PPOAgent()
    torch.save(agent.actor.state_dict(), './Models/actor_initial.pth')

    num_episodes = 1000
    returns_queue = deque(maxlen=100)
    returns_curve = []
    average_returns_curve = []

    memory = {'states': [], 'actions': [], 'rewards': [], 'log_probs': [], 'dones': [], 'values': []}
    for episode in range(num_episodes):
        state, _ = env.reset()
        total_reward = 0

        while 1:
            action, log_prob = agent.select_action(state)
            value = agent.critic(torch.tensor(state, dtype=torch.float32).unsqueeze(0)).item()

            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated

            memory['states'].append(state)
            memory['actions'].append(action)
            memory['rewards'].append(reward)
            memory['log_probs'].append(log_prob)
            memory['dones'].append(done)
            memory['values'].append(value)

            total_reward += reward
            state = next_state

            if done:
                state, _ = env.reset()
                break

        next_value = agent.critic(torch.tensor(state, dtype=torch.float32).unsqueeze(0)).item()
        advantages, returns = agent.compute_gae(memory['rewards'], memory['dones'], memory['values'], next_value)
        memory['advantages'] = advantages
        memory['returns'] = returns

        if len(memory['states']) >= agent.batch_size:
            agent.train(memory)
            memory = {'states': [], 'actions': [], 'rewards': [], 'log_probs': [], 'dones': [], 'values': []}

        returns_queue.append(total_reward)
        returns_curve.append(total_reward)
        average_returns_curve.append(np.mean(returns_queue))
        if episode % 1 == 0:
            print(f"Episode {episode}, Return: {total_reward:.2f}, Avg Return: {np.mean(returns_queue):.2f}")
            torch.save(agent.actor.state_dict(), './Pretrain/actor_' + str(episode) + '.pth')

    data = {
        'returns': returns_curve,
        'average_returns': average_returns_curve
    }
    with open('./Pretrain/data.json', 'w') as f:
        json.dump(data, f, indent = 4)

    plt.figure(figsize=(10, 5))
    plt.plot(average_returns_curve)
    plt.xlabel('Episode')
    plt.ylabel('Average Return (last 100 episodes)')
    plt.title('Average Return Over Training Episodes')
    plt.grid()
    plt.tight_layout()
    plt.savefig("./Pretrain/average_return_plot.png")  # Optional: save to file
    plt.show()

    plt.figure(figsize=(10, 5))
    plt.plot(returns_curve)
    plt.xlabel('Episode')
    plt.ylabel('Return')
    plt.title('Return Over Training Episodes')
    plt.grid()
    plt.tight_layout()
    plt.savefig("./Pretrain/return_plot.png")  # Optional: save to file
    plt.show()

    env.close()
