import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal, Independent
from collections import deque
import matplotlib.pyplot as plt
import os
import json
from datetime import datetime
import base64
from io import BytesIO
import time
import itertools
import torch.nn.functional as F
import matplotlib.patches as patches
from matplotlib.transforms import Affine2D


class ActorCritic(nn.Module):
    def __init__(self, h, w, speed_dim, action_dim):
        super(ActorCritic, self).__init__()
        self.action_dim = action_dim

        # Shared CNN for both current and last frames
        self.shared_conv = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )

        # Calculate conv output size
        def conv2d_size_out(size, kernel_size, stride):
            return (size - (kernel_size - 1) - 1) // stride + 1

        convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w, 8, 4), 4, 2), 3, 1)
        convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h, 8, 4), 4, 2), 3, 1)
        conv_output_size = convw * convh * 64

        # Total input size: 2 frames + speed + last action
        linear_input_size = conv_output_size * 2 + speed_dim + action_dim

        # Shared network with layer normalization
        self.shared_net = nn.Sequential(
            nn.Linear(linear_input_size, 512),
            nn.LayerNorm(512),
            nn.ReLU()
        )

        # Actor networks for mean and log std
        self.actor_mu = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, action_dim)
        )
        self.actor_log_std = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, action_dim)
        )

        # Critic network
        self.critic = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )

    def forward(self, x):
        current_frame, speed, last_decision_frame, last_action = x

        # Process frames through shared CNN
        current_features = self.shared_conv(current_frame)
        current_features = current_features.flatten(start_dim=1)

        last_features = self.shared_conv(last_decision_frame)
        last_features = last_features.flatten(start_dim=1)

        # Concatenate all features
        x = torch.cat([
            current_features,
            last_features,
            speed,
            last_action
        ], dim=1)

        shared_out = self.shared_net(x)

        # Get action distribution parameters
        mu = self.actor_mu(shared_out)
        log_std = self.actor_log_std(shared_out)

        # Apply action bounds
        mu_steering = torch.tanh(mu[:, 0:1])  # Steering in [-1, 1]
        mu_throttle = torch.sigmoid(mu[:, 1:2])  # Throttle in [0, 1]
        mu_brake = torch.sigmoid(mu[:, 2:3])  # Brake in [0, 1]
        mu = torch.cat([mu_steering, mu_throttle, mu_brake], dim=1)

        # Limit standard deviation
        std = torch.exp(log_std).clamp(max=1.0)

        # Create independent normal distribution
        dist = Independent(Normal(mu, std), 1)
        value = self.critic(shared_out)

        return dist, value


