import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque, namedtuple
import random
import matplotlib.pyplot as plt
import os
import json
from datetime import datetime
import base64
from PIL import Image
import io
from torch.distributions import Categorical
import time


# Define the PPO Policy Network
class PPOPolicyNetwork(nn.Module):
    def __init__(self, state_size, action_size):
        super(PPOPolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(state_size, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc_policy = nn.Linear(512, action_size)
        self.fc_value = nn.Linear(512, 1)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        policy_logits = self.fc_policy(x)
        value = self.fc_value(x)
        return policy_logits, value


# Define the PPO agent
class PPOAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
        self.policy = PPOPolicyNetwork(state_size, action_size)
        self.optimizer = optim.Adam(self.policy.parameters(), lr=0.0003)

        # Improved PPO hyperparameters
        self.gamma = 0.99
        self.gae_lambda = 0.95  # For GAE
        self.clip_epsilon = 0.2
        self.epochs = 10  # Increased epochs
        self.mini_batch_size = 64  # Renamed to avoid confusion
        self.buffer_size = 2048  # Collect more samples before updating
        self.trajectory_buffer = []
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.policy.to(self.device)

    def act(self, state, training=True):
        state = torch.FloatTensor(state).to(self.device)
        with torch.no_grad():
            policy_logits, value = self.policy(state)

        dist = Categorical(logits=policy_logits)
        action = dist.sample()
        log_prob = dist.log_prob(action)

        if not training:
            action = policy_logits.argmax()

        return action.item(), log_prob.item(), value.item()

    def store_trajectory(self, state, action, log_prob, value, reward, done):
        self.trajectory_buffer.append({
            'state': state,
            'action': action,
            'log_prob': log_prob,
            'value': value,
            'reward': reward,
            'done': done
        })

    def compute_advantages(self, rewards, values, dones, next_value):
        advantages = np.zeros(len(rewards), dtype=np.float32)
        last_advantage = 0

        # Generalized Advantage Estimation (GAE)
        for t in reversed(range(len(rewards))):
            if t == len(rewards) - 1:
                next_non_terminal = 1.0 - dones[t]
                next_value = next_value
            else:
                next_non_terminal = 1.0 - dones[t]
                next_value = values[t + 1]

            delta = rewards[t] + self.gamma * next_value * next_non_terminal - values[t]
            advantages[t] = delta + self.gamma * 0.95 * next_non_terminal * last_advantage
            last_advantage = advantages[t]

        returns = advantages + values
        return advantages, returns

    def learn(self):
        # Only learn when we have enough samples
        if len(self.trajectory_buffer) < self.buffer_size:
            return

        # Convert to numpy arrays
        states = np.array([t['state'] for t in self.trajectory_buffer])
        actions = np.array([t['action'] for t in self.trajectory_buffer])
        old_log_probs = np.array([t['log_prob'] for t in self.trajectory_buffer])
        old_values = np.array([t['value'] for t in self.trajectory_buffer])
        rewards = np.array([t['reward'] for t in self.trajectory_buffer])
        dones = np.array([t['done'] for t in self.trajectory_buffer])

        # Compute advantages with proper bootstrapping
        _, next_value = self.policy(torch.FloatTensor(self.trajectory_buffer[-1]['state']).to(self.device))
        next_value = next_value.item()
        advantages, returns = self.compute_advantages(rewards, old_values, dones, next_value)

        # Normalize advantages
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        # Convert to tensors
        states = torch.FloatTensor(states).to(self.device)
        actions = torch.LongTensor(actions).to(self.device)
        old_log_probs = torch.FloatTensor(old_log_probs).to(self.device)
        returns = torch.FloatTensor(returns).to(self.device)
        advantages = torch.FloatTensor(advantages).to(self.device)

        # Create dataset for multiple epochs
        dataset = torch.utils.data.TensorDataset(states, actions, old_log_probs, returns, advantages)
        loader = torch.utils.data.DataLoader(dataset, batch_size=self.mini_batch_size, shuffle=True)

        for _ in range(self.epochs):
            for batch in loader:
                batch_states, batch_actions, batch_old_log_probs, batch_returns, batch_advantages = batch

                # Get current policy outputs
                policy_logits, values = self.policy(batch_states)
                dist = Categorical(logits=policy_logits)
                new_log_probs = dist.log_prob(batch_actions)
                entropy = dist.entropy()

                # Policy ratio
                ratio = (new_log_probs - batch_old_log_probs).exp()

                # Policy loss
                policy_loss1 = ratio * batch_advantages
                policy_loss2 = torch.clamp(ratio, 1.0 - self.clip_epsilon,
                                           1.0 + self.clip_epsilon) * batch_advantages
                policy_loss = -torch.min(policy_loss1, policy_loss2).mean()

                # Value loss
                value_loss = 0.5 * (batch_returns - values.squeeze()).pow(2).mean()

                # Entropy bonus
                entropy_loss = -entropy.mean()

                # Total loss
                loss = policy_loss + 0.5 * value_loss + 0.01 * entropy_loss

                # Optimization step
                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 0.5)
                self.optimizer.step()

        # Clear buffer after learning
        # self.trajectory_buffer = []

        keep_size = int(0.1 * self.buffer_size)
        self.trajectory_buffer = self.trajectory_buffer[-keep_size:]

    def save(self, filename):
        torch.save({
            'policy_state_dict': self.policy.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
        }, filename)

    def load(self, filename):
        checkpoint = torch.load(filename)
        self.policy.load_state_dict(checkpoint['policy_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])


def evaluate_agent(agent, env_seeds, max_steps=200, gravity=-10,
                   enable_wind=False):
    """Evaluate the agent on multiple seeds and return metrics"""
    total_rewards = []
    total_fuel = 0
    success_count = 0
    episodes_recorder = {}
    image64s = []
    observations = []

    for i, seed in enumerate(env_seeds):
        env = gym.make("LunarLander-v3", render_mode='rgb_array', gravity=gravity,
                   enable_wind=enable_wind)
        observation, _ = env.reset(seed=seed)
        episode_reward = 0
        episode_fuel = 0
        episode_observations = []
        canvas = None

        for step in range(max_steps):
            action, _, _ = agent.act(observation, training=False)
            observation, reward, terminated, truncated, info = env.step(action)

            episode_reward += reward
            if action in [1, 2, 3]:  # Actions that use fuel
                episode_fuel += 1

            episode_observations.append(observation.tolist())

            # Capture rendering (simplified version)
            if step % 10 == 0:
                img = env.render()
                if canvas is None:
                    canvas = np.zeros_like(img, dtype=np.float32)
                mask = np.any(img != [0, 0, 0], axis=-1)
                alpha = step / max_steps
                canvas[mask] = img[mask] * alpha + canvas[mask] * (1 - alpha)

            if terminated or truncated:
                break

        # Final render
        img = env.render()
        if canvas is None:
            canvas = np.zeros_like(img, dtype=np.float32)
        mask = np.any(img != [0, 0, 0], axis=-1)
        canvas[mask] = img[mask]

        # Convert canvas to base64
        img_pil = Image.fromarray(canvas.astype(np.uint8))
        buffered = io.BytesIO()
        img_pil.save(buffered, format="PNG")
        img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')

        env.close()

        # Record results
        total_rewards.append(episode_reward)
        total_fuel += episode_fuel
        image64s.append(img_str)
        observations.append(episode_observations)

        if episode_reward >= 200:
            success_count += 1

        episodes_recorder[f'{i}'] = {
            'seed': seed,
            'episode_reward': episode_reward,
            'episode_fuel': episode_fuel,
            'observations': episode_observations,
            'terminated': terminated,
            'truncated': truncated
        }

    # Calculate metrics
    mean_reward = np.mean(total_rewards)
    mean_fuel = total_fuel / len(env_seeds)
    success_rate = success_count / len(env_seeds)

    # Normalized Weighted Score (α=0.6, β=0.2, γ=0.2)
    nws = (mean_reward / 200) * 0.6 + (1 - min(mean_fuel / 100, 1)) * 0.2 + success_rate * 0.2

    # Get the worst performance case
    worst_idx = np.argmin(total_rewards)

    return {
        'mean_reward': mean_reward,
        'mean_fuel': mean_fuel,
        'success_rate': success_rate,
        'nws': nws,
        'worst_case_image': image64s[worst_idx],
        'worst_case_observations': observations[worst_idx],
        'episodes_recorder': episodes_recorder
    }


def train_ppo(env, agent, episodes=1000, max_steps=200, save_interval=100):
    best_nws = -np.inf
    # Fixed evaluation seeds
    env_seeds = [42, 520, 1231, 114, 886]
    seed_cycle = len(env_seeds)

    rewards_history = []
    episode_lengths = []
    moving_avg_rewards = []
    evaluation_history = []

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    save_dir = f"results/ppo_lunarlander_{timestamp}"
    os.makedirs(save_dir, exist_ok=True)

    # Record start time
    start_time = time.time()
    last_eval_time = start_time

    for episode in range(1, episodes + 1):
        # Cycle through the specified seeds
        current_seed = env_seeds[(episode - 1) % seed_cycle]
        state, _ = env.reset(seed=current_seed)

        total_reward = 0
        done = False
        steps = 0

        while not done and steps < max_steps:
            action, log_prob, value = agent.act(state)
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated

            agent.store_trajectory(state, action, log_prob, value, reward, done)

            state = next_state
            total_reward += reward
            steps += 1

        # Learn from the collected trajectories
        agent.learn()

        rewards_history.append(total_reward)
        episode_lengths.append(steps)

        # Calculate moving average
        if episode >= 100:
            avg_reward = np.mean(rewards_history[-100:])
            moving_avg_rewards.append(avg_reward)
        else:
            moving_avg_rewards.append(np.mean(rewards_history))

        # print(f"Episode {episode}/{episodes}, Reward: {total_reward:.2f}, Steps: {steps}")

        print(
            f"Episode {episode}/{episodes}, Seed: {current_seed}, Total Reward: {total_reward:.2f}, Steps: {steps}")

        # Periodic evaluation
        if episode % save_interval == 0 or episode == episodes:
            current_time = time.time()
            elapsed_time = current_time - start_time
            eval_time = current_time - last_eval_time
            last_eval_time = current_time

            eval_results = evaluate_agent(agent, env_seeds, max_steps)
            evaluation_history.append({
                'episode': episode,
                'metrics': eval_results,
                'elapsed_time': elapsed_time,
                'eval_time': eval_time
            })

            print(f"\nEvaluation at Episode {episode}:")
            print(f"  Mean Reward: {eval_results['mean_reward']:.2f}")
            print(f"  Mean Fuel: {eval_results['mean_fuel']:.2f}")
            print(f"  Success Rate: {eval_results['success_rate']:.2f}")
            print(f"  NWS: {eval_results['nws']:.2f}")
            print(f"  Elapsed Time: {elapsed_time:.2f} seconds")
            print(f"  Time Since Last Eval: {eval_time:.2f} seconds")

            # Save evaluation image
            img_data = base64.b64decode(eval_results['worst_case_image'])
            img_path = os.path.join(save_dir, f"eval_ep{episode}_worst.png")
            with open(img_path, 'wb') as f:
                f.write(img_data)

            # Save current model
            if episode % (5 * save_interval) == 0:
                agent.save(os.path.join(save_dir, f"ppo_model_ep{episode}.pth"))

            if eval_results['nws'] > best_nws:
                best_nws = eval_results['nws']
                model_path = os.path.join(save_dir, f"dqn_model_best.pth")
                agent.save(model_path)
                print(f"  New Best Model Saved! NWS: {best_nws:.2f}")

            # Save model
            # model_path = os.path.join(save_dir, f"ppo_model_ep{episode}.pth")
            # agent.save(model_path)
            # print(f"Model saved to {model_path}\n")

            # Record final time
            end_time = time.time()
            total_time = end_time - start_time

            # Save final model and training data
            final_model_path = os.path.join(save_dir, "ppo_model_final.pth")
            agent.save(final_model_path)

            training_data = {
                'rewards': rewards_history,
                'episode_lengths': episode_lengths,
                'moving_avg_rewards': moving_avg_rewards,
                'evaluation_history': evaluation_history,
                'env_seeds': env_seeds,
                'total_time': total_time,
                'start_time': start_time,
                'end_time': end_time
            }
            with open(os.path.join(save_dir, 'training_history.json'), 'w') as f:
                json.dump(training_data, f)

            # Plot training results
            plt.figure(figsize=(12, 6))
            plt.plot(rewards_history, alpha=0.6, label='Episode Reward')
            plt.plot(moving_avg_rewards, 'r-', linewidth=2, label='Moving Avg (100 episodes)')

            # Mark evaluation points
            eval_episodes = [e['episode'] for e in evaluation_history]
            eval_scores = [e['metrics']['nws'] for e in evaluation_history]
            plt.scatter(eval_episodes, [moving_avg_rewards[e - 1] for e in eval_episodes],
                        c='green', marker='o', label='Evaluation Points')

            plt.xlabel('Episode')
            plt.ylabel('Reward')
            plt.title('PPO Training Performance with Periodic Evaluation')
            plt.legend()
            plt.grid()
            plt.savefig(os.path.join(save_dir, 'training_plot.png'))
            plt.close()

            # Plot evaluation metrics
            plt.figure(figsize=(12, 8))

            plt.subplot(2, 2, 1)
            plt.plot([e['episode'] for e in evaluation_history],
                     [e['metrics']['mean_reward'] for e in evaluation_history], 'o-')
            plt.xlabel('Episode')
            plt.ylabel('Mean Reward')
            plt.title('Evaluation Mean Reward')
            plt.grid()

            plt.subplot(2, 2, 2)
            plt.plot([e['episode'] for e in evaluation_history],
                     [e['metrics']['mean_fuel'] for e in evaluation_history], 'o-')
            plt.xlabel('Episode')
            plt.ylabel('Mean Fuel')
            plt.title('Evaluation Mean Fuel')
            plt.grid()

            plt.subplot(2, 2, 3)
            plt.plot([e['episode'] for e in evaluation_history],
                     [e['metrics']['success_rate'] for e in evaluation_history], 'o-')
            plt.xlabel('Episode')
            plt.ylabel('Success Rate')
            plt.title('Evaluation Success Rate')
            plt.grid()

            plt.subplot(2, 2, 4)
            plt.plot([e['episode'] for e in evaluation_history],
                     [e['metrics']['nws'] for e in evaluation_history], 'o-')
            plt.xlabel('Episode')
            plt.ylabel('NWS')
            plt.title('Normalized Weighted Score')
            plt.grid()

            plt.tight_layout()
            plt.savefig(os.path.join(save_dir, 'evaluation_metrics.png'))
            plt.close()

            # Plot best NWS progression
            best_nws_plot_path = plot_best_nws_progress(evaluation_history, save_dir)
            print(f"Best NWS progression plot saved to {best_nws_plot_path}")
            print(f"Training completed in {total_time:.2f} seconds. Results saved in: {save_dir}")
    return save_dir


def plot_best_nws_progress(evaluation_history, save_dir):
    best_nws = -float('inf')
    best_nws_history = []
    episodes = []

    for eval_point in evaluation_history:
        current_nws = eval_point['metrics']['nws']
        if current_nws > best_nws:
            best_nws = current_nws
            best_nws_history.append(best_nws)
            episodes.append(eval_point['episode'])
        else:
            # Keep the previous best value
            best_nws_history.append(best_nws)
            episodes.append(eval_point['episode'])

    plt.figure(figsize=(10, 6))
    plt.plot(episodes, best_nws_history, 'g-o', linewidth=2, markersize=8)

    # Annotate the best point
    max_idx = np.argmax(best_nws_history)
    plt.annotate(f'Best NWS: {best_nws_history[max_idx]:.3f}',
                 xy=(episodes[max_idx], best_nws_history[max_idx]),
                 xytext=(0, -30), textcoords='offset points',
                 bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.5),
                 arrowprops=dict(arrowstyle='->'))

    plt.xlabel('Episode')
    plt.ylabel('Best NWS')
    plt.title('Progression of Best Normalized Weighted Score')
    plt.grid(True)

    # Save the plot
    plot_path = os.path.join(save_dir, 'best_nws_progression.png')
    plt.savefig(plot_path)
    plt.close()

    return plot_path

def main(run_id):
    env = gym.make("LunarLander-v3")
    state_size = env.observation_space.shape[0]
    action_size = env.action_space.n

    agent = PPOAgent(state_size, action_size)

    save_dir = train_ppo(env, agent, episodes=10000, max_steps=200, save_interval=30)

    env.close()
    print(f"Training completed. Results saved in: {save_dir}")

if __name__ == "__main__":
    main(1)
    # for i in range(3):
    #     main(i)
