# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_procgenpy
import os
import random
import sys
import time
from dataclasses import dataclass
import math
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import tyro
from procgen import ProcgenEnv
from torch.distributions.categorical import Categorical
from torch import vmap
from torch.utils.tensorboard import SummaryWriter
from torch.amp import autocast
# allow importing local modules (e.g. poly/, utils.py) when running this file directly
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from poly.wd_regularization_torch import polynomial_regularization, precompute_chebyshev_matrix
from utils import pca


@dataclass
class Args:
    exp_name: str = 'lambda0.001_warmup0.25_seed0_resoluton4_degree3_random_onlyactor_new'
    """the name of this experiment"""
    seed: int = 0
    """seed of the experiment"""
    torch_deterministic: bool = True
    """if toggled, `torch.backends.cudnn.deterministic=False`"""
    cuda: bool = True
    """if toggled, cuda will be enabled by default"""
    track: bool = False
    """if toggled, this experiment will be tracked with Weights and Biases"""
    wandb_project_name: str = "cleanRL"
    """the wandb's project name"""
    wandb_entity: str = None
    """the entity (team) of wandb's project"""
    capture_video: bool = False
    """whether to capture videos of the agent performances (check out `videos` folder)"""

    # Device (aligned with options.py)
    gpu_id: str = "2"
    """which gpu to use"""

    # Algorithm specific arguments
    env_id: str = "starpilot"
    """the id of the environment"""
    total_timesteps: int = int(25e6)
    """total timesteps of the experiments"""
    learning_rate: float = 5e-4
    """the learning rate of the optimizer"""
    num_envs: int = 64
    """the number of parallel game environments"""
    num_steps: int = 256
    """the number of steps to run in each environment per policy rollout"""
    anneal_lr: bool = False
    """Toggle learning rate annealing for policy and value networks"""
    gamma: float = 0.999
    """the discount factor gamma"""
    gae_lambda: float = 0.95
    """the lambda for the general advantage estimation"""
    num_minibatches: int = 8
    """the number of mini-batches"""
    update_epochs: int = 3
    """the K epochs to update the policy"""
    norm_adv: bool = True
    """Toggles advantages normalization"""
    clip_coef: float = 0.2
    """the surrogate clipping coefficient"""
    clip_vloss: bool = True
    """Toggles whether or not to use a clipped loss for the value function, as per the paper."""
    ent_coef: float = 0.01
    """coefficient of the entropy"""
    vf_coef: float = 0.5
    """coefficient of the value function"""
    max_grad_norm: float = 0.5
    """the maximum norm for the gradient clipping"""
    target_kl: float = None
    """the target KL divergence threshold"""

    distribution_mode: str = "easy"

    # --- Added: Generalization & Test Configuration ---
    num_train_levels: int = 500        # Training set limited to 500 levels
    start_level: int = 0               # Training set starts from seed 0
    num_test_levels: int = 0           # 0 means unlimited/random generation (for testing unseen levels)
    test_start_level: int = 500        # Test set starts from seed 500, ensuring no overlap with training set (0-499)
    test_interval: int = 10      # How often to run tests (in steps)
    num_test_episodes: int = 10        # How many episodes to run per test


    # Regularization 
    use_l2: bool = False               # L2 regularization
    l2_weight_decay: float = 1e-5      # L2 weight decay coefficient
    use_batchnorm: bool = False        # Batch Normalization

    # === WD regularization args (mirrors poly-fit/options.py) ===

    min_lambda_reg: float = 0
    """kept for compatibility; not used by default in PPO"""
    max_degree: int = 3
    """Chebyshev max degree"""
    miu: float = 0.0
    """kept for compatibility with polynomial_regularization signature"""
    use_norm: bool = False
    """use norm in wd computation"""
    nums_pairs: int = 1024
    """number of pairs sampled per minibatch for reg"""
    remove_const: bool = False
    """if True, remove 0-order term in reg"""
    label: bool = True
    """(deprecated) kept for CLI compatibility; endpoints are no longer hard-labeled"""
    smooth: bool = False
    """if True, smooth sequence along resolution"""
    random_alpha: bool = True
    """if True, use random cosine sampling for alpha per pair"""
    pca_reg: int = 10
    """if >0, apply PCA (first component) to actor sequence"""
    square: bool = False
    """use square instead of abs in wd"""
    degree_mode: str = "index"
    """degrees vector, index means 0,1,2,... else means 0,0,1,1,1,..."""
    resolution: int = 4
    """number of points along the interpolation path (includes endpoints)"""

    lambda_reg_actor: float = 0.001
    """actor reg coefficient"""
    lambda_reg_critic: float = 0.002
    """critic reg coefficient"""

    lambda_reg_actor_warmup: float = 0.3333
    lambda_reg_actor_warm_type: str = "up" # up / down

    lambda_reg_critic_warmup: float = 0.6667
    lambda_reg_critic_warm_type: str = "down" # up / down

    # to be filled in runtime
    batch_size: int = 0
    """the batch size (computed in runtime)"""
    minibatch_size: int = 0
    """the mini-batch size (computed in runtime)"""
    num_iterations: int = 0
    """the number of iterations (computed in runtime)"""


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer


# taken from https://github.com/AIcrowd/neurips2020-procgen-starter-kit/blob/142d09586d2272a17f44481a115c4bd817cf6a94/models/impala_cnn_torch.py
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv0 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, padding=1)
        self.conv1 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, padding=1)

    def forward(self, x):
        inputs = x
        x = nn.functional.relu(x)
        x = self.conv0(x)
        x = nn.functional.relu(x)
        x = self.conv1(x)
        return x + inputs


class ConvSequence(nn.Module):
    def __init__(self, input_shape, out_channels):
        super().__init__()
        self._input_shape = input_shape
        self._out_channels = out_channels
        self.conv = nn.Conv2d(in_channels=self._input_shape[0], out_channels=self._out_channels, kernel_size=3, padding=1)
        self.res_block0 = ResidualBlock(self._out_channels)
        self.res_block1 = ResidualBlock(self._out_channels)

    def forward(self, x):
        x = self.conv(x)
        x = nn.functional.max_pool2d(x, kernel_size=3, stride=2, padding=1)
        x = self.res_block0(x)
        x = self.res_block1(x)
        assert x.shape[1:] == self.get_output_shape()
        return x

    def get_output_shape(self):
        _c, h, w = self._input_shape
        return (self._out_channels, (h + 1) // 2, (w + 1) // 2)


class Agent(nn.Module):
    def __init__(self, envs):
        super().__init__()
        h, w, c = envs.single_observation_space.shape
        shape = (c, h, w)
        conv_seqs = []
        for out_channels in [16, 32, 32]:
            conv_seq = ConvSequence(shape, out_channels)
            shape = conv_seq.get_output_shape()
            conv_seqs.append(conv_seq)
        conv_seqs += [
            nn.Flatten(),
            nn.ReLU(),
            nn.Linear(in_features=shape[0] * shape[1] * shape[2], out_features=256),
            nn.ReLU(),
        ]
        self.network = nn.Sequential(*conv_seqs)
        self.actor = layer_init(nn.Linear(256, envs.single_action_space.n), std=0.01)
        self.critic = layer_init(nn.Linear(256, 1), std=1)

    def get_value(self, x):
        return self.critic(self.network(x.permute((0, 3, 1, 2)) / 255.0))  # "bhwc" -> "bchw"

    def get_action_and_value(self, x, action=None):
        hidden = self.network(x.permute((0, 3, 1, 2)) / 255.0)  # "bhwc" -> "bchw"
        logits = self.actor(hidden)
        probs = Categorical(logits=logits)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action), probs.entropy(), self.critic(hidden)



def adjust_lambda_reg_sin(lambda_reg, iteration, args, lambda_reg_warm_type, lambda_reg_warmup):
    if lambda_reg <= 0:
        return lambda_reg
    else:
        if lambda_reg_warm_type == "up":
            if iteration < args.num_iterations * lambda_reg_warmup:
                # Sine ramp-up strategy: from 0 to π/2, corresponding values rise from 0 to 1
                progress = iteration / (args.num_iterations * lambda_reg_warmup)
                lambda_reg = lambda_reg * math.sin(progress * math.pi / 2)
            else:
                lambda_reg = lambda_reg
        else:  # down
            if iteration < args.num_iterations * lambda_reg_warmup:
                lambda_reg = lambda_reg
            else:
                # Sine ramp-down strategy: from π/2 to π, corresponding values fall from 1 to 0
                progress = (iteration - args.num_iterations * lambda_reg_warmup) / (args.num_iterations * (1 - lambda_reg_warmup))
                lambda_reg = lambda_reg * math.cos(progress * math.pi / 2)
        return lambda_reg