class PPOCarRacingAgent:
    def __init__(self, env, frame_skip=4):
        self.env = env
        self.frame_skip = frame_skip
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")

        # Continuous action space parameters
        self.action_dim = 3  # [steering, throttle, brake]

        # Get environment dimensions
        self.screen_height, self.screen_width, _ = self.env.observation_space.shape
        self.speed_dim = 1
        self.last_action_dim = 3

        # PPO hyperparameters
        self.gamma = 0.99
        self.gae_lambda = 0.95
        self.ppo_epochs = 4
        self.ppo_clip = 0.2
        self.entropy_coef = 0.01  # Lower for continuous actions
        self.critic_coef = 0.5
        self.max_grad_norm = 0.5
        self.batch_size = 256
        self.mini_batch_size = 64

        # Initialize networks
        self.policy = ActorCritic(
            self.screen_height, self.screen_width,
            self.speed_dim, self.action_dim
        ).to(self.device)

        self.optimizer = optim.Adam(self.policy.parameters(), lr=3e-5)  # Lower learning rate
        self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=100, gamma=0.99)

        # Frame skipping variables
        self.last_decision_frame = None
        self.last_decision_action = torch.zeros(3).to(self.device)

    def preprocess_state(self, current_frame, speed, last_decision_frame, last_action):
        """Preprocess state components into tensors"""
        current_frame = torch.FloatTensor(current_frame).permute(2, 0, 1).to(self.device) / 255.0
        last_decision_frame = torch.FloatTensor(last_decision_frame).permute(2, 0, 1).to(self.device) / 255.0
        speed = torch.FloatTensor([speed]).to(self.device)
        last_action = torch.FloatTensor(last_action).to(self.device)

        # Add batch dimension
        return (
            current_frame.unsqueeze(0),
            speed.unsqueeze(0),
            last_decision_frame.unsqueeze(0),
            last_action.unsqueeze(0)
        )

    def act(self, state):
        """Sample action from policy"""
        with torch.no_grad():
            dist, value = self.policy(state)
            action = dist.sample()
            log_prob = dist.log_prob(action)
            return action, log_prob, value

    def compute_gae(self, rewards, values, dones, next_values):
        advantages = torch.zeros_like(rewards)
        last_advantage = 0

        for t in reversed(range(len(rewards))):
            if dones[t]:
                delta = rewards[t] - values[t]
                last_advantage = delta
            else:
                delta = rewards[t] + self.gamma * next_values[t] - values[t]
                last_advantage = delta + self.gamma * self.gae_lambda * (1 - dones[t]) * last_advantage

            advantages[t] = last_advantage

        returns = advantages + values
        return advantages, returns

    def update(self, states, next_states, actions, log_probs_old, rewards, dones, values):
        """Update policy using PPO"""
        # Convert to tensors
        states = (
            torch.cat([s[0] for s in states]),
            torch.cat([s[1] for s in states]),
            torch.cat([s[2] for s in states]),
            torch.cat([s[3] for s in states])
        )

        next_states = (
            torch.cat([ns[0] for ns in next_states]),
            torch.cat([ns[1] for ns in next_states]),
            torch.cat([ns[2] for ns in next_states]),
            torch.cat([ns[3] for ns in next_states])
        )

        actions = torch.cat(actions)
        log_probs_old = torch.cat(log_probs_old)
        rewards = torch.cat(rewards)
        dones = torch.cat(dones)
        values = torch.cat(values).squeeze(-1)

        with torch.no_grad():
            _, next_values_all = self.policy(next_states)
            next_values_all = next_values_all.squeeze(-1)

        next_values = next_values_all * (1 - dones)

        # Compute GAE and returns
        advantages, returns = self.compute_gae(
            rewards=rewards,
            values=values,
            dones=dones,
            next_values=next_values
        )

        # Normalize advantages
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        # PPO updates with mini-batches
        for _ in range(self.ppo_epochs):
            indices = torch.randperm(len(actions))
            for start in range(0, len(actions), self.mini_batch_size):
                end = start + self.mini_batch_size
                idx = indices[start:end]

                # Get mini-batch
                mini_states = (
                    states[0][idx], states[1][idx],
                    states[2][idx], states[3][idx]
                )
                mini_actions = actions[idx]
                mini_log_probs_old = log_probs_old[idx]
                mini_advantages = advantages[idx]
                mini_returns = returns[idx]

                # Compute new policy
                dist, new_values = self.policy(mini_states)
                new_log_probs = dist.log_prob(mini_actions)
                entropy = dist.entropy().mean()

                # PPO Loss
                ratio = (new_log_probs - mini_log_probs_old).exp()
                surr1 = ratio * mini_advantages
                surr2 = torch.clamp(ratio, 1.0 - self.ppo_clip, 1.0 + self.ppo_clip) * mini_advantages
                actor_loss = -torch.min(surr1, surr2).mean()

                critic_loss = F.mse_loss(new_values.squeeze(), mini_returns)
                loss = actor_loss + self.critic_coef * critic_loss - self.entropy_coef * entropy

                # Update
                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
                self.optimizer.step()

        # Update learning rate
        self.scheduler.step()

    def save(self, filename):
        """Save model checkpoint"""
        torch.save({
            'policy_state_dict': self.policy.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
        }, filename)

    def load(self, filename):
        """Load model checkpoint"""
        checkpoint = torch.load(filename)
        self.policy.load_state_dict(checkpoint['policy_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])


def evaluate_agent(agent, env, env_seeds, max_steps=1200):
    """Evaluate agent on multiple seeds"""
    total_rewards = []
    episodes_recorder = {}
    track_coverages = []
    image64s = []

    for seed in env_seeds:
        current_frame, _ = env.reset(seed=seed)
        car_velocity = env.unwrapped.car.hull.linearVelocity
        speed = np.sqrt(car_velocity[0] ** 2 + car_velocity[1] ** 2)

        # Initialize last decision frame and action
        agent.last_decision_frame = current_frame.copy()
        agent.last_decision_action = np.array([0.0, 0.0, 0.0])

        state = agent.preprocess_state(
            current_frame=current_frame,
            speed=speed,
            last_decision_frame=agent.last_decision_frame,
            last_action=agent.last_decision_action
        )

        episode_reward = 0
        trajectory = []
        car_angles = []
        view_rectangles = []
        done = False
        view_length = 46.0
        view_width = 38.0
        view_offset = 14.0
        step = 0

        while not done and step < max_steps:
            # Select continuous action
            action_tensor, _, _ = agent.act(state)
            action = action_tensor.squeeze(0).cpu().numpy()

            # Store decision frame and action
            agent.last_decision_frame = current_frame.copy()
            agent.last_decision_action = action

            # Execute action with frame skipping
            total_skip_reward = 0
            for _ in range(agent.frame_skip):
                if not done and step < max_steps:
                    current_frame, reward, done, truncated, info = env.step(action)
                    step += 1

                    # Record car position and angle
                    car_pos = env.unwrapped.car.hull.position
                    car_angle = env.unwrapped.car.hull.angle

                    trajectory.append((car_pos.x, car_pos.y))
                    car_angles.append(car_angle)

                    # Calculate view rectangle
                    corrected_angle = car_angle + np.pi / 2
                    view_center_x = car_pos.x + np.cos(corrected_angle) * view_offset
                    view_center_y = car_pos.y + np.sin(corrected_angle) * view_offset
                    view_rectangles.append((view_center_x, view_center_y, corrected_angle, view_width, view_length))

                    total_skip_reward += reward

            # Update state
            car_velocity = env.unwrapped.car.hull.linearVelocity
            next_speed = np.sqrt(car_velocity[0] ** 2 + car_velocity[1] ** 2)

            state = agent.preprocess_state(
                current_frame=current_frame,
                speed=next_speed,
                last_decision_frame=agent.last_decision_frame,
                last_action=agent.last_decision_action
            )

            episode_reward += total_skip_reward

        # Calculate track coverage
        track_coverage = env.unwrapped.tile_visited_count / len(env.unwrapped.track) * 100
        total_rewards.append(episode_reward)
        track_coverages.append(track_coverage)

        plt.figure(figsize=(9, 8))
        green_color = '#62f972'
        plt.gca().set_facecolor(green_color)

        for polygon in env.unwrapped.road_poly:
            vertices = polygon[0]
            color = polygon[1]
            if hasattr(color, '__iter__') and not isinstance(color, tuple):
                color = tuple(color)
            fill_color = '#FFFFFF'
            if isinstance(color, tuple) and len(color) == 3:
                if color == (255, 255, 255):
                    fill_color = '#FFFFFF'
                elif color == (255, 0, 0):
                    fill_color = '#FF0000'
            x_coords = [v[0] for v in vertices] + [vertices[0][0]]
            y_coords = [v[1] for v in vertices] + [vertices[0][1]]
            plt.fill(x_coords, y_coords, color=fill_color, alpha=1.0)

        view_color = '#8000FF'
        arrow_interval = 50

        for idy, rect in enumerate(view_rectangles):
            if idy == 0 or idy == len(view_rectangles) - 1 or idy % arrow_interval == 0:
                center_x, center_y, angle, length, width = rect

                rect_patch = patches.Rectangle(
                    (-length / 2, -width / 2),  # 左下角相对于中心
                    length,
                    width,
                    linewidth=0,
                    edgecolor='none',
                    facecolor=view_color,
                    alpha=0.1
                )

                t = Affine2D().rotate(angle).translate(center_x, center_y) + plt.gca().transData
                rect_patch.set_transform(t)
                plt.gca().add_patch(rect_patch)
        arrow_color = '#FF6A00'

        if trajectory:
            trajectory = np.array(trajectory)
            plt.plot(trajectory[:, 0], trajectory[:, 1], 'k-', linewidth=1, label='Trajectory')
            # plt.scatter(trajectory[0, 0], trajectory[0, 1], c='#1E90FF', s=100, label='Start Point')
            # plt.scatter(trajectory[-1, 0], trajectory[-1, 1], c='#FF00FF', s=100, label='End Point')



            for i in range(len(trajectory)):

                if i == 0 or i == len(trajectory) - 1 or i % arrow_interval == 0:
                    x, y = trajectory[i, 0], trajectory[i, 1]
                    angle = car_angles[i] + np.pi / 2
                    dx = np.cos(angle) * 3
                    dy = np.sin(angle) * 5


                    arrow_start_x = x - dx * 0.3
                    arrow_start_y = y - dy * 0.3

                    plt.arrow(arrow_start_x, arrow_start_y, dx, dy,
                              head_width=3, head_length=4, fc=arrow_color, ec=arrow_color)


        grass_patch = patches.Patch(color=green_color, label='Off-Track Area (Grass)')
        track_patch = patches.Patch(color='white', label='Track')
        border_patch = patches.Patch(color='red', label='Curbing (indicating the track limits during sharp turns)')
        view_patch = patches.Patch(color=view_color, alpha=0.1, label="Agent's Dynamic Visual Field")


        handles, labels = plt.gca().get_legend_handles_labels()

        custom_handles = [grass_patch, track_patch, border_patch, view_patch]
        all_handles = custom_handles + handles

        seen_labels = set()
        unique_handles = []
        for handle in all_handles:
            label = handle.get_label()
            if label not in seen_labels:
                seen_labels.add(label)
                unique_handles.append(handle)

        plt.title(
            f"Track with Car Trajectory and Corresponding Dynamic View Areas\n"
            f"Track Completion Rate: {track_coverage:.1f} %")
        plt.axis('equal')
        plt.legend(handles=unique_handles)


        buffer = BytesIO()
        plt.savefig(buffer, format="png", bbox_inches='tight')
        buffer.seek(0)

        img_base64 = base64.b64encode(buffer.read()).decode("utf-8")
        plt.close()

        # Record results
        image64s.append(img_base64)

        episodes_recorder[f'{seed}'] = {
            'seed': seed,
            'episode_reward': episode_reward,
            'terminated': done,
            'truncated': truncated,
            'track_coverage': track_coverage,
        }

    # Calculate metrics
    mean_reward = np.mean(total_rewards)
    nws = np.mean(track_coverages)
    worst_idx = np.argmin(total_rewards)

    return {
        'mean_reward': mean_reward,
        'nws': nws,
        'track_coverages': track_coverages,
        'total_rewards': total_rewards,
        'worst_case_image': image64s[worst_idx] if image64s else None,
        'episodes_recorder': episodes_recorder
    }

def plot_best_nws_progress(evaluation_history, save_dir):
    best_nws = -float('inf')
    best_nws_history = []
    episodes = []

    for eval_point in evaluation_history:
        current_nws = eval_point['metrics']['nws']
        if current_nws > best_nws:
            best_nws = current_nws
            best_nws_history.append(best_nws)
            episodes.append(eval_point['episode'])
        else:
            # Keep the previous best value
            best_nws_history.append(best_nws)
            episodes.append(eval_point['episode'])

    plt.figure(figsize=(10, 6))
    plt.plot(episodes, best_nws_history, 'g-o', linewidth=2, markersize=8)

    # Annotate the best point
    max_idx = np.argmax(best_nws_history)
    plt.annotate(f'Best NWS: {best_nws_history[max_idx]:.3f}',
                 xy=(episodes[max_idx], best_nws_history[max_idx]),
                 xytext=(0, -30), textcoords='offset points',
                 bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.5),
                 arrowprops=dict(arrowstyle='->'))

    plt.xlabel('Episode')
    plt.ylabel('Best NWS')
    plt.title('Progression of Best Normalized Weighted Score')
    plt.grid(True)

    # Save the plot
    plot_path = os.path.join(save_dir, 'best_nws_progression.png')
    plt.savefig(plot_path)
    plt.close()

    return plot_path

def train_ppo(env, agent, num_episodes=1000, max_steps=1200, save_interval=50, env_seeds=(42,)):
    """Train PPO agent"""
    best_nws = -np.inf
    seed_cycle = itertools.cycle(env_seeds)

    # Training metrics
    rewards_history = []
    episode_lengths = []
    moving_avg_rewards = []
    evaluation_history = []

    # Create save directory
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    save_dir = f"results/ppo_carracing_continuous_{timestamp}"
    os.makedirs(save_dir, exist_ok=True)

    start_time = time.time()
    last_eval_time = start_time

    for i_episode in range(num_episodes):
        current_seed = next(seed_cycle)
        current_frame, _ = env.reset(seed=current_seed)
        car_velocity = env.unwrapped.car.hull.linearVelocity
        speed = np.sqrt(car_velocity[0] ** 2 + car_velocity[1] ** 2)

        # Initialize last decision frame and action
        agent.last_decision_frame = current_frame.copy()
        agent.last_decision_action = np.array([0.0, 0.0, 0.0])

        state = agent.preprocess_state(
            current_frame=current_frame,
            speed=speed,
            last_decision_frame=agent.last_decision_frame,
            last_action=agent.last_decision_action
        )

        # Initialize buffers
        states, next_states, actions, log_probs, rewards, dones, values = [], [], [], [], [], [], []
        episode_reward = 0
        done = False
        step = 0

        while not done and step < max_steps:
            # Select continuous action
            action_tensor, log_prob, value = agent.act(state)
            action = action_tensor.squeeze(0).cpu().numpy()

            # Store transition
            states.append(state)
            actions.append(action_tensor)
            log_probs.append(log_prob)
            values.append(value)

            # Execute action with frame skipping
            total_skip_reward = 0
            for _ in range(agent.frame_skip):
                if not done and step < max_steps:
                    current_frame, reward, done, truncated, info = env.step(action)
                    step += 1
                    total_skip_reward += reward

            # Update state
            car_velocity = env.unwrapped.car.hull.linearVelocity
            next_speed = np.sqrt(car_velocity[0] ** 2 + car_velocity[1] ** 2)
            next_state = agent.preprocess_state(
                current_frame=current_frame,
                speed=next_speed,
                last_decision_frame=agent.last_decision_frame,
                last_action=action
            )
            next_states.append(next_state)

            # Store reward and done flag
            rewards.append(torch.FloatTensor([total_skip_reward]).to(agent.device))
            dones.append(torch.FloatTensor([done]).to(agent.device))

            episode_reward += total_skip_reward

            # Update state
            agent.last_decision_frame = current_frame.copy()
            agent.last_decision_action = action
            state = next_state

        # Update policy
        if len(states) > 0:
            agent.update(
                states=states,
                next_states=next_states,
                actions=actions,
                log_probs_old=log_probs,
                rewards=rewards,
                dones=dones,
                values=values
            )

        # Record training metrics
        rewards_history.append(episode_reward)
        if i_episode >= 100:
            moving_avg = np.mean(rewards_history[-100:])
        else:
            moving_avg = np.mean(rewards_history)
        moving_avg_rewards.append(moving_avg)

        print(
            f"Episode {i_episode + 1}/{num_episodes}, Seed: {current_seed}, Reward: {episode_reward:.1f}, Steps: {step}")

        # Periodic evaluation
        if i_episode % save_interval == 0 or i_episode == num_episodes - 1:
            current_time = time.time()
            elapsed_time = current_time - start_time
            eval_time = current_time - last_eval_time
            last_eval_time = current_time

            eval_results = evaluate_agent(agent, env, env_seeds, max_steps)
            evaluation_history.append({
                'episode': i_episode,
                'metrics': eval_results,
                'elapsed_time': elapsed_time,
                'eval_time': eval_time
            })

            print(f"\nEvaluation at Episode {i_episode}:")
            print(f"  Mean Reward: {eval_results['mean_reward']:.2f}")
            print(f"  NWS: {eval_results['nws']:.2f}")
            print(f"  Elapsed Time: {elapsed_time:.2f} seconds")
            print(f"  Time Since Last Eval: {eval_time:.2f} seconds")

            # Save best model
            if eval_results['nws'] > best_nws:
                best_nws = eval_results['nws']
                agent.save(os.path.join(save_dir, f"ppo_model_best.pth"))
                print(f"  New best model saved! NWS: {best_nws:.1f}")

            # # Save current model
            # agent.save(os.path.join(save_dir, f"model_ep{i_episode}.pth"))
            # Save evaluation image
            img_data = base64.b64decode(eval_results['worst_case_image'])
            img_path = os.path.join(save_dir, f"eval_ep{i_episode}_worst.png")
            with open(img_path, 'wb') as f:
                f.write(img_data)

            # Save current model
            if i_episode % (5 * save_interval) == 0:
                agent.save(os.path.join(save_dir, f"model_ep{i_episode}.pth"))

            # Record final time
            end_time = time.time()
            total_time = end_time - start_time

            # Save final model and training data
            final_model_path = os.path.join(save_dir, "ppo_model_final.pth")
            agent.save(final_model_path)

            # Save final training data
            training_data = {
                'rewards': rewards_history,
                'episode_lengths': episode_lengths,
                'moving_avg_rewards': moving_avg_rewards,
                'evaluation_history': evaluation_history,
                'env_seeds': env_seeds,
                'total_time': total_time,
                'start_time': start_time,
                'end_time': end_time
            }

            with open(os.path.join(save_dir, 'training_data.json'), 'w') as f:
                json.dump(training_data, f)

            # Plot training results
            plt.figure(figsize=(12, 6))
            plt.plot(rewards_history, alpha=0.6, label='Episode Reward')
            plt.plot(moving_avg_rewards, 'r-', linewidth=2, label=f'Moving Avg ({save_interval} episodes)')

            # Mark evaluation points
            eval_episodes = [e['episode'] for e in evaluation_history]
            plt.scatter(eval_episodes, [moving_avg_rewards[e - 1] for e in eval_episodes],
                        c='green', marker='o', label='Evaluation Points')

            plt.xlabel('Episode')
            plt.ylabel('Reward')
            plt.title('PPO Training Performance with Periodic Evaluation')
            plt.legend()
            plt.grid()
            plt.savefig(os.path.join(save_dir, 'training_plot.png'))
            plt.close()

            plt.figure(figsize=(12, 8))

            plt.subplot(1, 2, 1)
            plt.plot([e['episode'] for e in evaluation_history],
                     [e['metrics']['mean_reward'] for e in evaluation_history], 'o-')
            plt.xlabel('Episode')
            plt.ylabel('Mean Reward')
            plt.title('Evaluation Mean Reward')
            plt.grid()

            plt.subplot(1, 2, 2)
            plt.plot([e['episode'] for e in evaluation_history],
                     [e['metrics']['nws'] for e in evaluation_history], 'o-')
            plt.xlabel('Episode')
            plt.ylabel('Track Coverage')
            plt.title('Track Coverage')
            plt.grid()

            plt.tight_layout()
            plt.savefig(os.path.join(save_dir, 'evaluation_metrics.png'))
            plt.close()

            # Plot best NWS progression
            best_nws_plot_path = plot_best_nws_progress(evaluation_history, save_dir)
            print(f"Training completed in {total_time:.2f} seconds. Results saved in: {save_dir}")


    return save_dir

def main(run_id):
    # Initialize environment with continuous actions
    env = gym.make('CarRacing-v3', render_mode='rgb_array', domain_randomize=False, continuous=True)

    # Initialize agent
    agent = PPOCarRacingAgent(env, frame_skip=1)

    # Training parameters
    save_dir = train_ppo(
        env=env,
        agent=agent,
        num_episodes=8000,
        max_steps=1200,
        save_interval=30,
        env_seeds=(42, 1231, 516, 413)
    )

    env.close()


if __name__ == "__main__":
    for run_i in range(1):
        main(run_i)