import gymnasium as gym
import torch.optim as optim
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
import random
from torch.distributions import Normal
from collections import deque
import matplotlib.pyplot as plt
"""
Global constants
"""
SEED = 42
MAX_STEPS = 500

STATE_DIM = 17
ACTION_DIM = 6
ACTION_HIGH = torch.FloatTensor(np.ones(ACTION_DIM))
ACTION_LOW = - torch.FloatTensor(np.ones(ACTION_DIM))

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

"""
Import Self-Defined Module
"""
from Networks import Actor, Critic
class PPOAgent:
    def __init__(self):
        self.actor = Actor(STATE_DIM,ACTION_DIM)                       # Actor Network
        self.critic = Critic(STATE_DIM)                                 # Critic Network

        self.lr = 3e-4
        self.gamma = 0.99
        self.gae_lambda = 0.95
        self.clip_eps = 0.2
        self.max_grad_norm = 0.5
        self.entropy_coef = 0.01
        self.ppo_epochs = 5

        self.actor_optim = optim.Adam(self.actor.parameters(), lr=self.lr)
        self.critic_optim = optim.Adam(self.critic.parameters(), lr=self.lr)

    def load_model(self, pth):
        self.actor.load_state_dict(torch.load(pth, weights_only=True))

    def select_action(self, state):
        state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        with torch.no_grad():
            mean, std = self.actor(state)
        dist = torch.distributions.Normal(mean, std)
        pre_tanh = dist.rsample()  # rsample for gradients through mean/std
        action = torch.tanh(pre_tanh)

        # log |det d(tanh)/dx| = sum log(1 - tanh(x)^2)
        logp = dist.log_prob(pre_tanh).sum(-1) - torch.log1p(-action.pow(2) + 1e-6).sum(-1)
        return action.squeeze(0).numpy(), logp.item()

    def evaluate(self, states, actions):
        mean, std = self.actor(states)
        dist = torch.distributions.Normal(mean, std)
        x = torch.clamp(actions, -0.999999, 0.999999)
        pre_tanh = torch.atanh(x)
        logp = dist.log_prob(pre_tanh).sum(-1) - torch.log1p(-x.pow(2) + 1e-6).sum(-1)
        entropy = dist.entropy().sum(dim=-1)
        values = self.critic(states).squeeze(-1)
        return logp, entropy, values

    def compute_gae(self, rewards, dones, values, next_value):
        advantages = []
        gae = 0
        values = values + [next_value]
        for step in reversed(range(len(rewards))):
            delta = rewards[step] + self.gamma * values[step + 1] * (1 - dones[step]) - values[step]
            gae = delta + self.gamma * self.gae_lambda * (1 - dones[step]) * gae
            advantages.insert(0, gae)
        returns = [adv + val for adv, val in zip(advantages, values[:-1])]
        return advantages, returns

    def train(self, memory):
        batch_size = len(memory['states'])
        states = torch.tensor(np.array(memory['states']), dtype=torch.float32)
        actions = torch.tensor(np.array(memory['actions']), dtype=torch.float32)
        old_log_probs = torch.tensor(np.array(memory['log_probs']), dtype=torch.float32)
        returns = torch.tensor(np.array(memory['returns']), dtype=torch.float32)
        advantages = torch.tensor(np.array(memory['advantages']), dtype=torch.float32)
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        for epoch in range(self.ppo_epochs):
            log_probs, entropy, values = self.evaluate(states, actions)

            # Check for NaN values
            if torch.isnan(log_probs).any() or torch.isnan(values).any():
                print(f"Warning: NaN detected at epoch {epoch}")
                print(f"NaN in log_probs: {torch.isnan(log_probs).sum().item()}")
                print(f"NaN in values: {torch.isnan(values).sum().item()}")
                break

            ratios = torch.exp(log_probs - old_log_probs)
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - self.clip_eps, 1 + self.clip_eps) * advantages
            actor_loss = -torch.min(surr1, surr2).mean() - self.entropy_coef * entropy.mean()
            critic_loss = F.mse_loss(values, returns)

            self.actor_optim.zero_grad()
            actor_loss.backward()
            nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
            self.actor_optim.step()

            self.critic_optim.zero_grad()
            critic_loss.backward()
            nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
            self.critic_optim.step()

        return actor_loss.item(), critic_loss.item()


if __name__ == "__main__":
    # Training Loop
    env = gym.make('HalfCheetah-v5', max_episode_steps=MAX_STEPS)
    env.reset(seed=SEED)

    agent = PPOAgent()
    torch.save(agent.actor.state_dict(), './Models/actor_initial.pth')

    num_episodes = 5000
    returns_queue = deque(maxlen=100)
    returns_curve = []
    average_returns_curve = []

    memory = {
        'states': [], 'actions': [], 'rewards': [], 'log_probs': [],
        'dones': [], 'truncateds': [], 'values': []
    }
    for episode in range(num_episodes):
        state, _ = env.reset()
        total_reward = 0

        while 1:
            action, log_prob = agent.select_action(state)
            value = agent.critic(torch.tensor(state, dtype=torch.float32).unsqueeze(0)).item()

            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated

            memory['states'].append(state)
            memory['actions'].append(action)
            memory['rewards'].append(reward)
            memory['log_probs'].append(log_prob)
            memory['dones'].append(terminated)
            memory['truncateds'].append(truncated)
            memory['values'].append(value)

            total_reward += reward
            state = next_state

            if done:
                last_state = next_state
                last_truncated = truncated
                break
        if last_truncated:
            next_value = agent.critic(torch.tensor(last_state, dtype=torch.float32).unsqueeze(0)).item()
        else:
            next_value = 0.0
        advantages, returns = agent.compute_gae(
            memory['rewards'], memory['dones'], memory['values'], next_value
        )
        memory['advantages'] = advantages
        memory['returns'] = returns

        agent.train(memory)
        memory = {
            'states': [], 'actions': [], 'rewards': [], 'log_probs': [],
            'dones': [], 'truncateds': [], 'values': []
        }

        returns_queue.append(total_reward)
        returns_curve.append(total_reward)
        average_returns_curve.append(np.mean(returns_queue))
        if episode % 1 == 0:
            print(f"Episode {episode}, Return: {total_reward:.2f}, Avg Return: {np.mean(returns_queue):.2f}")
            torch.save(agent.actor.state_dict(), './Pretrain/actor_' + str(episode) + '.pth')

    data = {
        'returns': returns_curve,
        'average_returns': average_returns_curve
    }
    with open('./Pretrain/data.json', 'w') as f:
        json.dump(data, f, indent = 4)

    plt.figure(figsize=(10, 5))
    plt.plot(average_returns_curve)
    plt.xlabel('Episode')
    plt.ylabel('Average Return (last 100 episodes)')
    plt.title('Average Return Over Training Episodes')
    plt.grid()
    plt.tight_layout()
    plt.savefig("./Pretrain/average_return_plot.png")  # Optional: save to file
    plt.show()

    plt.figure(figsize=(10, 5))
    plt.plot(returns_curve)
    plt.xlabel('Episode')
    plt.ylabel('Return')
    plt.title('Return Over Training Episodes')
    plt.grid()
    plt.tight_layout()
    plt.savefig("./Pretrain/return_plot.png")  # Optional: save to file
    plt.show()

    env.close()
