import os
import random
import time
from collections import deque
from dataclasses import dataclass

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import tyro
from einops import rearrange
from torch.distributions import Categorical
from torch.utils.tensorboard import SummaryWriter
from environments.nasim_env import NASimWrapper


@dataclass
class Args:
    exp_name: str = os.path.basename(__file__)[: -len(".py")]
    """the name of this experiment"""
    seed: int = 4444
    """seed of the experiment"""
    torch_deterministic: bool = True
    """if toggled, `torch.backends.cudnn.deterministic=False`"""
    cuda: bool = True
    """if toggled, cuda will be enabled by default"""
    track: bool = False
    """if toggled, this experiment will be tracked with Weights and Biases"""
    wandb_project_name: str = "cleanRL"
    """the wandb's project name"""
    wandb_entity: str = None
    """the entity (team) of wandb's project"""
    save_model: bool = False
    """whether to save model into the `runs/{run_name}` folder"""

    # Algorithm specific arguments
    env_id: str = "StochPO-v0"
    """the id of the environment"""
    total_timesteps: int = 1000000
    """total timesteps of the experiments"""
    num_evals: int = 1
    """Number of policy evaluations to perform during training."""
    num_eval_episodes: int = 10
    """Number of episodes to run during each policy evaluation."""
    eval_freq: int = 100000
    """How many global step counts between evaluations."""
    init_lr: float = 2.75e-4
    """the initial learning rate of the optimizer"""
    final_lr: float = 1.0e-5
    """the final learning rate of the optimizer after linearly annealing"""
    num_envs: int = 1
    """the number of parallel game environments"""
    num_eval_envs: int = 1
    """the number of parallel game environments for evaluation"""
    num_steps: int = 256
    """the number of steps to run in each environment per policy rollout"""
    min_num_hosts: int = 5
    """Minimum number of hosts in NASim scenarios"""
    max_num_hosts: int = 8
    """Maximum number of hosts in NASim scenarios"""
    anneal_steps: int = num_envs * num_steps * 200 # We use approx 82% of the total timesteps for annealing
    """the number of steps to linearly anneal the learning rate and entropy coefficient from initial to final"""
    gamma: float = 0.99
    """the discount factor gamma"""
    gae_lambda: float = 0.99
    """the lambda for the general advantage estimation"""
    num_minibatches: int = 8
    """the number of mini-batches"""
    update_epochs: int = 4
    """the K epochs to update the policy"""
    norm_adv: bool = False
    """Toggles advantages normalization"""
    clip_coef: float = 0.2
    """the surrogate clipping coefficient"""
    clip_vloss: bool = True
    """Toggles whether or not to use a clipped loss for the value function, as per the paper."""
    init_ent_coef: float = 0.0001
    """initial coefficient of the entropy bonus"""
    final_ent_coef: float = 0.000001
    """final coefficient of the entropy bonus after linearly annealing"""
    vf_coef: float = 0.5
    """coefficient of the value function"""
    max_grad_norm: float = 0.25
    """the maximum norm for the gradient clipping"""
    target_kl: float = None
    """the target KL divergence threshold"""

    # Transformer-XL specific arguments
    trxl_num_layers: int = 2
    """the number of transformer layers"""
    trxl_num_heads: int = 1
    """the number of heads used in multi-head attention"""
    trxl_dim: int = 384
    """the dimension of the transformer"""
    trxl_memory_length: int = 256
    """the length of TrXL's sliding memory window"""
    trxl_positional_encoding: str = "learned"
    """the positional encoding type of the transformer, choices: "", "absolute", "learned" """
    reconstruction_coef: float = 0.0
    """the coefficient of the observation reconstruction loss, if set to 0.0 the reconstruction loss is not used"""

    # To be filled on runtime
    batch_size: int = 0
    """the batch size (computed in runtime)"""
    minibatch_size: int = 0
    """the mini-batch size (computed in runtime)"""
    num_iterations: int = 0
    """the number of iterations (computed in runtime)"""


def make_env(env_id, min_num_hosts, max_num_hosts, seed=None, render_mode=None):

    def thunk():
        env = NASimWrapper(env_id, seed, render_mode=render_mode, 
                           min_num_hosts=min_num_hosts, max_num_hosts=max_num_hosts)
        return gym.wrappers.RecordEpisodeStatistics(env)

    return thunk


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    # torch.nn.init.constant_(layer.bias, bias_const)
    return layer