def _sample_full_alpha(args: Args, num_pairs: int, device: torch.device) -> torch.Tensor:
    """Return alpha in [-1, 1].

    - If args.random_alpha: stratified random cosine sampling on [0, pi],
      output shape [num_pairs, resolution]. For resolution=4, this samples
      theta uniformly from [0,pi/4), [pi/4,pi/2), [pi/2,3pi/4), [3pi/4,pi),
      then maps alpha=-cos(theta).
    - Else: deterministic alpha values from precompute_chebyshev_matrix, shape [resolution].
    """
    resolution = int(args.resolution)
    if resolution < 3:
        raise ValueError(f"args.resolution must be >= 3, got {resolution}")

    if args.random_alpha:
        # stratified sampling over theta in [0, pi]
        bin_width = float(np.pi) / float(resolution)
        bins = torch.arange(resolution, device=device, dtype=torch.float32).view(1, -1)  # [1, R]
        jitter = torch.rand((num_pairs, resolution), device=device, dtype=torch.float32)  # [P, R]
        theta = (bins + jitter) * bin_width  # [P, R]
        alpha_full = -torch.cos(theta)
        return alpha_full.to(dtype=torch.float32)  # [P, R]

    cached = precompute_chebyshev_matrix(resolution, int(args.max_degree), device=str(device))
    return cached["alpha_values"].to(device=device, dtype=torch.float32)  # [resolution]


def _agent_actor_probs_and_value(agent: Agent, x: torch.Tensor):
    """x: [B,H,W,C] float/uint8 in [0,255]. returns probs [B,A], value [B,1]."""
    hidden = agent.network(x.permute((0, 3, 1, 2)) / 255.0)
    logits = agent.actor(hidden)
    probs = torch.softmax(logits, dim=-1)
    value = agent.critic(hidden)
    return probs, value


