import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque, namedtuple
import random
import matplotlib.pyplot as plt
import os
import json
from datetime import datetime
import base64
from PIL import Image
import io
import time


# Define the Q-network
class QNetwork(nn.Module):
    def __init__(self, state_size, action_size):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_size, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, action_size)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)


# Define the replay buffer
Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state', 'done'))


class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, *args):
        self.buffer.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

    def __len__(self):
        return len(self.buffer)


# Define the DQN agent
class DQNAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
        self.memory = ReplayBuffer(10000)
        self.gamma = 0.99
        self.epsilon = 1.0
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.9995
        self.batch_size = 128
        self.tau = 0.01

        self.policy_net = QNetwork(state_size, action_size)
        self.target_net = QNetwork(state_size, action_size)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=0.001)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.policy_net.to(self.device)
        self.target_net.to(self.device)

    def act(self, state, training=True):
        if training and random.random() < self.epsilon:
            return random.randrange(self.action_size)

        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        with torch.no_grad():
            q_values = self.policy_net(state)
        return q_values.argmax().item()

    def learn(self):
        if len(self.memory) < self.batch_size:
            return

        transitions = self.memory.sample(self.batch_size)
        batch = Transition(*zip(*transitions))

        states = torch.FloatTensor(np.array(batch.state)).to(self.device)
        actions = torch.LongTensor(np.array(batch.action)).unsqueeze(1).to(self.device)
        rewards = torch.FloatTensor(np.array(batch.reward)).unsqueeze(1).to(self.device)
        next_states = torch.FloatTensor(np.array([s for s in batch.next_state if s is not None])).to(self.device)
        dones = torch.FloatTensor(np.array(batch.done).astype(np.float32)).unsqueeze(1).to(self.device)

        current_q = self.policy_net(states).gather(1, actions)

        next_q = torch.zeros(self.batch_size, device=self.device)
        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)),
                                      device=self.device, dtype=torch.bool)
        if non_final_mask.sum() > 0:
            with torch.no_grad():
                next_q[non_final_mask] = self.target_net(next_states).max(1)[0].detach()

        expected_q = rewards + (self.gamma * next_q.unsqueeze(1) * (1 - dones))

        criterion = nn.SmoothL1Loss()
        loss = criterion(current_q, expected_q)

        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_value_(self.policy_net.parameters(), 100)
        self.optimizer.step()

        target_net_state_dict = self.target_net.state_dict()
        policy_net_state_dict = self.policy_net.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = policy_net_state_dict[key] * self.tau + target_net_state_dict[key] * (
                        1 - self.tau)
        self.target_net.load_state_dict(target_net_state_dict)

        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

    def save(self, filename):
        torch.save({
            'policy_state_dict': self.policy_net.state_dict(),
            'target_state_dict': self.target_net.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'epsilon': self.epsilon,
        }, filename)

    def load(self, filename):
        checkpoint = torch.load(filename)
        self.policy_net.load_state_dict(checkpoint['policy_state_dict'])
        self.target_net.load_state_dict(checkpoint['target_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.epsilon = checkpoint['epsilon']


def evaluate_agent(agent, env_seeds, max_steps=200, gravity=-10,
                   enable_wind=False):
    """Evaluate the agent on multiple seeds and return metrics"""
    total_rewards = []
    total_fuel = 0
    success_count = 0
    episodes_recorder = {}
    image64s = []
    observations = []

    for i, seed in enumerate(env_seeds):
        env = gym.make("LunarLander-v3", render_mode='rgb_array', gravity=gravity,
                   enable_wind=enable_wind)
        observation, _ = env.reset(seed=seed)
        episode_reward = 0
        episode_fuel = 0
        episode_observations = []
        canvas = None

        for step in range(max_steps):
            action = agent.act(observation, training=False)
            observation, reward, terminated, truncated, info = env.step(action)

            episode_reward += reward
            if action in [1, 2, 3]:  # Actions that use fuel
                episode_fuel += 1

            episode_observations.append(observation.tolist())

            # Capture rendering (simplified version)
            if step % 10 == 0:
                img = env.render()
                if canvas is None:
                    canvas = np.zeros_like(img, dtype=np.float32)
                mask = np.any(img != [0, 0, 0], axis=-1)
                alpha = step / max_steps
                canvas[mask] = img[mask] * alpha + canvas[mask] * (1 - alpha)

            if terminated or truncated:
                break

        # Final render
        img = env.render()
        if canvas is None:
            canvas = np.zeros_like(img, dtype=np.float32)
        mask = np.any(img != [0, 0, 0], axis=-1)
        canvas[mask] = img[mask]

        # Convert canvas to base64
        img_pil = Image.fromarray(canvas.astype(np.uint8))
        buffered = io.BytesIO()
        img_pil.save(buffered, format="PNG")
        img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')

        env.close()

        # Record results
        total_rewards.append(episode_reward)
        total_fuel += episode_fuel
        image64s.append(img_str)
        observations.append(episode_observations)

        if episode_reward >= 200:
            success_count += 1

        episodes_recorder[f'{i}'] = {
            'seed': seed,
            'episode_reward': episode_reward,
            'episode_fuel': episode_fuel,
            'observations': episode_observations,
            'terminated': terminated,
            'truncated': truncated
        }

    # Calculate metrics
    mean_reward = np.mean(total_rewards)
    mean_fuel = total_fuel / len(env_seeds)
    success_rate = success_count / len(env_seeds)

    # Normalized Weighted Score (α=0.6, β=0.2, γ=0.2)
    nws = (mean_reward / 200) * 0.6 + (1 - min(mean_fuel / 100, 1)) * 0.2 + success_rate * 0.2

    # Get the worst performance case
    worst_idx = np.argmin(total_rewards)

    return {
        'mean_reward': mean_reward,
        'mean_fuel': mean_fuel,
        'success_rate': success_rate,
        'nws': nws,
        'worst_case_image': image64s[worst_idx],
        'worst_case_observations': observations[worst_idx],
        'episodes_recorder': episodes_recorder
    }


def train_dqn(env, agent, episodes=1000, max_steps=200, save_interval=100):
    best_nws = -np.inf
    # Fixed evaluation seeds
    env_seeds = [42, 520, 1231, 114, 886]
    seed_cycle = len(env_seeds)

    rewards_history = []
    episode_lengths = []
    moving_avg_rewards = []
    evaluation_history = []

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    save_dir = f"results/dqn_lunarlander_{timestamp}"
    os.makedirs(save_dir, exist_ok=True)

    # Record start time
    start_time = time.time()
    last_eval_time = start_time

    for episode in range(1, episodes + 1):
        # Cycle through the specified seeds
        current_seed = env_seeds[(episode - 1) % seed_cycle]
        state, _ = env.reset(seed=current_seed)

        total_reward = 0
        done = False
        steps = 0

        while not done and steps < max_steps:
            action = agent.act(state)
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated

            agent.memory.push(state, action, reward, next_state if not done else None, done)
            agent.learn()

            state = next_state
            total_reward += reward
            steps += 1

        rewards_history.append(total_reward)
        episode_lengths.append(steps)

        # Calculate moving average
        if episode >= 100:
            avg_reward = np.mean(rewards_history[-100:])
            moving_avg_rewards.append(avg_reward)
        else:
            moving_avg_rewards.append(np.mean(rewards_history))

        # print(f"Episode {episode}/{episodes}, Reward: {total_reward:.2f}, Steps: {steps}, Epsilon: {agent.epsilon:.2f}")
        print(
            f"Episode {episode}/{episodes}, Seed: {current_seed}, Total Reward: {total_reward:.2f}, Epsilon: {agent.epsilon:.2f}, Steps: {steps}")


        # Periodic evaluation
        if episode % save_interval == 0 or episode == episodes:
            current_time = time.time()
            elapsed_time = current_time - start_time
            eval_time = current_time - last_eval_time
            last_eval_time = current_time

            eval_results = evaluate_agent(agent, env_seeds, max_steps)
            evaluation_history.append({
                'episode': episode,
                'metrics': eval_results,
                'elapsed_time': elapsed_time,
                'eval_time': eval_time
            })

            print(f"\nEvaluation at Episode {episode}:")
            print(f"  Mean Reward: {eval_results['mean_reward']:.2f}")
            print(f"  Mean Fuel: {eval_results['mean_fuel']:.2f}")
            print(f"  Success Rate: {eval_results['success_rate']:.2f}")
            print(f"  NWS: {eval_results['nws']:.2f}")
            print(f"  Elapsed Time: {elapsed_time:.2f} seconds")
            print(f"  Time Since Last Eval: {eval_time:.2f} seconds")

            # Save evaluation image
            img_data = base64.b64decode(eval_results['worst_case_image'])
            img_path = os.path.join(save_dir, f"eval_ep{episode}_worst.png")
            with open(img_path, 'wb') as f:
                f.write(img_data)

            # Save model
            # model_path = os.path.join(save_dir, f"dqn_model_ep{episode}.pth")
            # agent.save(model_path)
            # print(f"Model saved to {model_path}\n")

            if eval_results['nws'] > best_nws:
                best_nws = eval_results['nws']
                model_path = os.path.join(save_dir, f"dqn_model_best.pth")
                agent.save(model_path)
                print(f"  New Best Model Saved! NWS: {best_nws:.2f}")

            # Record final time
            end_time = time.time()
            total_time = end_time - start_time

            # Save current model
            if episode % (5 * save_interval) == 0:
                agent.save(os.path.join(save_dir, f"model_ep{episode}.pth"))
            # Save final model and training data
            final_model_path = os.path.join(save_dir, "dqn_model_final.pth")
            agent.save(final_model_path)

            training_data = {
                'rewards': rewards_history,
                'episode_lengths': episode_lengths,
                'moving_avg_rewards': moving_avg_rewards,
                'evaluation_history': evaluation_history,
                'env_seeds': env_seeds,
                'total_time': total_time,
                'start_time': start_time,
                'end_time': end_time
            }
            with open(os.path.join(save_dir, 'training_history.json'), 'w') as f:
                json.dump(training_data, f)

            # Plot training results
            plt.figure(figsize=(12, 6))
            plt.plot(rewards_history, alpha=0.6, label='Episode Reward')
            plt.plot(moving_avg_rewards, 'r-', linewidth=2, label='Moving Avg (100 episodes)')

            # Mark evaluation points
            eval_episodes = [e['episode'] for e in evaluation_history]
            eval_scores = [e['metrics']['nws'] for e in evaluation_history]
            plt.scatter(eval_episodes, [moving_avg_rewards[e - 1] for e in eval_episodes],
                        c='green', marker='o', label='Evaluation Points')

            plt.xlabel('Episode')
            plt.ylabel('Reward')
            plt.title('DQN Training Performance with Periodic Evaluation')
            plt.legend()
            plt.grid()
            plt.savefig(os.path.join(save_dir, 'training_plot.png'))
            plt.close()

            # Plot evaluation metrics
            plt.figure(figsize=(12, 8))

            plt.subplot(2, 2, 1)
            plt.plot([e['episode'] for e in evaluation_history],
                     [e['metrics']['mean_reward'] for e in evaluation_history], 'o-')
            plt.xlabel('Episode')
            plt.ylabel('Mean Reward')
            plt.title('Evaluation Mean Reward')
            plt.grid()

            plt.subplot(2, 2, 2)
            plt.plot([e['episode'] for e in evaluation_history],
                     [e['metrics']['mean_fuel'] for e in evaluation_history], 'o-')
            plt.xlabel('Episode')
            plt.ylabel('Mean Fuel')
            plt.title('Evaluation Mean Fuel')
            plt.grid()

            plt.subplot(2, 2, 3)
            plt.plot([e['episode'] for e in evaluation_history],
                     [e['metrics']['success_rate'] for e in evaluation_history], 'o-')
            plt.xlabel('Episode')
            plt.ylabel('Success Rate')
            plt.title('Evaluation Success Rate')
            plt.grid()

            plt.subplot(2, 2, 4)
            plt.plot([e['episode'] for e in evaluation_history],
                     [e['metrics']['nws'] for e in evaluation_history], 'o-')
            plt.xlabel('Episode')
            plt.ylabel('NWS')
            plt.title('Normalized Weighted Score')
            plt.grid()

            plt.tight_layout()
            plt.savefig(os.path.join(save_dir, 'evaluation_metrics.png'))
            plt.close()

            # Plot best NWS progression
            best_nws_plot_path = plot_best_nws_progress(evaluation_history, save_dir)
            print(f"Training completed in {total_time:.2f} seconds. Results saved in: {save_dir}")

    return save_dir


def plot_best_nws_progress(evaluation_history, save_dir):
    best_nws = -float('inf')
    best_nws_history = []
    episodes = []

    for eval_point in evaluation_history:
        current_nws = eval_point['metrics']['nws']
        if current_nws > best_nws:
            best_nws = current_nws
            best_nws_history.append(best_nws)
            episodes.append(eval_point['episode'])
        else:
            # Keep the previous best value
            best_nws_history.append(best_nws)
            episodes.append(eval_point['episode'])

    plt.figure(figsize=(10, 6))
    plt.plot(episodes, best_nws_history, 'g-o', linewidth=2, markersize=8)

    # Annotate the best point
    max_idx = np.argmax(best_nws_history)
    plt.annotate(f'Best NWS: {best_nws_history[max_idx]:.3f}',
                 xy=(episodes[max_idx], best_nws_history[max_idx]),
                 xytext=(0, -30), textcoords='offset points',
                 bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.5),
                 arrowprops=dict(arrowstyle='->'))

    plt.xlabel('Episode')
    plt.ylabel('Best NWS')
    plt.title('Progression of Best Normalized Weighted Score')
    plt.grid(True)

    # Save the plot
    plot_path = os.path.join(save_dir, 'best_nws_progression.png')
    plt.savefig(plot_path)
    plt.close()



    return plot_path

def main(run_id):
    env = gym.make("LunarLander-v3")
    state_size = env.observation_space.shape[0]
    action_size = env.action_space.n

    agent = DQNAgent(state_size, action_size)

    save_dir = train_dqn(env, agent, episodes=10000, max_steps=200, save_interval=30)

    env.close()
    print(f"Training completed. Results saved in: {save_dir}")

if __name__ == "__main__":
    for i in range(1):
        main(i)