def batched_index_select(input, dim, index):
    for ii in range(1, len(input.shape)):
        if ii != dim:
            index = index.unsqueeze(ii)
    expanse = list(input.shape)
    expanse[0] = -1
    expanse[dim] = -1
    index = index.expand(expanse)
    return torch.gather(input, dim, index)


class PositionalEncoding(nn.Module):
    def __init__(self, dim, min_timescale=2.0, max_timescale=1e4):
        super().__init__()
        freqs = torch.arange(0, dim, min_timescale)
        inv_freqs = max_timescale ** (-freqs / dim)
        self.register_buffer("inv_freqs", inv_freqs)

    def forward(self, seq_len):
        seq = torch.arange(seq_len - 1, -1, -1.0)
        sinusoidal_inp = rearrange(seq, "n -> n ()") * rearrange(self.inv_freqs, "d -> () d")
        pos_emb = torch.cat((sinusoidal_inp.sin(), sinusoidal_inp.cos()), dim=-1)
        return pos_emb


class MultiHeadAttention(nn.Module):
    """Multi Head Attention without dropout inspired by https://github.com/aladdinpersson/Machine-Learning-Collection"""

    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_size = embed_dim // num_heads

        assert self.head_size * num_heads == embed_dim, "Embedding dimension needs to be divisible by the number of heads"

        self.values = nn.Linear(self.head_size, self.head_size, bias=False)
        self.keys = nn.Linear(self.head_size, self.head_size, bias=False)
        self.queries = nn.Linear(self.head_size, self.head_size, bias=False)
        self.fc_out = nn.Linear(self.num_heads * self.head_size, embed_dim)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        values = values.reshape(N, value_len, self.num_heads, self.head_size)
        keys = keys.reshape(N, key_len, self.num_heads, self.head_size)
        query = query.reshape(N, query_len, self.num_heads, self.head_size)

        values = self.values(values)  # (N, value_len, heads, head_dim)
        keys = self.keys(keys)  # (N, key_len, heads, head_dim)
        queries = self.queries(query)  # (N, query_len, heads, heads_dim)

        # Dot-product
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

        # Mask padded indices so their attention weights become 0
        if mask is not None:
            energy = energy.masked_fill(mask.unsqueeze(1).unsqueeze(1) == 0, float("-1e20"))  # -inf causes NaN

        # Normalize energy values and apply softmax to retrieve the attention scores
        attention = torch.softmax(
            energy / (self.embed_dim ** (1 / 2)), dim=3
        )  # attention shape: (N, heads, query_len, key_len)

        # Scale values by attention weights
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, query_len, self.num_heads * self.head_size)

        return self.fc_out(out), attention


class TransformerLayer(nn.Module):
    def __init__(self, dim, num_heads):
        super().__init__()
        self.attention = MultiHeadAttention(dim, num_heads)
        self.layer_norm_q = nn.LayerNorm(dim)
        self.norm_kv = nn.LayerNorm(dim)
        self.layer_norm_attn = nn.LayerNorm(dim)
        self.fc_projection = nn.Sequential(nn.Linear(dim, dim), nn.ReLU())

    def forward(self, value, key, query, mask):
        # Pre-layer normalization (post-layer normalization is usually less effective)
        query_ = self.layer_norm_q(query)
        value = self.norm_kv(value)
        key = value  # K = V -> self-attention
        attention, attention_weights = self.attention(value, key, query_, mask)  # MHA
        x = attention + query  # Skip connection
        x_ = self.layer_norm_attn(x)  # Pre-layer normalization
        forward = self.fc_projection(x_)  # Forward projection
        out = forward + x  # Skip connection
        return out, attention_weights


class Transformer(nn.Module):
    def __init__(self, num_layers, dim, num_heads, max_episode_steps, positional_encoding):
        super().__init__()
        self.max_episode_steps = max_episode_steps
        self.positional_encoding = positional_encoding
        if positional_encoding == "absolute":
            self.pos_embedding = PositionalEncoding(dim)
        elif positional_encoding == "learned":
            self.pos_embedding = nn.Parameter(torch.randn(max_episode_steps, dim))
        self.transformer_layers = nn.ModuleList([TransformerLayer(dim, num_heads) for _ in range(num_layers)])

    def forward(self, x, memories, mask, memory_indices):
        # Add positional encoding to every transformer layer input
        if self.positional_encoding == "absolute":
            pos_embedding = self.pos_embedding(self.max_episode_steps)[memory_indices]
            memories = memories + pos_embedding.unsqueeze(2)
        elif self.positional_encoding == "learned":
            memories = memories + self.pos_embedding[memory_indices].unsqueeze(2)

        # Forward transformer layers and return new memories (i.e. hidden states)
        out_memories = []
        for i, layer in enumerate(self.transformer_layers):
            out_memories.append(x.detach())
            x, attention_weights = layer(
                memories[:, :, i], memories[:, :, i], x.unsqueeze(1), mask
            )  # args: value, key, query, mask
            x = x.squeeze()
            if len(x.shape) == 1:
                x = x.unsqueeze(0)
        return x, torch.stack(out_memories, dim=1)


