import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque, namedtuple
import random
import matplotlib.pyplot as plt
import os
import json
from datetime import datetime
import base64
from PIL import Image
import io
import time
import copy
from io import BytesIO
import matplotlib.patches as patches
from matplotlib.transforms import Affine2D
import itertools


# Define the Q-network
class DQN(nn.Module):
    def __init__(self, h, w, speed_dim, action_dim):
        super(DQN, self).__init__()
        # CNN for current frame
        self.current_frame_conv = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )

        # CNN for last decision frame
        self.last_frame_conv = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )

        # Calculate the output dimensions of the conv layers
        def conv2d_size_out(size, kernel_size, stride):
            return (size - (kernel_size - 1) - 1) // stride + 1

        convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w, 8, 4), 4, 2), 3, 1)
        convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h, 8, 4), 4, 2), 3, 1)
        conv_output_size = convw * convh * 64

        # Calculate total input dimension
        linear_input_size = conv_output_size * 2 + speed_dim + 3  # 3 is action dimension

        self.head = nn.Sequential(
            nn.Linear(linear_input_size, 512),
            nn.ReLU(),
            nn.Linear(512, action_dim))

    def forward(self, x):
        current_frame, speed, last_decision_frame, last_action = x

        # Process current frame through CNN
        current_features = self.current_frame_conv(current_frame)
        current_features = current_features.view(current_features.size(0), -1)

        # Process last decision frame through CNN
        last_features = self.last_frame_conv(last_decision_frame)
        last_features = last_features.view(last_features.size(0), -1)

        # Ensure all features have consistent batch dimension
        speed = speed.view(speed.size(0), -1)
        last_action = last_action.view(last_action.size(0), -1)

        # Concatenate all features
        x = torch.cat([current_features, last_features, speed, last_action], dim=1)
        return self.head(x)


# Define the replay buffer
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward', 'done'))


class ReplayMemory:
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)