def compute_ppo_reg_loss(
    args: Args,
    agent: Agent,
    batch_obs: torch.Tensor,
    batch_actions: torch.Tensor,
):
    """Compute wd regularization for actor and critic on a minibatch.

    - sample `nums_pairs` pairs from minibatch
    - build interpolation path with random cosine alpha (optional)
    - All `resolution` points (including endpoints) are passed through the network.
    - optionally smooth & PCA (actor)
    - fit Chebyshev per channel via polynomial_regularization (float64 inside)
    """

    device = batch_obs.device
    batch_size = batch_obs.shape[0]
    num_pairs = int(min(args.nums_pairs, batch_size // 2))
    if num_pairs <= 0:
        zero = torch.tensor(0.0, device=device)
        return zero, zero

    resolution = int(args.resolution)
    n_inner = resolution - 2
    perm = torch.randperm(batch_size, device=device)
    idx1 = perm[:num_pairs]
    idx2 = perm[num_pairs : 2 * num_pairs]

    x1 = batch_obs[idx1].to(torch.float32)
    x2 = batch_obs[idx2].to(torch.float32)

    # build full interpolation path (including endpoints) and pass ALL points through the network
    if args.random_alpha:
        alpha_full = _sample_full_alpha(args, num_pairs=num_pairs, device=device)  # [P, R]
        alpha_mix_full = (alpha_full + 1.0) * 0.5  # [P, R] in [0,1]
    else:
        alpha_vec = _sample_full_alpha(args, num_pairs=num_pairs, device=device)  # [R]
        alpha_full = alpha_vec  # [R]
        alpha_mix_full = -0.1 + (alpha_vec + 1.0) * 0.6

    if args.random_alpha:
        alpha_mix_view = alpha_mix_full.view(num_pairs, resolution, 1, 1, 1)
    else:
        alpha_mix_view = alpha_mix_full.view(1, resolution, 1, 1, 1).expand(num_pairs, -1, -1, -1, -1)

    full_samples = x1.unsqueeze(1) + alpha_mix_view * (x2.unsqueeze(1) - x1.unsqueeze(1))  # [P,R,H,W,C]
    full_samples_flat = full_samples.reshape((-1,) + batch_obs.shape[1:])

    actor_probs_flat, values_flat = _agent_actor_probs_and_value(agent, full_samples_flat)
    num_actions = actor_probs_flat.shape[-1]
    actor_sequence = actor_probs_flat.view(num_pairs, resolution, num_actions)  # [P,R,A]
    critic_sequence = values_flat.view(num_pairs, resolution, 1)  # [P,R,1]


    if int(getattr(args, "pca_reg", 0) or 0) > 0:
        actor_sequence = pca(actor_sequence, num_pairs, int(args.pca_reg))  # [P,R,1]

    have_const = not args.remove_const

    def reg_wrapper(alpha, sample_output):
        return polynomial_regularization(
            alpha,
            sample_output,
            resolution,
            args.miu,
            int(args.max_degree),
            have_const,
            bool(args.use_norm),
            bool(args.random_alpha),
            bool(args.square),
            str(args.degree_mode),
        )

    if args.random_alpha:
        # alpha_full: [P,R]
        alpha_in_dims = 0
        alpha_inputs = alpha_full
    else:
        alpha_in_dims = None
        alpha_inputs = alpha_full

    batched_reg_func = vmap(reg_wrapper, in_dims=(alpha_in_dims, 0))
    reg_actor_terms = batched_reg_func(alpha_inputs, actor_sequence)
    reg_critic_terms = batched_reg_func(alpha_inputs, critic_sequence)

    return torch.mean(reg_actor_terms), torch.mean(reg_critic_terms)


if __name__ == "__main__":
    args = tyro.cli(Args)
    args.batch_size = int(args.num_envs * args.num_steps)
    args.minibatch_size = int(args.batch_size // args.num_minibatches)
    args.num_iterations = args.total_timesteps // args.batch_size
    # Print all experiment parameters
    print("Experiment Arguments:")
    for arg, value in vars(args).items():
        print(f"  {arg}: {value}")
    run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
    if args.track:
        import wandb

        wandb.init(
            project=args.wandb_project_name,
            entity=args.wandb_entity,
            sync_tensorboard=True,
            config=vars(args),
            name=run_name,
            monitor_gym=True,
            save_code=True,
        )
    writer = SummaryWriter(f"runs/{run_name}")
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
    )

    # TRY NOT TO MODIFY: seeding
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic

    if torch.cuda.is_available() and args.cuda:
        device = torch.device(f"cuda:{args.gpu_id}")
    else:
        device = torch.device("cpu")

    # env setup
    # envs = ProcgenEnv(num_envs=args.num_envs, env_name=args.env_id, num_levels=0, start_level=0, distribution_mode="easy")
    envs = ProcgenEnv(
        num_envs=args.num_envs, 
        env_name=args.env_id, 
        num_levels=args.num_train_levels,  # 500
        start_level=args.start_level,      # 0
        distribution_mode=args.distribution_mode,
        rand_seed=args.seed
    )

    envs = gym.wrappers.TransformObservation(envs, lambda obs: obs["rgb"])
    envs.single_action_space = envs.action_space
    envs.single_observation_space = envs.observation_space["rgb"]
    envs.is_vector_env = True
    envs = gym.wrappers.RecordEpisodeStatistics(envs)
    if args.capture_video:
        envs = gym.wrappers.RecordVideo(envs, f"videos/{run_name}")
    envs = gym.wrappers.NormalizeReward(envs, gamma=args.gamma)
    envs = gym.wrappers.TransformReward(envs, lambda reward: np.clip(reward, -10, 10))
    assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

    # ================= [Added] Initialize Test Environment =================
    # Note: Test environment is used for Zero-shot evaluation and should not overlap with training environment
    test_envs = ProcgenEnv(
        num_envs=args.num_envs, 
        env_name=args.env_id, 
        num_levels=args.num_test_levels,   # 0 means unlimited/random generation (for testing unseen levels)
        start_level=args.test_start_level, # 500 (avoiding training set 0-499)
        distribution_mode=args.distribution_mode,
        rand_seed=args.seed + 1000  # Different random seed from training environment
    )
    test_envs = gym.wrappers.TransformObservation(test_envs, lambda obs: obs["rgb"])
    test_envs.single_action_space = test_envs.action_space
    test_envs.single_observation_space = test_envs.observation_space["rgb"]
    test_envs.is_vector_env = True
    test_envs = gym.wrappers.RecordEpisodeStatistics(test_envs)
    # Test environment usually does not need NormalizeReward because we want the true raw scores
    # If input distribution consistency must be maintained, training environment statistics can be used (complex),
    # but usually for Zero-shot generalization testing, we directly look at the raw returns.
    test_envs = gym.wrappers.TransformReward(test_envs, lambda reward: np.clip(reward, -10, 10))


    agent = Agent(envs).to(device)
    optimizer = optim.Adam(
        agent.parameters(), 
        lr=args.learning_rate, 
        eps=1e-5, 
        weight_decay=args.l2_weight_decay if args.use_l2 else 0.0
    )

    # ALGO Logic: Storage setup
    obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
    actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
    logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
    rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
    dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
    values = torch.zeros((args.num_steps, args.num_envs)).to(device)

    # TRY NOT TO MODIFY: start the game
    global_step = 0
    start_time = time.time()
    next_obs = torch.Tensor(envs.reset()).to(device)
    next_done = torch.zeros(args.num_envs).to(device)

    for iteration in range(1, args.num_iterations + 1):
        # Annealing the rate if instructed to do so.
        if args.anneal_lr:
            frac = 1.0 - (iteration - 1.0) / args.num_iterations
            lrnow = frac * args.learning_rate
            optimizer.param_groups[0]["lr"] = lrnow

        for step in range(0, args.num_steps):
            global_step += args.num_envs
            obs[step] = next_obs
            dones[step] = next_done

            # ALGO LOGIC: action logic
            with torch.no_grad():
                action, logprob, _, value = agent.get_action_and_value(next_obs)
                values[step] = value.flatten()
            actions[step] = action
            logprobs[step] = logprob

            # TRY NOT TO MODIFY: execute the game and log data.
            next_obs, reward, next_done, info = envs.step(action.cpu().numpy())
            rewards[step] = torch.tensor(reward).to(device).view(-1)
            next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device)

            for item in info:
                if "episode" in item.keys():
                    print(f"global_step={global_step}, episodic_return={item['episode']['r']}")
                    writer.add_scalar("charts/episodic_return", item["episode"]["r"], global_step)
                    writer.add_scalar("charts/episodic_length", item["episode"]["l"], global_step)
                    break

        # bootstrap value if not done
        with torch.no_grad():
            next_value = agent.get_value(next_obs).reshape(1, -1)
            advantages = torch.zeros_like(rewards).to(device)
            lastgaelam = 0
            for t in reversed(range(args.num_steps)):
                if t == args.num_steps - 1:
                    nextnonterminal = 1.0 - next_done
                    nextvalues = next_value
                else:
                    nextnonterminal = 1.0 - dones[t + 1]
                    nextvalues = values[t + 1]
                delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
                advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
            returns = advantages + values

        # flatten the batch
        b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
        b_logprobs = logprobs.reshape(-1)
        b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
        b_advantages = advantages.reshape(-1)
        b_returns = returns.reshape(-1)
        b_values = values.reshape(-1)
        lambda_reg_actor = adjust_lambda_reg_sin(args.lambda_reg_actor, iteration, args, args.lambda_reg_actor_warm_type, args.lambda_reg_actor_warmup)
        lambda_reg_critic = adjust_lambda_reg_sin(args.lambda_reg_critic, iteration, args, args.lambda_reg_critic_warm_type, args.lambda_reg_critic_warmup)
        # Optimizing the policy and value network
        b_inds = np.arange(args.batch_size)
        clipfracs = []
        for epoch in range(args.update_epochs):
            np.random.shuffle(b_inds)
            for start in range(0, args.batch_size, args.minibatch_size):
                end = start + args.minibatch_size
                mb_inds = b_inds[start:end]

                _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds])
                logratio = newlogprob - b_logprobs[mb_inds]
                ratio = logratio.exp()

                with torch.no_grad():
                    # calculate approx_kl http://joschu.net/blog/kl-approx.html
                    old_approx_kl = (-logratio).mean()
                    approx_kl = ((ratio - 1) - logratio).mean()
                    clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]

                mb_advantages = b_advantages[mb_inds]
                if args.norm_adv:
                    mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)

                # Policy loss
                pg_loss1 = -mb_advantages * ratio
                pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
                pg_loss = torch.max(pg_loss1, pg_loss2).mean()

                # Value loss
                newvalue = newvalue.view(-1)
                if args.clip_vloss:
                    v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
                    v_clipped = b_values[mb_inds] + torch.clamp(
                        newvalue - b_values[mb_inds],
                        -args.clip_coef,
                        args.clip_coef,
                    )
                    v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
                    v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                    v_loss = 0.5 * v_loss_max.mean()
                else:
                    v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()

                entropy_loss = entropy.mean()
                loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef


                if (lambda_reg_actor > 0) or (lambda_reg_critic > 0):
                    reg_actor, reg_critic = compute_ppo_reg_loss(
                        args,
                        agent,
                        b_obs[mb_inds],
                        b_actions.long()[mb_inds],
                    )
                    if lambda_reg_actor > 0:
                        loss = loss + lambda_reg_actor * reg_actor
                    if lambda_reg_critic > 0:
                        loss = loss + lambda_reg_critic * reg_critic



                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
                optimizer.step()

            if args.target_kl is not None and approx_kl > args.target_kl:
                break

        y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
        var_y = np.var(y_true)
        explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y

        # TRY NOT TO MODIFY: record rewards for plotting purposes
        writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
        writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
        writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
        writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
        if (args.lambda_reg_actor > 0) or (args.lambda_reg_critic > 0):
            # reg_actor/reg_critic are defined in the last minibatch of the last epoch in this iteration
            # (kept lightweight; for detailed tracking, move add_scalar inside minibatch loop)
            try:
                writer.add_scalar("losses/reg_actor", float(reg_actor.detach().cpu().item()), global_step)
                writer.add_scalar("losses/reg_critic", float(reg_critic.detach().cpu().item()), global_step)
                writer.add_scalar("losses/lambda_reg_actor", float(lambda_reg_actor), global_step)
                writer.add_scalar("losses/lambda_reg_critic", float(lambda_reg_critic), global_step)
            except Exception:
                pass
        writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
        writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
        writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
        writer.add_scalar("losses/explained_variance", explained_var, global_step)

        # ================= [Added] Zero-shot Testing and TensorBoard Logging =================
        # Ensure args has test_interval, or set a default value here

        if iteration % args.test_interval == 0:
            agent.eval() # Switch to evaluation mode (affects BatchNorm / Dropout etc.)
            print(f"Starting evaluation at global step {global_step}...")
            
            # Store test episode rewards and lengths
            test_episode_rewards = []
            test_episode_lengths = []
            
            # Reset test environment
            t_obs = torch.Tensor(test_envs.reset()).to(device)
            
            # Run a certain number of test steps
            # Note: Procgen levels vary in length, so it's recommended to set a large number of steps to ensure enough episodes are collected
            NUM_TEST_STEPS = 4096 
            
            for _ in range(NUM_TEST_STEPS):
                with torch.no_grad():
                    # During testing, gradients are usually not needed
                    t_action, _, _, _ = agent.get_action_and_value(t_obs)
                
                t_obs, _, _, t_infos = test_envs.step(t_action.cpu().numpy())
                t_obs = torch.Tensor(t_obs).to(device)
                
                # Collect completed episode information
                for item in t_infos:
                    if "episode" in item.keys():
                        test_episode_rewards.append(item['episode']['r'])
                        test_episode_lengths.append(item['episode']['l'])

            # Calculate averages and log to TensorBoard
            if len(test_episode_rewards) > 0:
                mean_test_return = np.mean(test_episode_rewards)
                mean_test_length = np.mean(test_episode_lengths)
                
                print(f"Evaluated {len(test_episode_rewards)} episodes.")
                print(f"Mean Test Return: {mean_test_return:.2f}, Mean Test Length: {mean_test_length:.2f}")
                
                writer.add_scalar("charts/test_episodic_return", mean_test_return, global_step)
                writer.add_scalar("charts/test_episodic_length", mean_test_length, global_step)
                
            else:
                print("Evaluation finished but no episodes completed (increase NUM_TEST_STEPS).")

            agent.train() # Switch back to training mode
        # ==========================================================================


        print("SPS:", int(global_step / (time.time() - start_time)))
        writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

    envs.close()
    writer.close()