class Agent(nn.Module):
    def __init__(self, args, observation_space, action_space_shape, max_episode_steps):
        super().__init__()
        self.obs_shape = observation_space.shape
        self.max_episode_steps = max_episode_steps
        print("Agent setup with max_episode_step=", max_episode_steps)

        if len(self.obs_shape) > 1:
            self.encoder = nn.Sequential(
                layer_init(nn.Conv2d(3, 32, 8, stride=4)),
                nn.ReLU(),
                layer_init(nn.Conv2d(32, 64, 4, stride=2)),
                nn.ReLU(),
                layer_init(nn.Conv2d(64, 64, 3, stride=1)),
                nn.ReLU(),
                nn.Flatten(),
                layer_init(nn.Linear(64 * 7 * 7, args.trxl_dim)),
                nn.ReLU(),
            )
        else:
            self.encoder = layer_init(nn.Linear(observation_space.shape[0], args.trxl_dim))

        self.transformer = Transformer(
            args.trxl_num_layers, args.trxl_dim, args.trxl_num_heads, self.max_episode_steps, args.trxl_positional_encoding
        )

        self.hidden_post_trxl = nn.Sequential(
            layer_init(nn.Linear(args.trxl_dim, args.trxl_dim)),
            nn.ReLU(),
        )

        self.actor_branches = nn.ModuleList(
            [
                layer_init(nn.Linear(args.trxl_dim, out_features=num_actions), np.sqrt(0.01))
                for num_actions in action_space_shape
            ]
        )
        self.critic = layer_init(nn.Linear(args.trxl_dim, 1), 1)

        if args.reconstruction_coef > 0.0:
            self.transposed_cnn = nn.Sequential(
                layer_init(nn.Linear(args.trxl_dim, 64 * 7 * 7)),
                nn.ReLU(),
                nn.Unflatten(1, (64, 7, 7)),
                layer_init(nn.ConvTranspose2d(64, 64, 3, stride=1)),
                nn.ReLU(),
                layer_init(nn.ConvTranspose2d(64, 32, 4, stride=2)),
                nn.ReLU(),
                layer_init(nn.ConvTranspose2d(32, 3, 8, stride=4)),
                nn.Sigmoid(),
            )

    def get_value(self, x, memory, memory_mask, memory_indices):
        if len(self.obs_shape) > 1:
            x = self.encoder(x.permute((0, 3, 1, 2)) / 255.0)
        else:
            x = self.encoder(x)
        x, _ = self.transformer(x, memory, memory_mask, memory_indices)
        x = self.hidden_post_trxl(x)
        return self.critic(x).flatten()

    def get_action_and_value(self, x, memory, memory_mask, memory_indices, action=None):
        if len(self.obs_shape) > 1:
            x = self.encoder(x.permute((0, 3, 1, 2)) / 255.0)
        else:
            x = self.encoder(x)
        x, memory = self.transformer(x, memory, memory_mask, memory_indices)
        x = self.hidden_post_trxl(x)
        self.x = x
        probs = [Categorical(logits=branch(x)) for branch in self.actor_branches]
        if action is None:
            action = torch.stack([dist.sample() for dist in probs], dim=1)
        log_probs = []
        for i, dist in enumerate(probs):
            log_probs.append(dist.log_prob(action[:, i]))
        entropies = torch.stack([dist.entropy() for dist in probs], dim=1).sum(1).reshape(-1)
        return action, torch.stack(log_probs, dim=1), entropies, self.critic(x).flatten(), memory

    def reconstruct_observation(self):
        x = self.transposed_cnn(self.x)
        return x.permute((0, 2, 3, 1))