class CarRacingAgent:
    def __init__(self, env, frame_skip=1):
        self.env = env
        self.frame_skip = frame_skip

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")

        self.DISCRETE_ACTIONS = [
            [0.0, 0.0, 0.0],  # No action
            [0.0, 1.0, 0.0],  # Accelerate straight
            [0.0, 0.0, 0.2],  # Brake
            [-1.0, 0.0, 0.2],  # Brake + left
            [1.0, 0.0, 0.2],  # Brake + right
            [-0.5, 0.0, 0.0],  # Left
            [0.5, 0.0, 0.0],  # Right
            [-1.0, 1.0, 0.0],  # Left + accelerate
            [1.0, 1.0, 0.0]  # Right + accelerate
        ]

        # Get dimensions
        self.n_actions = len(self.DISCRETE_ACTIONS)
        self.screen_height, self.screen_width, _ = self.env.observation_space.shape
        self.speed_dim = 1
        self.last_action_dim = len(self.DISCRETE_ACTIONS[0])

        self.gamma = 0.98
        self.epsilon = 1.0
        self.epsilon_min = 0.1
        self.epsilon_decay = 0.995
        self.batch_size = 128
        self.tau = 0.01

        self.policy_net = DQN(self.screen_height, self.screen_width,
                              self.speed_dim, self.n_actions).to(self.device)
        self.target_net = DQN(self.screen_height, self.screen_width,
                              self.speed_dim, self.n_actions).to(self.device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()

        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=0.0005)
        self.memory = ReplayMemory(30000)

        self.eps_threshold = self.epsilon

        # Frame skipping variables
        self.last_decision_frame = None
        self.last_decision_action = None

    # def act(self, state, training=True):
    #     if training and random.random() < self.epsilon:
    #         return random.randrange(self.action_size)
    #
    #     state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
    #     with torch.no_grad():
    #         q_values = self.policy_net(state)
    #     return q_values.argmax().item()

    def select_action(self, state):
        sample = random.random()
        self.eps_threshold = max(self.epsilon_min, self.eps_threshold * self.epsilon_decay)

        if sample > self.eps_threshold:
            with torch.no_grad():
                return self.policy_net(state).max(1)[1].view(1, 1)
        else:
            return torch.tensor([[random.randrange(self.n_actions)]],
                              device=self.device, dtype=torch.long)

    def preprocess_state(self, current_frame, speed, last_decision_frame, last_action):
        current_frame = torch.from_numpy(current_frame).float().permute(2, 0, 1).unsqueeze(0).to(self.device) / 255.0
        last_decision_frame = torch.from_numpy(last_decision_frame).float().permute(2, 0, 1).unsqueeze(0).to(self.device) / 255.0
        speed = torch.tensor([speed], dtype=torch.float32).unsqueeze(0).to(self.device)
        last_action = torch.tensor(last_action, dtype=torch.float32).unsqueeze(0).to(self.device)
        return (current_frame, speed, last_decision_frame, last_action)

    def soft_update(self):
        """Soft update model parameters"""
        for target_param, policy_param in zip(self.target_net.parameters(), self.policy_net.parameters()):
            target_param.data.copy_(self.tau * policy_param.data + (1.0 - self.tau) * target_param.data)

    def learn(self):
        if len(self.memory) < self.batch_size * 5:
            return

        transitions = self.memory.sample(self.batch_size)
        batch = Transition(*zip(*transitions))

        # Process state batch
        state_parts = list(zip(*batch.state))
        current_frames = torch.cat(state_parts[0])
        speeds = torch.cat(state_parts[1])
        last_decision_frames = torch.cat(state_parts[2])
        last_actions = torch.cat(state_parts[3])
        state_batch = (current_frames, speeds, last_decision_frames, last_actions)

        action_batch = torch.cat(batch.action).to(self.device)
        reward_batch = torch.cat(batch.reward).to(self.device)
        done_batch = torch.tensor(batch.done, dtype=torch.float32).to(self.device)

        # Compute Q(s_t, a)
        state_action_values = self.policy_net(state_batch).gather(1, action_batch)

        # Compute V(s_{t+1}) for all next states
        next_state_values = torch.zeros(self.batch_size, device=self.device)

        # Get non-final next states
        non_final_mask = torch.tensor([s is not None for s in batch.next_state],
                                      dtype=torch.bool, device=self.device)

        if non_final_mask.any():
            next_state_parts = list(zip(*[s for s in batch.next_state if s is not None]))
            non_final_next_current_frames = torch.cat(next_state_parts[0])
            non_final_next_speeds = torch.cat(next_state_parts[1])
            non_final_next_last_decision_frames = torch.cat(next_state_parts[2])
            non_final_next_last_actions = torch.cat(next_state_parts[3])
            non_final_next_states = (
                non_final_next_current_frames,
                non_final_next_speeds,
                non_final_next_last_decision_frames,
                non_final_next_last_actions
            )

            with torch.no_grad():
                next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1)[0]

        # Compute the expected Q values
        expected_state_action_values = (next_state_values * self.gamma) + reward_batch

        # Compute loss
        criterion = nn.SmoothL1Loss()
        loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

        # Optimize the model
        self.optimizer.zero_grad()
        loss.backward()
        for param in self.policy_net.parameters():
            param.grad.data.clamp_(-1, 1)
        self.optimizer.step()
        # Add this after the optimization step
        self.soft_update()

    def save(self, filename):
        torch.save({
            'policy_state_dict': self.policy_net.state_dict(),
            'target_state_dict': self.target_net.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'epsilon': self.eps_threshold,
        }, filename)

    def load(self, filename):
        checkpoint = torch.load(filename)
        self.policy_net.load_state_dict(checkpoint['policy_state_dict'])
        self.target_net.load_state_dict(checkpoint['target_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.eps_threshold = checkpoint['epsilon']


def evaluate_agent(agent, env, env_seeds, max_steps=1200):
    """Evaluate the agent on multiple seeds and return metrics"""
    total_rewards = []
    episodes_recorder = {}
    track_coverages=[]
    image64s = []

    for i, seed in enumerate(env_seeds):
        # env = gym.make("CarRacing-v3", render_mode='rgb_array', domain_randomize=False, continuous=True)
        current_frame, _ = env.reset(seed=seed)
        car_velocity = env.unwrapped.car.hull.linearVelocity
        speed = np.sqrt(car_velocity[0] ** 2 + car_velocity[1] ** 2)

        # Initialize last decision frame and action
        agent.last_decision_frame = current_frame.copy()
        agent.last_decision_action = [0.0, 0.0, 0.0]

        episode_reward = 0

        state = agent.preprocess_state(
            current_frame=current_frame,
            speed=speed,
            last_decision_frame=agent.last_decision_frame,
            last_action=agent.last_decision_action
        )


        trajectory = []
        car_angles = []
        view_rectangles = []
        done = False

        view_length = 46.0
        view_width = 38.0
        view_offset = 14.0

        step = 0
        while not done and step < max_steps:
            # Select and perform an action
            action_idx = agent.select_action(state)
            action = np.array(agent.DISCRETE_ACTIONS[action_idx.item()])

            # Store decision frame and action
            agent.last_decision_frame = current_frame.copy()
            agent.last_decision_action = action

            # Execute action with frame skipping
            total_skip_reward = 0
            for _ in range(agent.frame_skip):
                if not done and step < max_steps:
                    current_frame, reward, done, truncated, info = env.step(action)
                    step += 1


                    car_pos = env.unwrapped.car.hull.position
                    car_angle = env.unwrapped.car.hull.angle

                    trajectory.append((car_pos.x, car_pos.y))
                    car_angles.append(car_angle)

                    corrected_angle = car_angle + np.pi / 2

                    view_center_x = car_pos.x + np.cos(corrected_angle) * view_offset
                    view_center_y = car_pos.y + np.sin(corrected_angle) * view_offset

                    view_rectangles.append((view_center_x, view_center_y, corrected_angle, view_width, view_length))

                    total_skip_reward += reward

            # Calculate speed for next state
            car_velocity = env.unwrapped.car.hull.linearVelocity
            next_speed = np.sqrt(car_velocity[0] ** 2 + car_velocity[1] ** 2)  # 计算速度大小
            next_state = agent.preprocess_state(
                current_frame=current_frame,
                speed=next_speed,
                last_decision_frame=agent.last_decision_frame,
                last_action=agent.last_decision_action
            )

            # Move to the next state
            state = next_state
            episode_reward += total_skip_reward

        plt.figure(figsize=(9, 8))
        green_color = '#62f972'
        plt.gca().set_facecolor(green_color)

        for polygon in env.unwrapped.road_poly:
            vertices = polygon[0]
            color = polygon[1]
            if hasattr(color, '__iter__') and not isinstance(color, tuple):
                color = tuple(color)
            fill_color = '#FFFFFF'
            if isinstance(color, tuple) and len(color) == 3:
                if color == (255, 255, 255):
                    fill_color = '#FFFFFF'
                elif color == (255, 0, 0):
                    fill_color = '#FF0000'
            x_coords = [v[0] for v in vertices] + [vertices[0][0]]
            y_coords = [v[1] for v in vertices] + [vertices[0][1]]
            plt.fill(x_coords, y_coords, color=fill_color, alpha=1.0)

        view_color = '#8000FF'
        arrow_interval = 50
        for idy, rect in enumerate(view_rectangles):
            if idy == 0 or idy == len(view_rectangles) - 1 or idy % arrow_interval == 0:
                center_x, center_y, angle, length, width = rect


                rect_patch = patches.Rectangle(
                    (-length / 2, -width / 2),
                    length,
                    width,
                    linewidth=0,
                    edgecolor='none',
                    facecolor=view_color,
                    alpha=0.1
                )

                t = Affine2D().rotate(angle).translate(center_x, center_y) + plt.gca().transData
                rect_patch.set_transform(t)
                plt.gca().add_patch(rect_patch)

        arrow_color = '#FF6A00'
        if trajectory:
            trajectory = np.array(trajectory)
            plt.plot(trajectory[:, 0], trajectory[:, 1], 'k-', linewidth=1, label='Trajectory')
            # plt.scatter(trajectory[0, 0], trajectory[0, 1], c='#1E90FF', s=100, label='Start Point')
            # plt.scatter(trajectory[-1, 0], trajectory[-1, 1], c='#FF00FF', s=100, label='End Point')


            for i in range(len(trajectory)):

                if i == 0 or i == len(trajectory) - 1 or i % arrow_interval == 0:
                    x, y = trajectory[i, 0], trajectory[i, 1]
                    angle = car_angles[i] + np.pi / 2
                    dx = np.cos(angle) * 3
                    dy = np.sin(angle) * 5


                    arrow_start_x = x - dx * 0.3
                    arrow_start_y = y - dy * 0.3

                    plt.arrow(arrow_start_x, arrow_start_y, dx, dy,
                              head_width=3, head_length=4, fc=arrow_color, ec=arrow_color)


        grass_patch = patches.Patch(color=green_color, label='Off-Track Area (Grass)')
        track_patch = patches.Patch(color='white', label='Track')
        border_patch = patches.Patch(color='red', label='Curbing (indicating the track limits during sharp turns)')
        view_patch = patches.Patch(color=view_color, alpha=0.1, label="Agent's Dynamic Visual Field")


        handles, labels = plt.gca().get_legend_handles_labels()

        custom_handles = [grass_patch, track_patch, border_patch, view_patch]
        all_handles = custom_handles + handles

        seen_labels = set()
        unique_handles = []
        for handle in all_handles:
            label = handle.get_label()
            if label not in seen_labels:
                seen_labels.add(label)
                unique_handles.append(handle)

        track_coverage = env.unwrapped.tile_visited_count / len(env.unwrapped.track) * 100

        plt.title(
            f"Track with Car Trajectory and Corresponding Dynamic View Areas\n"
            f"Track Completion Rate: {track_coverage:.1f} %")
        plt.axis('equal')
        plt.legend(handles=unique_handles)

        buffer = BytesIO()
        plt.savefig(buffer, format="png", bbox_inches='tight')
        buffer.seek(0)

        img_base64 = base64.b64encode(buffer.read()).decode("utf-8")
        plt.close()

        # Record results
        total_rewards.append(episode_reward)
        track_coverages.append(track_coverage)
        image64s.append(img_base64)

        episodes_recorder[f'{seed}'] = {
            'seed': seed,
            'episode_reward': episode_reward,
            'terminated': done,
            'truncated': truncated,
            'track_coverage': track_coverage,
        }

    # Calculate metrics
    mean_reward = np.mean(total_rewards)

    # Normalized Weighted Score (α=0.6, β=0.2, γ=0.2)
    nws = np.mean(track_coverages)

    # Get the worst performance case
    worst_idx = np.argmin(total_rewards)

    return {
        'mean_reward': mean_reward,
        'nws': nws,
        'worst_case_image': image64s[worst_idx],
        'episodes_recorder': episodes_recorder
    }


def train_dqn(env, agent, num_episodes=1000, max_steps=200, save_interval=100, env_seeds=(42,)):
    best_nws = -np.inf
    # Fixed evaluation seeds
    seed_cycle = itertools.cycle(env_seeds)

    rewards_history = []
    episode_lengths = []
    moving_avg_rewards = []
    evaluation_history = []

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    save_dir = f"results/dqn_carracing_{timestamp}"
    os.makedirs(save_dir, exist_ok=True)

    # Record start time
    start_time = time.time()
    last_eval_time = start_time

    for i_episode in range(num_episodes):
        # Cycle through the specified seeds
        # agent.tau = max(0.001, 0.01*(1 - i_episode/num_episodes))
        current_seed = next(seed_cycle)
        current_frame, _ = env.reset(seed=current_seed)
        car_velocity = env.unwrapped.car.hull.linearVelocity
        speed = np.sqrt(car_velocity[0] ** 2 + car_velocity[1] ** 2)

        # Initialize last decision frame and action
        agent.last_decision_frame = current_frame.copy()
        agent.last_decision_action = [0.0, 0.0, 0.0]

        # Preprocess initial state
        state = agent.preprocess_state(
            current_frame=current_frame,
            speed=speed,
            last_decision_frame=agent.last_decision_frame,
            last_action=agent.last_decision_action
        )

        total_reward = 0
        done = False
        frame_count = 0
        step_count = 0

        while not done and step_count < max_steps:
            # Select and perform an action
            action_idx = agent.select_action(state)
            action = np.array(agent.DISCRETE_ACTIONS[action_idx.item()])

            # Store decision frame and action
            agent.last_decision_frame = current_frame.copy()
            agent.last_decision_action = action

            # Execute action with frame skipping
            total_skip_reward = 0
            for _ in range(agent.frame_skip):
                if not done and step_count < max_steps:
                    current_frame, reward, done, truncated, info = env.step(action)
                    step_count += 1
                    total_skip_reward += reward
                    frame_count += 1

            # Calculate speed for next state
            car_velocity = env.unwrapped.car.hull.linearVelocity
            next_speed = np.sqrt(car_velocity[0] ** 2 + car_velocity[1] ** 2)

            if not done and step_count < max_steps:
                next_state = agent.preprocess_state(
                    current_frame=current_frame,
                    speed=next_speed,
                    last_decision_frame=agent.last_decision_frame,
                    last_action=agent.last_decision_action
                )
            else:
                next_state = None
                done = True

            # Store the transition in memory
            reward = torch.tensor([total_skip_reward], dtype=torch.float32)
            agent.memory.push(state, action_idx, next_state, reward, done)

            # Move to the next state
            state = next_state
            total_reward += total_skip_reward

            # Perform one step of the optimization
            agent.learn()

        rewards_history.append(total_reward)
        episode_lengths.append(step_count)

        # Calculate moving average
        if i_episode >= 100:
            avg_reward = np.mean(rewards_history[-100:])
            moving_avg_rewards.append(avg_reward)
        else:
            moving_avg_rewards.append(np.mean(rewards_history))

        # print(f"Episode {episode}/{episodes}, Reward: {total_reward:.2f}, Steps: {steps}, Epsilon: {agent.epsilon:.2f}")
        # print(type(i_episode), type(num_episodes), type(current_seed), type(total_reward), type(agent.eps_threshold), type(step_count))
        print(
            f"Episode {i_episode}/{num_episodes}, Seed: {current_seed}, Total Reward: {total_reward:.2f}, Epsilon: {agent.eps_threshold:.2f}, Steps: {step_count}")



        # Periodic evaluation
        if i_episode % save_interval == 0 or i_episode == num_episodes:
            current_time = time.time()
            elapsed_time = current_time - start_time
            eval_time = current_time - last_eval_time
            last_eval_time = current_time

            eval_results = evaluate_agent(agent, env, env_seeds, max_steps)
            evaluation_history.append({
                'episode': i_episode,
                'metrics': eval_results,
                'elapsed_time': elapsed_time,
                'eval_time': eval_time
            })

            print(f"\nEvaluation at Episode {i_episode}:")
            print(f"  Mean Reward: {eval_results['mean_reward']:.2f}")
            print(f"  NWS: {eval_results['nws']:.2f}")
            print(f"  Elapsed Time: {elapsed_time:.2f} seconds")
            print(f"  Time Since Last Eval: {eval_time:.2f} seconds")

            if eval_results['nws'] > best_nws:
                best_nws = eval_results['nws']
                model_path = os.path.join(save_dir, f"dqn_model_best.pth")
                agent.save(model_path)
                print(f"  New Best Model Saved! NWS: {best_nws:.2f}")

            # Save evaluation image
            img_data = base64.b64decode(eval_results['worst_case_image'])
            img_path = os.path.join(save_dir, f"eval_ep{i_episode}_worst.png")
            with open(img_path, 'wb') as f:
                f.write(img_data)

                # Save current model
                if i_episode % (5 * save_interval) == 0:
                    agent.save(os.path.join(save_dir, f"model_ep{i_episode}.pth"))

            # Record final time
            end_time = time.time()
            total_time = end_time - start_time

            # Save final model and training data
            final_model_path = os.path.join(save_dir, "dqn_model_final.pth")
            agent.save(final_model_path)

            training_data = {
                'rewards': rewards_history,
                'episode_lengths': episode_lengths,
                'moving_avg_rewards': moving_avg_rewards,
                'evaluation_history': evaluation_history,
                'env_seeds': env_seeds,
                'total_time': total_time,
                'start_time': start_time,
                'end_time': end_time
            }
            with open(os.path.join(save_dir, 'training_history.json'), 'w') as f:
                json.dump(training_data, f)

            # Plot training results
            plt.figure(figsize=(12, 6))
            plt.plot(rewards_history, alpha=0.6, label='Episode Reward')
            plt.plot(moving_avg_rewards, 'r-', linewidth=2, label=f'Moving Avg ({save_interval} episodes)')

            # Mark evaluation points
            eval_episodes = [e['episode'] for e in evaluation_history]
            eval_scores = [e['metrics']['nws'] for e in evaluation_history]
            plt.scatter(eval_episodes, [moving_avg_rewards[e - 1] for e in eval_episodes],
                        c='green', marker='o', label='Evaluation Points')

            plt.xlabel('Episode')
            plt.ylabel('Reward')
            plt.title('DQN Training Performance with Periodic Evaluation')
            plt.legend()
            plt.grid()
            plt.savefig(os.path.join(save_dir, 'training_plot.png'))
            plt.close()

            # Plot evaluation metrics
            plt.figure(figsize=(12, 8))

            plt.subplot(1, 2, 1)
            plt.plot([e['episode'] for e in evaluation_history],
                     [e['metrics']['mean_reward'] for e in evaluation_history], 'o-')
            plt.xlabel('Episode')
            plt.ylabel('Mean Reward')
            plt.title('Evaluation Mean Reward')
            plt.grid()

            plt.subplot(1, 2, 2)
            plt.plot([e['episode'] for e in evaluation_history],
                     [e['metrics']['nws'] for e in evaluation_history], 'o-')
            plt.xlabel('Episode')
            plt.ylabel('Track Coverage')
            plt.title('Track Coverage')
            plt.grid()

            plt.tight_layout()
            plt.savefig(os.path.join(save_dir, 'evaluation_metrics.png'))
            plt.close()

            # Plot best NWS progression
            best_nws_plot_path = plot_best_nws_progress(evaluation_history, save_dir)
            print(f"Training completed in {total_time:.2f} seconds. Results saved in: {save_dir}")

    return save_dir


def plot_best_nws_progress(evaluation_history, save_dir):
    best_nws = -float('inf')
    best_nws_history = []
    episodes = []

    for eval_point in evaluation_history:
        current_nws = eval_point['metrics']['nws']
        if current_nws > best_nws:
            best_nws = current_nws
            best_nws_history.append(best_nws)
            episodes.append(eval_point['episode'])
        else:
            # Keep the previous best value
            best_nws_history.append(best_nws)
            episodes.append(eval_point['episode'])

    plt.figure(figsize=(10, 6))
    plt.plot(episodes, best_nws_history, 'g-o', linewidth=2, markersize=8)

    # Annotate the best point
    max_idx = np.argmax(best_nws_history)
    plt.annotate(f'Best NWS: {best_nws_history[max_idx]:.3f}',
                 xy=(episodes[max_idx], best_nws_history[max_idx]),
                 xytext=(0, -30), textcoords='offset points',
                 bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.5),
                 arrowprops=dict(arrowstyle='->'))

    plt.xlabel('Episode')
    plt.ylabel('Best NWS')
    plt.title('Progression of Best Normalized Weighted Score')
    plt.grid(True)

    # Save the plot
    plot_path = os.path.join(save_dir, 'best_nws_progression.png')
    plt.savefig(plot_path)
    plt.close()

    return plot_path

def main(run_id):
    carracing_env = gym.make('CarRacing-v3', render_mode='rgb_array', domain_randomize=False, continuous=True)  # rgb_array
    frame_skip = 1
    agent = CarRacingAgent(carracing_env, frame_skip)

    save_dir = train_dqn(carracing_env, agent, num_episodes=8000, max_steps=1200, save_interval=30,
                         env_seeds=(40, 1231, 516, 413))

    carracing_env.close()
    print(f"Training completed. Results saved in: {save_dir}")

if __name__ == "__main__":
    for run_i in range(1):
        main(run_i)