def evaluate_policy(args, agent, envs, device, max_episode_steps):
    agent.eval()
    num_envs = args.num_eval_envs
    # Setup memory mask and indices (same for all envs)
    memory_mask = torch.tril(torch.ones((args.trxl_memory_length, args.trxl_memory_length)), diagonal=-1).to(device)
    repetitions = torch.repeat_interleave(
        torch.arange(0, args.trxl_memory_length).unsqueeze(0), 
        args.trxl_memory_length - 1, 
        dim=0
    ).long().to(device)
    memory_indices = torch.stack(
        [torch.arange(i, i + args.trxl_memory_length) 
        for i in range(max_episode_steps - args.trxl_memory_length + 1)]
    ).long().to(device)
    memory_indices = torch.cat((repetitions, memory_indices))
    
    # Statistics tracking
    episode_returns = []
    episode_lengths = []
    episodes_per_env = np.zeros(num_envs, dtype=np.int32)
    total_episodes_completed = 0
    target_episodes = args.num_eval_episodes

    # Initialize memories and episode tracking
    memory = torch.zeros((num_envs, max_episode_steps, args.trxl_num_layers, args.trxl_dim), 
                         dtype=torch.float32,
                         device=device)
    env_steps = np.zeros(num_envs, dtype=np.int32)
    current_returns = np.zeros(num_envs)
    
    # Reset environments
    obs, _ = envs.reset()
    obs = torch.tensor(obs).to(device)
    
    # Main evaluation loop
    with torch.no_grad(): # No gradient computation needed for evaluation
        while total_episodes_completed < target_episodes:
            # Identify active environments (those that haven't completed all episodes)
            active_envs = np.where(episodes_per_env < args.num_eval_episodes)[0]
            
            if len(active_envs) == 0:
                break
                
            # Prepare inputs for active environments
            active_obs = obs[active_envs]
            memory_windows = []
            masks = []
            indices_list = []
            
            for i in active_envs:
                t = env_steps[i]
                memory_window = memory[i:i+1, memory_indices[t]]
                mask = memory_mask[min(t, args.trxl_memory_length-1)].unsqueeze(0)
                index = memory_indices[t].unsqueeze(0)
                
                memory_windows.append(memory_window)
                masks.append(mask)
                indices_list.append(index)
            
            memory_batch = torch.cat(memory_windows, dim=0)
            mask_batch = torch.cat(masks, dim=0)
            indices_batch = torch.cat(indices_list, dim=0)
            
            # Get actions from the agent (stochastic)
            action, _, _, _, new_memory = agent.get_action_and_value(
                active_obs, memory_batch, mask_batch, indices_batch
            )
            
            # Update memories for active environments
            for idx, env_idx in enumerate(active_envs):
                memory[env_idx, env_steps[env_idx]] = new_memory[idx]
            
            # Create full action array
            full_actions = np.zeros((num_envs,) + tuple(action.shape[1:]), dtype=np.int64)
            for idx, env_idx in enumerate(active_envs):
                full_actions[env_idx] = action[idx].cpu().numpy()
            
            # Step environments
            next_obs, rewards, terminations, truncations, infos = envs.step(full_actions)
            dones = np.logical_or(terminations, truncations)
            
            # Update returns and steps
            for i in active_envs:
                current_returns[i] += rewards[i]
                env_steps[i] += 1
                
                # Check if episode completed
                if dones[i]:
                    # Record statistics
                    episode_returns.append(current_returns[i])
                    episode_lengths.append(env_steps[i])
                    
                    # Print results
                    episodes_per_env[i] += 1
                    total_episodes_completed += 1

                    # Reset for next episode
                    current_returns[i] = 0
                    env_steps[i] = 0
                    memory[i] = torch.zeros((max_episode_steps, args.trxl_num_layers, args.trxl_dim), 
                                            dtype=torch.float32,
                                            device=device)
            
            # Update observations
            obs = torch.Tensor(next_obs).to(device)

    # Print summary statistics
    if episode_returns:
        # Somehow we can collect more that target_episodes results, therefore we truncate
        if len(episode_returns) > target_episodes or len(episode_lengths) > target_episodes:
            episode_returns = episode_returns[:target_episodes]
            episode_lengths = episode_lengths[:target_episodes]
        print("\nEvaluation Summary:")
        print(f"Total episodes: {len(episode_returns)}")
        print(f"Mean return: {np.mean(episode_returns):.2f} ± {np.std(episode_returns):.2f}")
        print(f"Min/Max return: {np.min(episode_returns):.2f} / {np.max(episode_returns):.2f}")
        print(f"Mean episode length: {np.mean(episode_lengths):.2f}")

    agent.train()

    return episode_returns, episode_lengths

if __name__ == "__main__":
    args = tyro.cli(Args)
    args.batch_size = int(args.num_envs * args.num_steps)
    print("Batch size:", args.batch_size)
    args.minibatch_size = int(args.batch_size // args.num_minibatches)
    args.num_iterations = args.total_timesteps // args.batch_size
    print("Total number of iterations:", args.num_iterations)
    run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
    eval_counter = 0 # Number of already performed evaluations
    eval_counter_vec = [False] * args.num_evals
    eval_step_list = [step for step in range(args.eval_freq, args.total_timesteps+1, args.eval_freq)]

    # Tracking evaluation stats for saving in .npz files
    evaluations_results: list[list[float]] = []
    evaluations_length: list[list[int]] = []
    evaluations_timesteps: list[int] = []

    if args.track:
        import wandb

        wandb.init(
            project=args.wandb_project_name,
            entity=args.wandb_entity,
            sync_tensorboard=True,
            config=vars(args),
            name=run_name,
            monitor_gym=True,
            save_code=True,
        )
    writer = SummaryWriter(f"runs/{run_name}")
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
    )

    # TRY NOT TO MODIFY: seeding
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic

    # TODO: Re-enable CUDA after testing.
    # Determine the device to be used for training and set the default tensor type
    if args.cuda:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        torch.set_default_device(device)
    else:
        device = torch.device("cpu")

    print("Min_num_hosts:", args.min_num_hosts)
    print("Max_num_hosts:", args.max_num_hosts)

    # Environment setup
    envs = gym.vector.SyncVectorEnv(
        [make_env(args.env_id, args.min_num_hosts, args.max_num_hosts) for i in range(args.num_envs)],
    )
    # Evaluation environment setup
    eval_envs = gym.vector.SyncVectorEnv(
        [make_env(args.env_id, args.min_num_hosts, args.max_num_hosts) for i in range(args.num_envs)],
    )
    observation_space = envs.single_observation_space
    action_space_shape = (
        (envs.single_action_space.n,)
        if isinstance(envs.single_action_space, gym.spaces.Discrete)
        else tuple(envs.single_action_space.nvec)
    )
    env_ids = range(args.num_envs)
    env_current_episode_step = torch.zeros((args.num_envs,), dtype=torch.long)
    # Determine maximum episode steps
    max_episode_steps = envs.envs[0].max_episode_steps
    print(f"Max episode steps: {max_episode_steps}")
    # Set transformer memory length to max episode steps if greater than max episode steps
    args.trxl_memory_length = min(args.trxl_memory_length, max_episode_steps)

    agent = Agent(args, observation_space, action_space_shape, max_episode_steps).to(device)
    optimizer = optim.AdamW(agent.parameters(), lr=args.init_lr)
    bce_loss = nn.BCELoss()  # Binary cross entropy loss for observation reconstruction

    # ALGO Logic: Storage setup
    rewards = torch.zeros((args.num_steps, args.num_envs))
    actions = torch.zeros((args.num_steps, args.num_envs, len(action_space_shape)), dtype=torch.long)
    dones = torch.zeros((args.num_steps, args.num_envs))
    obs = torch.zeros((args.num_steps, args.num_envs) + observation_space.shape)
    log_probs = torch.zeros((args.num_steps, args.num_envs, len(action_space_shape)))
    values = torch.zeros((args.num_steps, args.num_envs))
    # The length of stored-memories is equal to the number of sampled episodes during training data sampling
    # (num_episodes, max_episode_length, num_layers, embed_dim)
    stored_memories = []
    # Memory mask used during attention
    stored_memory_masks = torch.zeros((args.num_steps, args.num_envs, args.trxl_memory_length), dtype=torch.bool)
    # Index to select the correct episode memory from stored_memories
    stored_memory_index = torch.zeros((args.num_steps, args.num_envs), dtype=torch.long)
    # Indices to slice the episode memories into windows
    stored_memory_indices = torch.zeros((args.num_steps, args.num_envs, args.trxl_memory_length), dtype=torch.long)

    # TRY NOT TO MODIFY: start the game
    global_step = 0
    start_time = time.time()
    episode_infos = deque(maxlen=100)  # Store episode results for monitoring statistics
    next_obs, _ = envs.reset(seed=args.seed)
    next_obs = torch.Tensor(next_obs).to(device)
    next_done = torch.zeros(args.num_envs)
    # Setup placeholders for each environments's current episodic memory
    next_memory = torch.zeros((args.num_envs, max_episode_steps, args.trxl_num_layers, args.trxl_dim), dtype=torch.float32)
    # Generate episodic memory mask used in attention
    memory_mask = torch.tril(torch.ones((args.trxl_memory_length, args.trxl_memory_length)), diagonal=-1)
    """ e.g. memory mask tensor looks like this if memory_length = 6
    0, 0, 0, 0, 0, 0
    1, 0, 0, 0, 0, 0
    1, 1, 0, 0, 0, 0
    1, 1, 1, 0, 0, 0
    1, 1, 1, 1, 0, 0
    1, 1, 1, 1, 1, 0
    """
    # Setup memory window indices to support a sliding window over the episodic memory
    repetitions = torch.repeat_interleave(
        torch.arange(0, args.trxl_memory_length).unsqueeze(0), args.trxl_memory_length - 1, dim=0
    ).long()
    memory_indices = torch.stack(
        [torch.arange(i, i + args.trxl_memory_length) for i in range(max_episode_steps - args.trxl_memory_length + 1)]
    ).long()
    memory_indices = torch.cat((repetitions, memory_indices))
    """ e.g. the memory window indices tensor looks like this if memory_length = 4 and max_episode_length = 7:
    0, 1, 2, 3
    0, 1, 2, 3
    0, 1, 2, 3
    0, 1, 2, 3
    1, 2, 3, 4
    2, 3, 4, 5
    3, 4, 5, 6
    """

    for iteration in range(1, args.num_iterations + 1):
        sampled_episode_infos = []

        # Annealing the learning rate and entropy coefficient if instructed to do so
        do_anneal = args.anneal_steps > 0 and global_step < args.anneal_steps
        frac = 1 - global_step / args.anneal_steps if do_anneal else 0
        lr = (args.init_lr - args.final_lr) * frac + args.final_lr
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr
        ent_coef = (args.init_ent_coef - args.final_ent_coef) * frac + args.final_ent_coef

        # Init episodic memory buffer using each environments' current episodic memory
        stored_memories = [next_memory[e] for e in range(args.num_envs)]
        for e in range(args.num_envs):
            stored_memory_index[:, e] = e

        for step in range(args.num_steps):
            global_step += args.num_envs

            # ALGO LOGIC: action logic
            with torch.no_grad():
                obs[step] = next_obs
                dones[step] = next_done
                stored_memory_masks[step] = memory_mask[torch.clip(env_current_episode_step, 0, args.trxl_memory_length - 1)]
                stored_memory_indices[step] = memory_indices[env_current_episode_step]
                # Retrieve the memory window from the entire episodic memory
                memory_window = batched_index_select(next_memory, 1, stored_memory_indices[step])
                action, logprob, _, value, new_memory = agent.get_action_and_value(
                    next_obs, memory_window, stored_memory_masks[step], stored_memory_indices[step]
                )
                next_memory[env_ids, env_current_episode_step] = new_memory
                # Store the action, log_prob, and value in the buffer
                actions[step], log_probs[step], values[step] = action, logprob, value

            # TRY NOT TO MODIFY: execute the game and log data.
            next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy())
            next_done = np.logical_or(terminations, truncations)
            rewards[step] = torch.tensor(reward).to(device).view(-1)
            next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device)

            # Reset and process episodic memory if done
            for id, done in enumerate(next_done):
                if done:
                    # Reset the environment's current timestep
                    env_current_episode_step[id] = 0
                    # Break the reference to the environment's episodic memory
                    mem_index = stored_memory_index[step, id]
                    stored_memories[mem_index] = stored_memories[mem_index].clone()
                    # Reset episodic memory
                    next_memory[id] = torch.zeros(
                        (max_episode_steps, args.trxl_num_layers, args.trxl_dim), dtype=torch.float32
                    )
                    if step < args.num_steps - 1:
                        # Store memory inside the buffer
                        stored_memories.append(next_memory[id])
                        # Store the reference of to the current episodic memory inside the buffer
                        stored_memory_index[step + 1 :, id] = len(stored_memories) - 1
                else:
                    # Increment environment timestep if not done
                    env_current_episode_step[id] += 1

            if "final_info" in infos:
                for info in infos["final_info"]:
                    if info and "episode" in info:
                        sampled_episode_infos.append(info["episode"])

        # Bootstrap value if not done
        with torch.no_grad():
            start = torch.clip(env_current_episode_step - args.trxl_memory_length, 0)
            end = torch.clip(env_current_episode_step, args.trxl_memory_length)
            indices = torch.stack([torch.arange(start[b], end[b]) for b in range(args.num_envs)]).long()
            memory_window = batched_index_select(next_memory, 1, indices)  # Retrieve the memory window from the entire episode
            next_value = agent.get_value(
                next_obs,
                memory_window,
                memory_mask[torch.clip(env_current_episode_step, 0, args.trxl_memory_length - 1)],
                stored_memory_indices[-1],
            )
            advantages = torch.zeros_like(rewards).to(device)
            lastgaelam = 0
            for t in reversed(range(args.num_steps)):
                if t == args.num_steps - 1:
                    nextnonterminal = 1.0 - next_done
                    nextvalues = next_value
                else:
                    nextnonterminal = 1.0 - dones[t + 1]
                    nextvalues = values[t + 1]
                delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
                advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
            returns = advantages + values

        # Flatten the batch
        b_obs = obs.reshape(-1, *obs.shape[2:])
        b_logprobs = log_probs.reshape(-1, *log_probs.shape[2:])
        b_actions = actions.reshape(-1, *actions.shape[2:])
        b_advantages = advantages.reshape(-1)
        b_returns = returns.reshape(-1)
        b_values = values.reshape(-1)
        b_memory_index = stored_memory_index.reshape(-1)
        b_memory_indices = stored_memory_indices.reshape(-1, *stored_memory_indices.shape[2:])
        b_memory_mask = stored_memory_masks.reshape(-1, *stored_memory_masks.shape[2:])
        stored_memories = torch.stack(stored_memories, dim=0)

        # Remove unnecessary padding from TrXL memory, if applicable
        actual_max_episode_steps = (stored_memory_indices * stored_memory_masks).max().item() + 1
        if actual_max_episode_steps < args.trxl_memory_length:
            b_memory_indices = b_memory_indices[:, :actual_max_episode_steps]
            b_memory_mask = b_memory_mask[:, :actual_max_episode_steps]
            stored_memories = stored_memories[:, :actual_max_episode_steps]

        # Optimizing the policy and value network
        clipfracs = []
        for epoch in range(args.update_epochs):
            b_inds = torch.randperm(args.batch_size)
            for start in range(0, args.batch_size, args.minibatch_size):
                end = start + args.minibatch_size
                mb_inds = b_inds[start:end]
                mb_memories = stored_memories[b_memory_index[mb_inds]]
                mb_memory_windows = batched_index_select(mb_memories, 1, b_memory_indices[mb_inds])

                _, newlogprob, entropy, newvalue, _ = agent.get_action_and_value(
                    b_obs[mb_inds], mb_memory_windows, b_memory_mask[mb_inds], b_memory_indices[mb_inds], b_actions[mb_inds]
                )

                # Policy loss
                mb_advantages = b_advantages[mb_inds]
                if args.norm_adv:
                    mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)
                mb_advantages = mb_advantages.unsqueeze(1).repeat(
                    1, len(action_space_shape)
                )  # Repeat is necessary for multi-discrete action spaces
                logratio = newlogprob - b_logprobs[mb_inds]
                ratio = torch.exp(logratio)
                pgloss1 = -mb_advantages * ratio
                pgloss2 = -mb_advantages * torch.clamp(ratio, 1.0 - args.clip_coef, 1.0 + args.clip_coef)
                pg_loss = torch.max(pgloss1, pgloss2).mean()

                # Value loss
                v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
                if args.clip_vloss:
                    v_loss_clipped = b_values[mb_inds] + (newvalue - b_values[mb_inds]).clamp(
                        min=-args.clip_coef, max=args.clip_coef
                    )
                    v_loss = torch.max(v_loss_unclipped, (v_loss_clipped - b_returns[mb_inds]) ** 2).mean()
                else:
                    v_loss = v_loss_unclipped.mean()

                # Entropy loss
                entropy_loss = entropy.mean()

                # Combined losses
                loss = pg_loss - ent_coef * entropy_loss + v_loss * args.vf_coef

                # Add reconstruction loss if used
                r_loss = torch.tensor(0.0)
                if args.reconstruction_coef > 0.0:
                    r_loss = bce_loss(agent.reconstruct_observation(), b_obs[mb_inds] / 255.0)
                    loss += args.reconstruction_coef * r_loss

                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(agent.parameters(), max_norm=args.max_grad_norm)
                optimizer.step()

                with torch.no_grad():
                    # calculate approx_kl http://joschu.net/blog/kl-approx.html
                    old_approx_kl = (-logratio).mean()
                    approx_kl = ((ratio - 1) - logratio).mean()
                    clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]

            if args.target_kl is not None and approx_kl > args.target_kl:
                break

        y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
        var_y = np.var(y_true)
        explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y

        # Log and monitor training statistics
        episode_infos.extend(sampled_episode_infos)
        episode_result = {}
        if len(episode_infos) > 0:
            for key in episode_infos[0].keys():
                episode_result[key + "_mean"] = np.mean([info[key] for info in episode_infos])

        print(
            "{:9} SPS={:4} return={:.2f} length={:.1f} pi_loss={:.3f} v_loss={:.3f} entropy={:.3f} r_loss={:.3f} value={:.3f} adv={:.3f}".format(
                iteration,
                int(global_step / (time.time() - start_time)),
                episode_result["r_mean"],
                episode_result["l_mean"],
                pg_loss.item(),
                v_loss.item(),
                entropy_loss.item(),
                r_loss.item(),
                torch.mean(values),
                torch.mean(advantages),
            )
        )

        # Perform policy evaluation
        # If our current gobal step is above the required steps for the n-th eval,
        # and we haven't evaluated yet, then eval
        print("Global step count:", global_step)
        if global_step >= eval_step_list[eval_counter] and not eval_counter_vec[eval_counter]:
            eval_counter += 1
            start_time_eval = time.time()
            episode_returns, episode_lengths = evaluate_policy(args, 
                                                               agent,
                                                               device=device, 
                                                               envs=eval_envs, 
                                                               max_episode_steps=max_episode_steps)
            score = np.mean(episode_returns)
            print("Evaluation {} r_mean={:.2f} l_mean={:.2f} time={:.2f}s".format(
                eval_counter, np.mean(episode_returns), np.mean(episode_lengths), time.time() - start_time_eval))
            writer.add_scalar("eval/mean_return", np.mean(episode_returns), global_step)
            writer.add_scalar("eval/std_return", np.std(episode_returns), global_step)
            writer.add_scalar("eval/mean_length", np.mean(episode_lengths), global_step)
            writer.add_scalar("eval/min_return", np.min(episode_returns), global_step)
            writer.add_scalar("eval/max_return", np.max(episode_returns), global_step)

            assert isinstance(episode_returns, list)
            assert isinstance(episode_lengths, list)

            evaluations_timesteps.append(args.eval_freq * eval_counter)
            evaluations_results.append(episode_returns)
            evaluations_length.append(episode_lengths)

            np.savez(
                f"runs/{run_name}/evaluations",
                timesteps=evaluations_timesteps,
                results=evaluations_results,
                ep_lengths=evaluations_length,
            )

            if args.save_model:
                model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
                model_data = {
                    "model_weights": agent.state_dict(),
                    "args": vars(args),
                }
                torch.save(model_data, model_path)
                print(f"model saved to {model_path}")

            eval_counter_vec[eval_counter-1] = True

        if episode_result:
            for key in episode_result:
                writer.add_scalar("episode/" + key, episode_result[key], global_step)
        writer.add_scalar("episode/value_mean", torch.mean(values), global_step)
        writer.add_scalar("episode/advantage_mean", torch.mean(advantages), global_step)
        writer.add_scalar("charts/learning_rate", lr, global_step)
        writer.add_scalar("charts/entropy_coefficient", ent_coef, global_step)
        writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
        writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
        writer.add_scalar("losses/loss", loss.item(), global_step)
        writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
        writer.add_scalar("losses/reconstruction_loss", r_loss.item(), global_step)
        writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
        writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
        writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
        writer.add_scalar("losses/explained_variance", explained_var, global_step)
        writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

    if args.save_model:
        model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
        model_data = {
            "model_weights": agent.state_dict(),
            "args": vars(args),
        }
        torch.save(model_data, model_path)
        print(f"model saved to {model_path}")

    writer.close()
    envs.close()
    eval_envs.close()