from __future__ import annotations
import time
import gymnasium as gym
import argparse
import multiprocessing as mp
import os
import torch
import torch as T
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from copy import deepcopy
from functools import partial
from typing import Any, SupportsFloat
from gymnasium.core import WrapperActType, WrapperObsType
from gymnasium.spaces import Box
from torch import nn as nn, Tensor
from torch.nn import init
from math import sqrt
import math
from gymnasium.spaces import Box, Tuple
import matplotlib.pyplot as plt


############################################ Tensor Queue

class TensorQueue:
    def __init__(self, batch_size, context_length, context, device='cpu', dtype=torch.float32):
        """
        Initializes the TensorQueue with zeros.

        Args:
            batch_size (int): Number of samples in the batch.
            context_length (int): Length of the context (queue size).
            X (int): Feature dimension.
            device (str): Device to store the tensor.
            dtype (torch.dtype): Data type of the tensor.
        """
        self.batch_size = batch_size
        self.context_length = context_length
        self.context = context
        self.device = device
        self.dtype = dtype
        # Initialize the queue with zeros
        self.queue = torch.zeros(batch_size, context_length, context[0], context[1], context[2], device=device,
                                 dtype=dtype)

    def enqueue(self, new_data, batch_idx=None):
        """
        Adds new data to the queue. If batch_idx is None, enqueues to all batch indices.
        If batch_idx is specified, enqueues only to the given batch index.

        Args:
            new_data (torch.Tensor): Tensor of shape [batch, X] for global enqueue
                                      or [X] for selective enqueue.
            batch_idx (int, optional): Index of the batch to enqueue to. Defaults to None.
        """
        if batch_idx is None:
            # Global enqueue: new_data should have shape [batch, context]
            # if new_data.shape != (self.batch_size, self.context[0], self.context[1], self.context[2]):
            #     raise ValueError(f"Expected new_data shape [{self.batch_size}, {self.context}], but got {new_data.shape}")

            # Expand new_data to [batch, 1, context] for concatenation
            new_data_expanded = new_data.unsqueeze(1)  # Shape: [batch, 1, context]

            # Shift the queue left by one and append the new data
            with torch.no_grad():
                self.queue = torch.cat([self.queue[:, 1:, :], new_data_expanded], dim=1)
        else:
            # Selective enqueue: new_data should have shape [X]
            if not isinstance(batch_idx, int):
                raise TypeError(f"batch_idx must be an int, but got {type(batch_idx)}")
            if not (0 <= batch_idx < self.batch_size):
                raise IndexError(f"batch_idx {batch_idx} out of range for batch size {self.batch_size}")

            # Shift the queue left by one for the specified batch index
            with torch.no_grad():
                self.queue[batch_idx, :-1, :] = self.queue[batch_idx, 1:, :].clone()
                # Assign the new data to the last position
                self.queue[batch_idx, -1, :] = new_data

    def get_queue(self):
        """
        Retrieves the current state of the queue.

        Returns:
            torch.Tensor: The queue tensor of shape [batch, context_length, X].
        """
        return self.queue

    def reset(self):
        """
        Resets the queue to all zeros.
        """
        with torch.no_grad():
            self.queue.zero_()


############################################## Networks Section

class FactorizedNoisyLinear(nn.Module):
    """ The factorized Gaussian noise layer for noisy-nets dqn. """

    def __init__(self, in_features: int, out_features: int, sigma_0=0.5, self_norm=False) -> None:
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.sigma_0 = sigma_0

        # weight: w = \mu^w + \sigma^w . \epsilon^w
        self.weight_mu = nn.Parameter(torch.empty(out_features, in_features))
        self.weight_sigma = nn.Parameter(torch.empty(out_features, in_features))
        self.register_buffer('weight_epsilon', torch.empty(out_features, in_features))

        # bias: b = \mu^b + \sigma^b . \epsilon^b
        self.bias_mu = nn.Parameter(torch.empty(out_features))
        self.bias_sigma = nn.Parameter(torch.empty(out_features))
        self.register_buffer('bias_epsilon', torch.empty(out_features))

        if self_norm:
            self.reset_parameters_self_norm()
        else:
            self.reset_parameters()
        self.reset_noise()

        self.disable_noise()

    @torch.no_grad()
    def reset_parameters(self) -> None:
        # initialization is similar to Kaiming uniform (He. initialization) with fan_mode=fan_in
        scale = 1 / sqrt(self.in_features)

        init.uniform_(self.weight_mu, -scale, scale)
        init.uniform_(self.bias_mu, -scale, scale)

        init.constant_(self.weight_sigma, self.sigma_0 * scale)
        init.constant_(self.bias_sigma, self.sigma_0 * scale)

    @torch.no_grad()
    def reset_parameters_self_norm(self) -> None:
        # initialization is similar to Kaiming uniform (He. initialization) with fan_mode=fan_in

        nn.init.normal_(self.weight_mu, std=1 / math.sqrt(self.out_features))
        if self.bias_mu is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_mu)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias_mu, -bound, bound)

    @torch.no_grad()
    def _get_noise(self, size: int) -> Tensor:
        noise = torch.randn(size, device=self.weight_mu.device)
        # f(x) = sgn(x)sqrt(|x|)
        return noise.sign().mul_(noise.abs().sqrt_())

    @torch.no_grad()
    def reset_noise(self) -> None:
        # like in eq 10 and 11 of the paper
        epsilon_in = self._get_noise(self.in_features)
        epsilon_out = self._get_noise(self.out_features)
        self.weight_epsilon.copy_(epsilon_out.outer(epsilon_in))
        self.bias_epsilon.copy_(epsilon_out)

    @torch.no_grad()
    def disable_noise(self) -> None:
        self.weight_epsilon[:] = 0
        self.bias_epsilon[:] = 0

    def forward(self, input: Tensor) -> Tensor:
        # y = wx + d, where
        # w = \mu^w + \sigma^w * \epsilon^w
        # b = \mu^b + \sigma^b * \epsilon^b
        return F.linear(input,
                        self.weight_mu + self.weight_sigma * self.weight_epsilon,
                        self.bias_mu + self.bias_sigma * self.bias_epsilon)


class NatureC51(nn.Module):
    """
    Implementation of the large variant of the IMPALA CNN introduced in Espeholt et al. (2018).
    """

    def __init__(self, in_depth, actions, atoms=51, Vmin=-10, Vmax=10, device='cuda:0',
                 linear_size=512, lstm_out_size=256,
                 maxpool=False):
        super().__init__()

        self.start = time.time()
        self.actions = actions
        self.device = device
        self.maxpool = maxpool

        self.atoms = atoms

        self.lstm_hidden_size = lstm_out_size

        self.in_depth = in_depth

        self.linear_size = linear_size

        DELTA_Z = (Vmax - Vmin) / (atoms - 1)

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=in_depth, out_channels=32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
            nn.ReLU(),
        )

        conv_out_size = 3136
        self.conv_out_size = conv_out_size
        self.conv_lstm_out_size = self.conv_out_size

        self.lstm_projection = nn.Linear(self.lstm_hidden_size, self.conv_out_size)

        self.up_projection_act = nn.Sigmoid()

        self.lstm = nn.LSTM(
            input_size=self.conv_out_size,
            hidden_size=self.lstm_hidden_size,
            num_layers=1,
            batch_first=True  # Expect input of shape (batch_size, seq_len, input_size)
        )

        self.fc1V = FactorizedNoisyLinear(self.conv_lstm_out_size, linear_size)
        self.fc1A = FactorizedNoisyLinear(self.conv_lstm_out_size, linear_size)
        self.fcV2 = FactorizedNoisyLinear(linear_size, self.atoms)
        self.fcA2 = FactorizedNoisyLinear(linear_size, actions * self.atoms)

        self.register_buffer("supports", torch.arange(Vmin, Vmax + DELTA_Z, DELTA_Z))
        self.softmax = nn.Softmax(dim=1)

        self.to(device)

    def fc_val(self, x):
        x = F.relu(self.fc1V(x))
        x = self.fcV2(x)

        return x

    def fc_adv(self, x):
        x = F.relu(self.fc1A(x))
        x = self.fcA2(x)

        return x

    def _get_conv_out(self, shape):
        o = self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size()))

    def forward(self, x):
        inputt = x.float() / 255

        # inputt currently has shape (batch, sequence, framestack, x, y)
        # this doesn't work with torch CNN, so need to reshape to 4D, then reshape back to 5D
        batch_size, seq_len, C, X, Y = inputt.shape
        x = self.conv(inputt.contiguous().view(batch_size * seq_len, C, X, Y))
        # reshape back:
        x = x.reshape(batch_size, seq_len, -1)

        # get the final output for the CNN stream. The entire of x is still used for LSTM stream
        conv_out = x[:, -1, :].clone()

        # pass through lstm
        lstm_out, _ = self.lstm(x.detach())  # this has a tanh on the end anyway, don't need another activation

        # get final output
        lstm_out = lstm_out[:, -1, :]

        # upscale layer and activation
        lstm_out = self.lstm_projection(lstm_out)
        lstm_out = self.up_projection_act(lstm_out)

        # combine CNN with LSTM output
        x = conv_out * lstm_out

        val_out = self.fc_val(x).view(batch_size, 1, self.atoms)
        adv_out = self.fc_adv(x).view(batch_size, -1, self.atoms)
        adv_mean = adv_out.mean(dim=1, keepdim=True)
        return val_out + (adv_out - adv_mean)

    def both(self, x):
        cat_out = self(x)
        probs = self.apply_softmax(cat_out)
        weights = probs * self.supports
        res = weights.sum(dim=2)
        return cat_out, res

    def qvals(self, x, advantages_only=False):
        return self.both(x)[1]

    def apply_softmax(self, t):
        return self.softmax(t.view(-1, self.atoms)).view(t.size())

    def save_checkpoint(self, name):
        # print('... saving checkpoint ...')
        torch.save(self.state_dict(), name + ".model")

    def load_checkpoint(self, name):
        # print('... loading checkpoint ...')
        self.load_state_dict(torch.load(name))


#################################### Full Res Wrapper

class MultiObsAtariPreprocessing(gym.wrappers.AtariPreprocessing):

    def __init__(self, env, **kwargs):
        super().__init__(env, **kwargs)

        self.rgb_obs = np.zeros((210, 160, 3), dtype=np.uint8)
        self.observation_space = Tuple(
            (self.observation_space, Box(low=0, high=255, shape=(210, 160, 3), dtype=np.uint8)))

    def _get_obs(self):
        self.ale.getScreenRGB(self.rgb_obs)
        return super()._get_obs(), self.rgb_obs


class MultiObsAtariFrameStackObservation(gym.Wrapper, gym.utils.RecordConstructorArgs):

    def __init__(self, env: gym.Env, stack_size: int):
        gym.utils.RecordConstructorArgs.__init__(self, stack_size=stack_size)
        gym.Wrapper.__init__(self, env)

        if not np.issubdtype(type(stack_size), np.integer):
            raise TypeError(
                f"The stack_size is expected to be an integer, actual type: {type(stack_size)}"
            )
        if not 0 < stack_size:
            raise ValueError(
                f"The stack_size needs to be greater than zero, actual value: {stack_size}"
            )

        self.stack_size = stack_size
        self.obs_stack = np.zeros((stack_size, 84, 84), dtype=np.uint8)
        self.observation_space = Tuple((Box(low=0, high=255, shape=(4, 84, 84), dtype=np.uint8),
                                        Box(low=0, high=255, shape=(210, 160, 3), dtype=np.uint8)))

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)

        self.obs_stack = np.roll(self.obs_stack, -1, axis=0)
        self.obs_stack[-1] = obs[0]

        return (self.obs_stack, obs[1]), reward, terminated, truncated, info

    def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None):
        obs, info = self.env.reset(seed=seed, options=options)

        self.obs_stack = np.zeros((self.stack_size, 84, 84), dtype=np.uint8)
        self.obs_stack[-1] = obs[0]

        return (self.obs_stack, obs[1]), info


################# Atari Processing Section (by default, this is effectively identical to that of gymnasium's wrapper)

__all__ = ["AtariPreprocessing"]


class AtariPreprocessingCustom(gym.Wrapper, gym.utils.RecordConstructorArgs):
    """Implements the common preprocessing techniques for Atari environments (excluding frame stacking).

    For frame stacking use :class:`gymnasium.wrappers.FrameStackObservation`.
    No vector version of the wrapper exists

    This class follows the guidelines in Machado et al. (2018),
    "Revisiting the Arcade Learning Environment: Evaluation Protocols and Open Problems for General Agents".

    Specifically, the following preprocess stages applies to the atari environment:

    - Noop Reset: Obtains the initial state by taking a random number of no-ops on reset, default max 30 no-ops.
    - Frame skipping: The number of frames skipped between steps, 4 by default.
    - Max-pooling: Pools over the most recent two observations from the frame skips.
    - Termination signal when a life is lost: When the agent losses a life during the environment, then the environment is terminated.
        Turned off by default. Not recommended by Machado et al. (2018).
    - Resize to a square image: Resizes the atari environment original observation shape from 210x180 to 84x84 by default.
    - Grayscale observation: Makes the observation greyscale, enabled by default.
    - Grayscale new axis: Extends the last channel of the observation such that the image is 3-dimensional, not enabled by default.
    - Scale observation: Whether to scale the observation between [0, 1) or [0, 255), not scaled by default.

    Example:
        >>> import gymnasium as gym # doctest: +SKIP
        >>> env = gym.make("ALE/Adventure-v5") # doctest: +SKIP
        >>> env = AtariPreprocessing(env, noop_max=10, frame_skip=0, screen_size=84, terminal_on_life_loss=True, grayscale_obs=False, grayscale_newaxis=False) # doctest: +SKIP

    Change logs:
     * Added in gym v0.12.2 (gym #1455)


     This version is slightly modified, to include the option to use life information.
     This however is NOT enabled by default, and should NEVER be used when comparing in papers.
     If you use this setting, you are scumbag.
    """

    def __init__(
            self,
            env: gym.Env,
            noop_max: int = 30,
            frame_skip: int = 4,
            screen_size: int = 84,
            terminal_on_life_loss: bool = False,
            life_information: bool = False,
            grayscale_obs: bool = True,
            grayscale_newaxis: bool = False,
            scale_obs: bool = False,
    ):
        """Wrapper for Atari 2600 preprocessing.

        Args:
            env (Env): The environment to apply the preprocessing
            noop_max (int): For No-op reset, the max number no-ops actions are taken at reset, to turn off, set to 0.
            frame_skip (int): The number of frames between new observation the agents observations effecting the frequency at which the agent experiences the game.
            screen_size (int): resize Atari frame.
            terminal_on_life_loss (bool): `if True`, then :meth:`step()` returns `terminated=True` whenever a
                life is lost.
            grayscale_obs (bool): if True, then gray scale observation is returned, otherwise, RGB observation
                is returned.
            grayscale_newaxis (bool): `if True and grayscale_obs=True`, then a channel axis is added to
                grayscale observations to make them 3-dimensional.
            scale_obs (bool): if True, then observation normalized in range [0,1) is returned. It also limits memory
                optimization benefits of FrameStack Wrapper.

        Raises:
            DependencyNotInstalled: opencv-python package not installed
            ValueError: Disable frame-skipping in the original env
        """
        gym.utils.RecordConstructorArgs.__init__(
            self,
            noop_max=noop_max,
            frame_skip=frame_skip,
            screen_size=screen_size,
            terminal_on_life_loss=terminal_on_life_loss,
            grayscale_obs=grayscale_obs,
            grayscale_newaxis=grayscale_newaxis,
            scale_obs=scale_obs,
        )
        gym.Wrapper.__init__(self, env)

        try:
            import cv2  # noqa: F401
        except ImportError:
            raise gym.error.DependencyNotInstalled(
                "opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari"
            )

        assert frame_skip > 0
        assert screen_size > 0
        assert noop_max >= 0
        if frame_skip > 1 and getattr(env.unwrapped, "_frameskip", None) != 1:
            raise ValueError(
                "Disable frame-skipping in the original env. Otherwise, more than one frame-skip will happen as through this wrapper"
            )
        self.noop_max = noop_max
        assert env.unwrapped.get_action_meanings()[0] == "NOOP"

        self.frame_skip = frame_skip
        self.screen_size = screen_size
        self.terminal_on_life_loss = terminal_on_life_loss
        self.life_information = life_information
        self.grayscale_obs = grayscale_obs
        self.grayscale_newaxis = grayscale_newaxis
        self.scale_obs = scale_obs

        # buffer of most recent two observations for max pooling
        assert isinstance(env.observation_space, Box)
        if grayscale_obs:
            self.obs_buffer = [
                np.empty(env.observation_space.shape[:2], dtype=np.uint8),
                np.empty(env.observation_space.shape[:2], dtype=np.uint8),
            ]
        else:
            self.obs_buffer = [
                np.empty(env.observation_space.shape, dtype=np.uint8),
                np.empty(env.observation_space.shape, dtype=np.uint8),
            ]

        self.lives = 0
        self.game_over = False

        _low, _high, _obs_dtype = (
            (0, 255, np.uint8) if not scale_obs else (0, 1, np.float32)
        )
        _shape = (screen_size, screen_size, 1 if grayscale_obs else 3)
        if grayscale_obs and not grayscale_newaxis:
            _shape = _shape[:-1]  # Remove channel axis
        self.observation_space = Box(
            low=_low, high=_high, shape=_shape, dtype=_obs_dtype
        )

    @property
    def ale(self):
        """Make ale as a class property to avoid serialization error."""
        return self.env.unwrapped.ale

    def step(
            self, action: WrapperActType
    ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
        """Applies the preprocessing for an :meth:`env.step`."""
        total_reward, terminated, truncated, info = 0.0, False, False, {}

        for t in range(self.frame_skip):
            _, reward, terminated, truncated, info = self.env.step(action)
            total_reward += reward
            self.game_over = terminated

            # using agent terminal without resetting env require terminal_on_life_loss == False

            if self.terminal_on_life_loss:
                new_lives = self.ale.lives()
                terminated = terminated or new_lives < self.lives
                self.game_over = terminated
                self.lives = new_lives
                info["lost_life"] = False
            elif self.life_information:
                # code allows telling agent if they lost a life, so they can learn from terminal
                new_lives = self.ale.lives()
                info["lost_life"] = new_lives < self.lives
                self.lives = new_lives

                # added this line to not finish the 4 actions and overwrite info!
                if info["lost_life"]:
                    break
            else:
                info["lost_life"] = False

            if terminated or truncated:
                break
            if t == self.frame_skip - 2:
                if self.grayscale_obs:
                    self.ale.getScreenGrayscale(self.obs_buffer[1])
                else:
                    self.ale.getScreenRGB(self.obs_buffer[1])
            elif t == self.frame_skip - 1:
                if self.grayscale_obs:
                    self.ale.getScreenGrayscale(self.obs_buffer[0])
                else:
                    self.ale.getScreenRGB(self.obs_buffer[0])
        return self._get_obs(), total_reward, terminated, truncated, info

    def reset(
            self, *, seed: int | None = None, options: dict[str, Any] | None = None
    ) -> tuple[WrapperObsType, dict[str, Any]]:
        """Resets the environment using preprocessing."""
        # NoopReset
        _, reset_info = self.env.reset(seed=seed, options=options)

        noops = (
            self.env.unwrapped.np_random.integers(1, self.noop_max + 1)
            if self.noop_max > 0
            else 0
        )
        for _ in range(noops):
            _, _, terminated, truncated, step_info = self.env.step(0)
            reset_info.update(step_info)
            if terminated or truncated:
                _, reset_info = self.env.reset(seed=seed, options=options)

        self.lives = self.ale.lives()
        if self.grayscale_obs:
            self.ale.getScreenGrayscale(self.obs_buffer[0])
        else:
            self.ale.getScreenRGB(self.obs_buffer[0])
        self.obs_buffer[1].fill(0)

        return self._get_obs(), reset_info

    def _get_obs(self):
        if self.frame_skip > 1:  # more efficient in-place pooling
            np.maximum(self.obs_buffer[0], self.obs_buffer[1], out=self.obs_buffer[0])

        import cv2

        obs = cv2.resize(
            self.obs_buffer[0],
            (self.screen_size, self.screen_size),
            interpolation=cv2.INTER_AREA,
        )

        if self.scale_obs:
            obs = np.asarray(obs, dtype=np.float32) / 255.0
        else:
            obs = np.asarray(obs, dtype=np.uint8)

        if self.grayscale_obs and self.grayscale_newaxis:
            obs = np.expand_dims(obs, axis=-1)  # Add a channel axis
        return obs


################# Now Entering the Prioritized Experience Replay Section

# SumTree
# a binary tree data structure where the parent’s value is the sum of its children
class SumTree:
    def __init__(self, size, procgen=False):
        self.index = 0
        self.size = size
        self.full = False  # Used to track actual capacity
        self.tree_start = 2 ** (size - 1).bit_length() - 1  # Put all used node leaves on last tree level
        self.sum_tree = np.zeros((self.tree_start + self.size,), dtype=np.float32)
        self.max = 1  # Initial max value to return (1 = 1^ω)

    # Updates nodes values from current tree
    def _update_nodes(self, indices):
        children_indices = indices * 2 + np.expand_dims([1, 2], axis=1)
        self.sum_tree[indices] = np.sum(self.sum_tree[children_indices], axis=0)

    # Propagates changes up tree given tree indices
    def _propagate(self, indices):
        parents = (indices - 1) // 2
        unique_parents = np.unique(parents)
        self._update_nodes(unique_parents)
        if parents[0] != 0:
            self._propagate(parents)

    # Propagates single value up tree given a tree index for efficiency
    def _propagate_index(self, index):
        parent = (index - 1) // 2
        left, right = 2 * parent + 1, 2 * parent + 2
        self.sum_tree[parent] = self.sum_tree[left] + self.sum_tree[right]
        if parent != 0:
            self._propagate_index(parent)

    # Updates values given tree indices
    def update(self, indices, values):
        self.sum_tree[indices] = values  # Set new values
        self._propagate(indices)  # Propagate values
        current_max_value = np.max(values)
        self.max = max(current_max_value, self.max)

    # Updates single value given a tree index for efficiency
    def _update_index(self, index, value):
        self.sum_tree[index] = value  # Set new value
        self._propagate_index(index)  # Propagate value
        self.max = max(value, self.max)

    def append(self, value):
        self._update_index(self.index + self.tree_start, value)  # Update tree
        self.index = (self.index + 1) % self.size  # Update index
        self.full = self.full or self.index == 0  # Save when capacity reached
        self.max = max(value, self.max)

    # Searches for the location of values in sum tree
    def _retrieve(self, indices, values):
        children_indices = (indices * 2 + np.expand_dims([1, 2], axis=1))  # Make matrix of children indices
        # If indices correspond to leaf nodes, return them
        if children_indices[0, 0] >= self.sum_tree.shape[0]:
            return indices
        # If children indices correspond to leaf nodes, bound rare outliers in case total slightly overshoots
        elif children_indices[0, 0] >= self.tree_start:
            children_indices = np.minimum(children_indices, self.sum_tree.shape[0] - 1)
        left_children_values = self.sum_tree[children_indices[0]]
        successor_choices = np.greater(values, left_children_values).astype(
            np.int32)  # Classify which values are in left or right branches
        successor_indices = children_indices[
            successor_choices, np.arange(indices.size)]  # Use classification to index into the indices matrix
        successor_values = values - successor_choices * left_children_values  # Subtract the left branch values when searching in the right branch
        return self._retrieve(successor_indices, successor_values)

    # Searches for values in sum tree and returns values, data indices and tree indices
    def find(self, values):
        indices = self._retrieve(np.zeros(values.shape, dtype=np.int32), values)
        data_index = indices - self.tree_start
        return (self.sum_tree[indices], data_index, indices)  # Return values, data indices, tree indices

    def total(self):
        return self.sum_tree[0]


def save_image(tensor, filename):
    """
    Save a single image tensor to a file.
    Expects the tensor to be of shape (C, H, W) where C represents the time channel.
    If C is 1 (grayscale), we squeeze it out and use a grayscale colormap.
    """
    # Ensure the tensor is on CPU, and rearrange to (H, W, C)
    img = tensor.cpu().permute(1, 2, 0).numpy()
    tot = np.sum(img)
    print(f"Image number: {tot}")


class PER:
    def __init__(self, size, device, n, envs, gamma, alpha=0.2, beta=0.4, framestack=4, imagex=84, imagey=84, rgb=False,
                 context=10):

        self.st = SumTree(size)
        self.data = [None for _ in range(size)]
        self.index = 0
        self.size = size

        # this is the number of frames, not the number of transitions
        # the technical size to ensure there are errors with overwritten memory in theory is very high-
        # (2*framestack - overlap) * first_states + non_first_states
        # with N=3, framestack=4, size=1M, average ep length 20, we need a total frame storage of around 1.35M
        # this however is still pretty light given it uses discrete memory. Careful when using RGB though, as we don't need as much memory
        if rgb:
            self.storage_size = int(size * 4)
        else:
            self.storage_size = int(size * 1.25)
        self.gamma = gamma
        self.capacity = 0

        self.point_mem_idx = 0

        self.state_mem_idx = 0
        self.reward_mem_idx = 0
        self.encoding_mem_idx = 0

        self.imagex = imagex
        self.imagey = imagey

        self.max_prio = 1

        self.framestack = framestack

        self.alpha = alpha
        self.beta = beta
        self.eps = 1e-6  # small constant to stop 0 probability
        self.device = device

        self.last_terminal = [True for i in range(envs)]
        self.tstep_counter = [0 for i in range(envs)]

        self.context = context

        self.n_step = n
        self.state_buffer = [[] for i in range(envs)]
        self.reward_buffer = [[] for i in range(envs)]
        self.encodings_buffer = [[] for i in range(envs)]

        if rgb:
            self.state_mem = np.zeros((self.storage_size, 3, self.imagex, self.imagey), dtype=np.uint8)
        else:
            self.state_mem = np.zeros((self.storage_size, self.imagex, self.imagey), dtype=np.uint8)
        self.action_mem = np.zeros(self.storage_size, dtype=np.int64)
        self.reward_mem = np.zeros(self.storage_size, dtype=float)
        self.done_mem = np.zeros(self.storage_size, dtype=bool)
        self.trun_mem = np.zeros(self.storage_size, dtype=bool)

        # everything here is stored as ints as they are just pointers to the actual memory
        # reward contains N values. The first value contains the action. The set of N contains the pointers for both
        # the reward and dones
        self.trans_dtype = np.dtype([('state', int, self.framestack), ('n_state', int, self.framestack),
                                     ('reward', int, self.n_step),
                                     ('encodings', int, self.context + self.n_step + self.framestack - 1)])

        self.blank_trans = (np.zeros(self.framestack, dtype=int), np.zeros(self.framestack, dtype=int),
                            np.zeros(self.n_step, dtype=int),
                            np.zeros(self.context + self.n_step + self.framestack - 1, dtype=int))

        self.pointer_mem = np.array([self.blank_trans] * size, dtype=self.trans_dtype)

        self.overlap = self.framestack - self.n_step

        # self.priority_min = [float('inf') for _ in range(2 * self.size)]
        # print("Prio Size: " + str(len(self.priority_min)))

    def append(self, state, action, reward, n_state, done, trun, stream, prio=True):

        # append to memory
        self.append_memory(state, action, reward, n_state, done, trun, stream)

        # append to pointer
        self.append_pointer(stream, prio)

        if done or trun:
            self.finalize_experiences(stream)
            self.state_buffer[stream] = []
            self.reward_buffer[stream] = []
            self.encodings_buffer[stream] = []

        self.last_terminal[stream] = done or trun  # should this be or trun? TPPPPP

    # def _set_priority_min(self, idx, priority_alpha):
    #     idx += self.size
    #     self.priority_min[idx] = priority_alpha
    #     while idx >= 2:
    #         idx //= 2
    #         self.priority_min[idx] = min(self.priority_min[2 * idx], self.priority_min[2 * idx + 1])

    def append_pointer(self, stream, prio):

        while len(self.state_buffer[stream]) >= self.framestack + self.n_step and len(
                self.reward_buffer[stream]) >= self.n_step:
            # First array in the experience
            state_array = self.state_buffer[stream][:self.framestack]

            # Second array in the experience (starts after N frames)
            n_state_array = self.state_buffer[stream][self.n_step:self.n_step + self.framestack]

            # Reward array (first N rewards)
            reward_array = self.reward_buffer[stream][:self.n_step]

            encodings_array = self.encodings_buffer[stream][:self.n_step + self.context + self.framestack - 1]

            # print("Added Experience: (" + str(self.point_mem_idx) + ")")
            # print((np.array(state_array, dtype=int), np.array(n_state_array, dtype=int), np.array(reward_array, dtype=int)))

            # Add the experience to the list
            self.pointer_mem[self.point_mem_idx] = (
            np.array(state_array, dtype=int), np.array(n_state_array, dtype=int),
            np.array(reward_array, dtype=int), np.array(encodings_array, dtype=int))

            # self._set_priority_min(self.point_mem_idx, sqrt(self.max_prio))
            self.st.append(self.max_prio ** self.alpha)

            self.capacity = min(self.size, self.capacity + 1)
            self.point_mem_idx = (self.point_mem_idx + 1) % self.size

            # Remove the first state and reward from the buffers to slide the window
            self.state_buffer[stream].pop(0)
            self.reward_buffer[stream].pop(0)
            self.encodings_buffer[stream].pop(0)
            self.beta = 0

    def finalize_experiences(self, stream):
        # Process remaining states and rewards at the end of an episode
        while len(self.state_buffer[stream]) >= self.framestack and len(self.reward_buffer[stream]) > 0:
            # First array in the experience
            first_array = self.state_buffer[stream][:self.framestack]
            # print(first_array)

            # Second array in the experience (Final `framestack` elements)
            second_array = self.state_buffer[stream][-self.framestack:]

            # encodings array
            encodings_array = self.encodings_buffer[stream][:]
            while len(encodings_array) < self.n_step + self.context + self.framestack - 1:
                # add zeros at the end these won't be used as they are apart of next_state
                encodings_array.extend([0])

            # Reward array
            reward_array = self.reward_buffer[stream][:]
            while len(reward_array) < self.n_step:
                reward_array.extend([0])

            # Add the experience
            self.pointer_mem[self.point_mem_idx] = (np.array(first_array, dtype=int), np.array(second_array, dtype=int),
                                                    np.array(reward_array, dtype=int),
                                                    np.array(encodings_array, dtype=int))

            # self._set_priority_min(self.point_mem_idx, sqrt(self.max_prio))
            self.st.append(self.max_prio ** self.alpha)

            self.point_mem_idx = (self.point_mem_idx + 1) % self.size
            self.capacity = min(self.size, self.capacity + 1)

            # Remove the first state and reward from the buffers to slide the window
            self.state_buffer[stream].pop(0)
            self.encodings_buffer[stream].pop(0)

            if len(self.reward_buffer[stream]) > 0:
                self.reward_buffer[stream].pop(0)

    def append_memory(self, state, action, reward, n_state, done, trun, stream):

        if self.last_terminal[stream]:
            # add full transition

            # repeat encodings "context" times
            for i in range(self.context - 1):  # context then current state
                self.encodings_buffer[stream].append(self.state_mem_idx)

            for i in range(self.framestack):
                self.state_mem[self.state_mem_idx] = state[i]
                self.state_buffer[stream].append(self.state_mem_idx)
                self.encodings_buffer[stream].append(self.state_mem_idx)
                self.state_mem_idx = (self.state_mem_idx + 1) % self.storage_size

            # remember n_step is not applied in this memory
            self.state_mem[self.state_mem_idx] = n_state[self.framestack - 1]
            self.encodings_buffer[stream].append(self.state_mem_idx)
            self.state_buffer[stream].append(self.state_mem_idx)
            self.state_mem_idx = (self.state_mem_idx + 1) % self.storage_size

            self.action_mem[self.reward_mem_idx] = action
            self.reward_mem[self.reward_mem_idx] = reward
            self.done_mem[self.reward_mem_idx] = done
            self.trun_mem[self.reward_mem_idx] = trun

            self.reward_buffer[stream].append(self.reward_mem_idx)
            self.reward_mem_idx = (self.reward_mem_idx + 1) % self.storage_size

            self.tstep_counter[stream] = 0

        else:
            # just add relevant info
            self.state_mem[self.state_mem_idx] = n_state[self.framestack - 1]
            self.state_buffer[stream].append(self.state_mem_idx)

            self.encodings_buffer[stream].append(self.state_mem_idx)

            self.state_mem_idx = (self.state_mem_idx + 1) % self.storage_size

            self.action_mem[self.reward_mem_idx] = action
            self.reward_mem[self.reward_mem_idx] = reward
            self.done_mem[self.reward_mem_idx] = done
            self.trun_mem[self.reward_mem_idx] = trun

            self.reward_buffer[stream].append(self.reward_mem_idx)
            self.reward_mem_idx = (self.reward_mem_idx + 1) % self.storage_size

    def sample(self, batch_size):

        # get total sumtree priority
        p_total = self.st.total()

        # first use sumtree prios to get the indices
        segment_length = p_total / batch_size
        segment_starts = np.arange(batch_size) * segment_length
        try:
            samples = np.random.uniform(0.0, segment_length, [batch_size]) + segment_starts
        except Exception as e:
            print(segment_length)
            print(segment_starts)
            print(e)
            raise Exception("Stop")

        prios, idxs, tree_idxs = self.st.find(samples)

        probs = prios / p_total

        # fetch the pointers by using indices
        pointers = self.pointer_mem[idxs]
        # print("Pointers")
        # print(pointers)

        # Extract the pointers into separate arrays
        # state_pointers = np.array([p[0] for p in pointers])
        # n_state_pointers = np.array([p[1] for p in pointers])
        reward_pointers = np.array([p[2] for p in pointers])
        if self.n_step > 1:
            action_pointers = np.array([p[2][0] for p in pointers])
        else:
            action_pointers = np.array([p[2] for p in pointers])

        encoding_pointers = np.array([p[3] for p in pointers])

        # get state info
        # states_old = torch.tensor(self.state_mem[state_pointers], dtype=torch.uint8)
        # n_states_old = torch.tensor(self.state_mem[n_state_pointers], dtype=torch.uint8)

        # reward and dones just use the same pointer. actions just use the first one
        rewards = self.reward_mem[reward_pointers]
        dones = self.done_mem[reward_pointers]
        truns = self.trun_mem[reward_pointers]
        actions = self.action_mem[action_pointers]

        # print(f"Encoding Pointers: {encoding_pointers}")

        encodings = torch.tensor(self.state_mem[encoding_pointers], dtype=torch.uint8)
        # this puts it into the form using 4 framestacks, whereas the buffer on stores each frame once
        encodings = encodings.unfold(dimension=1, size=self.framestack, step=1)

        # last line puts extra dim at end, I want it in the middle
        encodings = encodings.permute(0, 1, 4, 2, 3)
        # print(f"Sample Encodings: {encodings.shape}")
        states = encodings[:, :self.context, :, :, :]  # first 40 along the sequence dimension
        n_states = encodings[:, self.n_step:, :, :, :]  # last 40 along the sequence dimension

        # print(f"States: {states.shape}")
        # print(f"Next States: {n_states.shape}")

        # Save States images: using the last three time steps from the state tensor.
        # Here states[0][-1] is assumed to be the most recent time, so we pick the last entry in that time dimension.
        # save_image(states[0][-1][-1].unsqueeze(dim=0), "States-1.png")
        # save_image(states[0][-2][-1].unsqueeze(dim=0), "States-2.png")
        # save_image(states[0][-3][-1].unsqueeze(dim=0), "States-3.png")
        # print("Saved States images.\n")
        #
        # # Save OldStates images: similar approach but without an extra time dim inside.
        # save_image(states_old[0][-1].unsqueeze(dim=0), "OldStates-1.png")
        # save_image(states_old[0][-2].unsqueeze(dim=0), "OldStates-2.png")
        # save_image(states_old[0][-3].unsqueeze(dim=0), "OldStates-3.png")
        # print("Saved OldStates images.\n")
        #
        # Save NextStates images: similar to States images.
        # save_image(n_states[0][-1][-1].unsqueeze(dim=0), "NextStates-1.png")
        # save_image(n_states[0][-2][-1].unsqueeze(dim=0), "NextStates-2.png")
        # save_image(n_states[0][-3][-1].unsqueeze(dim=0), "NextStates-3.png")
        # print("Saved NextStates images.\n")
        #
        # # Save OldNextStates images:
        # save_image(n_states_old[0][-1].unsqueeze(dim=0), "OldNextStates-1.png")
        # save_image(n_states_old[0][-2].unsqueeze(dim=0), "OldNextStates-2.png")
        # save_image(n_states_old[0][-3].unsqueeze(dim=0), "OldNextStates-3.png")
        # print("Saved OldNextStates images.\n")

        # apply n_step cumulation to rewards and dones
        if self.n_step > 1:
            rewards, dones = self.compute_discounted_rewards_batch(rewards, dones, truns)

        # prob_min = self.priority_min[1] / p_total
        # max_weight = (prob_min * self.capacity) ** (-self.beta)

        # Compute importance-sampling weights w
        weights = (self.capacity * probs) ** -self.beta

        weights = torch.tensor(weights / weights.max(), dtype=torch.float32,
                               device=self.device)  # Normalise by max importance-sampling weight from batch

        if torch.isnan(weights).any():
            print("Nan Found is sample!")
            print(f"Prios {prios}")
            print(f"Probs {probs}")
            print(f"Weights {weights}")

        # move to pytorch GPU tensors
        states = states.to(torch.float32).to(self.device)
        n_states = n_states.to(torch.float32).to(self.device)
        rewards = torch.tensor(rewards, dtype=torch.float32, device=self.device)
        dones = torch.tensor(dones, dtype=torch.bool, device=self.device)
        actions = torch.tensor(actions, dtype=torch.int64, device=self.device)

        # current_encodings = torch.tensor(encodings[:, -(self.n_step + 1)], dtype=torch.float32, device=self.device)
        # next_encodings = torch.tensor(encodings[:, -1], dtype=torch.float32, device=self.device)
        # burn_encodings = torch.tensor(encodings[:, :self.burn_in], dtype=torch.float32, device=self.device)

        # return batch
        return tree_idxs, states, actions, rewards, n_states, dones, weights

    def compute_discounted_rewards_batch(self, rewards_batch, dones_batch, truns_batch):
        """
        Compute discounted rewards for a batch of rewards and dones.

        Parameters:
        rewards_batch (np.ndarray): 2D array of rewards with shape (batch_size, n_step)
        dones_batch (np.ndarray): 2D array of dones with shape (batch_size, n_step)

        Returns:
        np.ndarray: 1D array of discounted rewards for each element in the batch
        np.ndarray: 1D array of cumulative dones (True if any done is True in the sequence)
        """
        batch_size, n_step = rewards_batch.shape
        discounted_rewards = np.zeros(batch_size)
        cumulative_dones = np.zeros(batch_size, dtype=bool)

        for i in range(batch_size):
            cumulative_discount = 1
            for j in range(n_step):
                discounted_rewards[i] += cumulative_discount * rewards_batch[i, j]
                if dones_batch[i, j] == 1:
                    cumulative_dones[i] = True
                    break
                elif truns_batch[i, j] == 1:
                    break
                cumulative_discount *= self.gamma

        return discounted_rewards, cumulative_dones

    def update_priorities(self, idxs, priorities):
        priorities = priorities + self.eps

        # for idx, priority in zip(idxs, priorities):
        #     self._set_priority_min(idx - self.size + 1, sqrt(priority))

        if np.isnan(priorities).any():
            print("NaN found in priority!")
            print(f"priorities: {priorities}")

        self.max_prio = max(self.max_prio, np.max(priorities))
        self.st.update(idxs, priorities ** self.alpha)


############## A few smaller functions to assist the main program

class EpsilonGreedy:
    def __init__(self, eps_start, eps_steps, eps_final, action_space):
        self.eps = eps_start
        self.steps = eps_steps
        self.eps_final = eps_final
        self.action_space = action_space

    def update_eps(self):
        self.eps = max(self.eps - (self.eps - self.eps_final) / self.steps, self.eps_final)

    def choose_action(self):
        if np.random.random() > self.eps:
            return None
        else:
            return np.random.choice(self.action_space)


def randomise_action_batch(x, probs, n_actions):
    mask = torch.rand(x.shape) < probs

    # Generate random values to replace the selected elements
    random_values = torch.randint(0, n_actions, x.shape)

    # Apply the mask to replace elements in the tensor with random values
    x[mask] = random_values[mask]

    return x


def choose_eval_action(observation, eval_net, n_actions, device, rng):
    with torch.no_grad():
        state = T.tensor(observation, dtype=T.float).to(device)
        qvals = eval_net.qvals(state, advantages_only=True)
        x = T.argmax(qvals, dim=1).cpu()

        if rng > 0.:
            # Generate a mask with the given probability
            x = randomise_action_batch(x, 0.01, n_actions)

    return x


def create_network(input_dims, n_actions, device, linear_size, lstm_out_size=256):
    return NatureC51(input_dims[0], n_actions, device=device, linear_size=linear_size, lstm_out_size=lstm_out_size)


#################### The big ol agent class, be prepared

class Agent:
    def __init__(self, n_actions, input_dims, device, num_envs, agent_name, total_frames, testing=False, batch_size=256,
                 rr=1, lr=1e-4, target_replace=500, discount=0.99, linear_size=512, non_factorised=False,
                 replay_period=1, framestack=4, rgb=False, imagex=84, imagey=84, per_alpha=0.5, max_mem_size=1048576,
                 eps_steps=2000000, n=3, grad_clip=10, lstm_out_size=256, context=40):

        self.per_alpha = per_alpha

        #self.scaler = torch.cuda.amp.GradScaler()

        self.Vmin = -10
        self.Vmax = 10
        self.N_ATOMS = 51

        self.scaler = torch.cuda.amp.GradScaler()

        self.procgen = True if input_dims[1] == 64 else False
        self.grad_clip = grad_clip

        self.n_actions = n_actions
        self.input_dims = input_dims
        self.device = device
        self.agent_name = agent_name
        self.testing = testing

        self.loading_checkpoint = False

        self.per_beta = 0.45

        self.lstm_out_size = lstm_out_size

        self.replay_ratio = int(rr) if rr > 0.99 else float(rr)
        self.total_frames = total_frames
        self.num_envs = num_envs

        if self.testing:
            self.min_sampling_size = 4000
        else:
            self.min_sampling_size = 80000

        self.lr = lr

        # this is the number of env steps per grad step
        self.replay_period = replay_period

        # replay ratio however does not take into account parallel envs

        # in this code, every {replay period} steps, we take {replay_ratio} grad steps

        self.total_grad_steps = (self.total_frames - self.min_sampling_size) / (self.replay_period / self.replay_ratio)

        self.priority_weight_increase = (1 - self.per_beta) / self.total_grad_steps

        self.action_space = [i for i in range(self.n_actions)]
        self.learn_step_counter = 0

        self.chkpt_dir = ""

        self.n = n
        self.gamma = discount
        self.discount_anneal = False
        self.batch_size = batch_size

        # this option is only available for non-impala. I could add it, but factorised seemed
        # to perform the same and is faster
        self.non_factorised = non_factorised

        # 1 Million rounded to the nearest power of 2 for tree implementation
        self.max_mem_size = max_mem_size

        self.replace_target_cnt = target_replace  # This is the number of grad steps - could be a little jank
        # when changing num_envs/batch size/replay ratio

        # Best used value is 32000 frames per replace. For bs 256, this is 500. For bs 16, this is every 8000!

        if not self.loading_checkpoint and not self.testing:
            self.eps_start = 0.0
            # divided by 4 is due to frameskip
            self.eps_steps = eps_steps
            self.eps_final = 0.00
        else:
            self.eps_start = 0.00
            self.eps_steps = eps_steps
            self.eps_final = 0.00

        self.epsilon = EpsilonGreedy(self.eps_start, self.eps_steps, self.eps_final, self.action_space)

        self.linear_size = linear_size
        self.imagex = imagex
        self.imagey = imagey

        self.framestack = framestack
        self.rgb = rgb
        self.context = context
        self.memory = PER(self.max_mem_size, device, self.n, num_envs, self.gamma, alpha=self.per_alpha,
                          beta=self.per_beta, framestack=self.framestack, rgb=self.rgb, imagex=imagex, imagey=imagey,
                          context=self.context)

        self.network_creator_fn = partial(create_network, self.input_dims, self.n_actions, self.device
                                          , self.linear_size, lstm_out_size=self.lstm_out_size)

        self.net = self.network_creator_fn()
        self.tgt_net = self.network_creator_fn()

        self.optimizer = optim.Adam(self.net.parameters(), lr=self.lr, eps=0.005 / self.batch_size)  # 0.00015

        self.net.train()

        self.eval_net = None

        for param in self.tgt_net.parameters():
            param.requires_grad = False

        self.env_steps = 0
        self.grad_steps = 0

        self.replay_ratio_cnt = 0
        self.eval_mode = False

    def prep_evaluation(self):
        self.eval_net = deepcopy(self.net)
        self.disable_noise(self.eval_net)

    @torch.no_grad()
    def reset_noise(self, net):
        for m in net.modules():
            if isinstance(m, FactorizedNoisyLinear):
                m.reset_noise()

    @torch.no_grad()
    def disable_noise(self, net):
        for m in net.modules():
            if isinstance(m, FactorizedNoisyLinear):
                m.disable_noise()

    @torch.autocast(device_type="cuda")
    def choose_action(self, observation):
        # encodings should be a batch of len context

        # this chooses an action for a batch. Can be used with a batch of 1 if needed though
        with T.no_grad():
            if not self.eval_mode:
                self.reset_noise(self.net)

            state = T.tensor(observation, dtype=T.float).to(self.net.device)

            qvals = self.net.qvals(state, advantages_only=True)
            x = T.argmax(qvals, dim=1).cpu()

            return x

    def store_transition(self, state, action, reward, next_state, done, trun, stream, prio=True):

        if self.rgb:
            # expand dims to create "framestack" dim, so it works with my replay buffer
            state = np.expand_dims(state, axis=0)
            next_state = np.expand_dims(next_state, axis=0)

        self.memory.append(state, action, reward, next_state, done, trun, stream, prio=prio)

        self.epsilon.update_eps()
        self.env_steps += 1

    def replace_target_network(self):
        self.tgt_net.load_state_dict(self.net.state_dict())

    def save_model(self):
        self.net.save_checkpoint(self.agent_name + "_" + str(int((self.env_steps // 250000))) + "M")

    def load_models(self, name):
        self.net.load_checkpoint(name)
        self.tgt_net.load_checkpoint(name)

    def learn(self):
        if self.replay_ratio < 1:
            if self.replay_ratio_cnt == 0:
                self.learn_call()
            self.replay_ratio_cnt = (self.replay_ratio_cnt + 1) % (int(1 / self.replay_ratio))
        else:
            for i in range(self.replay_ratio):
                self.learn_call()

    @torch.autocast(device_type="cuda")
    def learn_call(self):

        if self.env_steps < self.min_sampling_size:
            return

        self.reset_noise(self.tgt_net)

        if self.grad_steps % self.replace_target_cnt == 0:
            self.replace_target_network()

        idxs, states, actions, rewards, next_states, dones, weights = self.memory.sample(self.batch_size)

        self.optimizer.zero_grad()

        # use this code to check your states are correct!
        # If you apply BTR to a custom env and don't check your states first, you are killing both
        # trees and your own time

        # plt.imshow(states[0][0].unsqueeze(dim=0).cpu().permute(1, 2, 0))
        # plt.show()
        #
        # plt.imshow(states[0][1].unsqueeze(dim=0).cpu().permute(1, 2, 0))
        # plt.show()
        #
        # plt.imshow(states[0][2].unsqueeze(dim=0).cpu().permute(1, 2, 0))
        # plt.show()
        #
        # plt.imshow(states[1][0].unsqueeze(dim=0).cpu().permute(1, 2, 0))
        # plt.show()
        #
        # plt.imshow(states[2][0].unsqueeze(dim=0).cpu().permute(1, 2, 0))
        # plt.show()

        distr_v, qvals_v = self.net.both(states)
        state_action_values = distr_v[range(self.batch_size), actions.data]
        state_log_sm_v = F.log_softmax(state_action_values, dim=1)

        with torch.no_grad():
            # this is using Double DQN
            next_distr_v, next_qvals_v = self.tgt_net.both(next_states)
            action_distr_v, action_qvals_v = self.net.both(next_states)

            next_actions_v = action_qvals_v.max(1)[1]

            next_best_distr_v = next_distr_v[range(self.batch_size), next_actions_v.data]
            next_best_distr_v = self.tgt_net.apply_softmax(next_best_distr_v)
            next_best_distr = next_best_distr_v.data.cpu()

            proj_distr = distr_projection(next_best_distr, rewards.cpu(), dones.cpu(), self.Vmin, self.Vmax,
                                          self.N_ATOMS, self.gamma ** self.n)

            proj_distr_v = proj_distr.to(self.net.device)

        loss_v = -state_log_sm_v * proj_distr_v

        weights = T.squeeze(weights)
        loss_v = weights.to(self.net.device) * loss_v.sum(dim=1)

        loss = loss_v.mean()

        # update PER prios
        self.memory.update_priorities(idxs, loss_v.cpu().detach().numpy())

        ###### pytorch AMP
        self.scaler.scale(loss).backward()

        # Unscale the gradients before clipping - Gradients get completely zeroed without this
        self.scaler.unscale_(self.optimizer)

        torch.nn.utils.clip_grad_norm_(self.net.parameters(), self.grad_clip)
        self.scaler.step(self.optimizer)
        self.scaler.update()
        ##### This is non AMP version
        # loss.backward()
        # torch.nn.utils.clip_grad_norm_(self.net.parameters(), self.grad_clip)
        # self.optimizer.step()

        self.grad_steps += 1
        if self.grad_steps % 10000 == 0:
            print("Completed " + str(self.grad_steps) + " gradient steps")


def distr_projection(next_distr, rewards, dones, Vmin, Vmax, n_atoms, gamma):
    """
    Perform distribution projection aka Catergorical Algorithm from the
    "A Distributional Perspective on RL" paper
    """
    batch_size = len(rewards)
    proj_distr = T.zeros((batch_size, n_atoms), dtype=T.float32)
    delta_z = (Vmax - Vmin) / (n_atoms - 1)
    for atom in range(n_atoms):
        tz_j = np.minimum(Vmax, np.maximum(Vmin, rewards + (Vmin + atom * delta_z) * gamma))
        b_j = (tz_j - Vmin) / delta_z
        l = np.floor(b_j).type(T.int64)
        u = np.ceil(b_j).type(T.int64)
        eq_mask = u == l
        proj_distr[eq_mask, l[eq_mask]] += next_distr[eq_mask, atom]
        ne_mask = u != l
        proj_distr[ne_mask, l[ne_mask]] += next_distr[ne_mask, atom] * (u - b_j)[ne_mask]
        proj_distr[ne_mask, u[ne_mask]] += next_distr[ne_mask, atom] * (b_j - l)[ne_mask]
    if dones.any():
        proj_distr[dones] = 0.0
        tz_j = np.minimum(Vmax, np.maximum(Vmin, rewards[dones]))
        b_j = (tz_j - Vmin) / delta_z
        l = np.floor(b_j).type(T.int64)
        u = np.ceil(b_j).type(T.int64)
        eq_mask = u == l
        eq_dones = T.clone(dones)
        eq_dones[dones] = eq_mask
        if eq_dones.any():
            proj_distr[eq_dones, l[eq_mask]] = 1.0
        ne_mask = u != l
        ne_dones = T.clone(dones)
        ne_dones[dones] = ne_mask
        if ne_dones.any():
            proj_distr[ne_dones, l[ne_mask]] = (u - b_j)[ne_mask]
            proj_distr[ne_dones, u[ne_mask]] = (b_j - l)[ne_mask]
    return proj_distr


def calculate_huber_loss(td_errors, k=1.0, taus=8):
    """
    Calculate huber loss element-wisely depending on kappa k.
    """
    loss = torch.where(td_errors.abs() <= k, 0.5 * td_errors.pow(2), k * (td_errors.abs() - 0.5 * k))
    assert loss.shape == (td_errors.shape[0], taus, taus), "huber loss has wrong shape"
    return loss


def huber_loss(td_errors, k=1.0):
    """
    Calculate huber loss element-wisely depending on kappa k.
    """
    loss = torch.where(td_errors.abs() <= k, 0.5 * td_errors.pow(2), k * (td_errors.abs() - 0.5 * k))
    return loss


def make_env(envs_create, game, life_info, framestack, repeat_probs):
    return gym.vector.AsyncVectorEnv([lambda: gym.wrappers.FrameStack(
        AtariPreprocessingCustom(gym.make("ALE/" + game + "-v5", frameskip=1, repeat_action_probability=repeat_probs),
                                 life_information=life_info), framestack,
        lz4_compress=False) for _ in range(envs_create)], context="spawn")


def non_default_args(args, parser):
    result = []
    repeat_val = None  # To store the value of 'repeat'

    # Iterate over all arguments
    for arg in vars(args):
        if arg == 'repeat':
            # Store the 'repeat' value and skip adding it in the main loop
            repeat_val = getattr(args, arg)
            continue

        user_val = getattr(args, arg)
        default_val = parser.get_default(arg)

        # Check if the argument should be included
        if (user_val != default_val and
                default_val != "NameThisGame" and
                arg not in ["include_evals", "eval_envs", "num_eval_episodes", "analy", "save_allll"]):
            # Format: argName + value (e.g., testing1)
            result.append(f"{arg}{user_val}")

    # After processing all other arguments, handle 'repeat' if it's non-default
    if repeat_val != parser.get_default('repeat'):
        result.append(f"repeat{repeat_val}")

    # Join all parts with underscores
    return '_'.join(result)


def format_arguments(arg_string):
    arg_string = arg_string.replace('=', '')
    arg_string = arg_string.replace('True', '1')
    arg_string = arg_string.replace('False', '0')
    arg_string = arg_string.replace(', ', '_')
    return arg_string


def evaluate_agent(net_state_dict, network_creator, eval_envs, num_eval_episodes, agent_name, testing, game, life_info,
                   n_actions, device, index, framestack, repeat_probs):
    eval_env = make_env(eval_envs, game, life_info, framestack, repeat_probs)
    evals = []
    eval_episodes = 0
    eval_scores = np.array([0 for i in range(eval_envs)])
    eval_observation, eval_info = eval_env.reset()

    eval_net = network_creator()

    # move state dict to gpu - pytorch doesn't allow sharing across threads on gpu
    state_dict_gpu = {k: v.to(device) for k, v in net_state_dict.items()}

    eval_net.load_state_dict(state_dict_gpu)

    # this massively helps speed up training since agents get stuck in some games, causing evals to last a very
    # long time
    if index <= 125:
        rng = 0.01
    else:
        rng = 0.0
    while eval_episodes < num_eval_episodes:

        eval_action = choose_eval_action(eval_observation, eval_net, n_actions, device, rng)
        eval_observation_, eval_reward, eval_done_, eval_trun_, eval_info = eval_env.step(eval_action)
        eval_done_ = np.logical_or(eval_done_, eval_trun_)

        for i in range(eval_envs):
            eval_scores[i] += eval_reward[i]
            if eval_done_[i]:
                eval_episodes += 1
                evals.append(eval_scores[i])
                eval_scores[i] = 0
                if eval_episodes >= num_eval_episodes:
                    break

        eval_observation = eval_observation_
        # for stream in range(eval_envs):
        #     if eval_done_[stream]:
        #         eval_observation[stream] = eval_info["final_observation"][stream]

    if not testing:
        fname = agent_name + "Evaluation.npy"
        data = np.load(fname)

        # Update the specified index in the 0th dimension
        data[index] = evals
        print("Evaluation " + str(index + 1) + "M Complete, average score:")
        print(np.mean(evals))

        # Save the updated array back to the file
        np.save(fname, data)


def main():
    parser = argparse.ArgumentParser()

    # environment setup
    parser.add_argument('--game', type=str, default="NameThisGame")

    parser.add_argument('--envs', type=int, default=64)
    parser.add_argument('--frames', type=int, default=200000000)
    parser.add_argument('--eval_envs', type=int, default=10)
    parser.add_argument('--life_info', type=int, default=0)  # DO NOT USE - THIS IS CHEATING AND NOT COMPARABLE

    parser.add_argument('--bs', type=int, default=256)
    parser.add_argument('--rr', type=float, default=1)

    parser.add_argument('--repeat', type=int, default=0)
    parser.add_argument('--include_evals', type=int, default=0)

    parser.add_argument('--num_eval_episodes', type=int, default=100)
    parser.add_argument('--framestack', type=int, default=4)
    parser.add_argument('--sticky', type=int, default=1)

    # agent setup
    parser.add_argument('--nstep', type=int, default=3)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--testing', type=bool, default=False)
    parser.add_argument('--grad_clip', type=int, default=10)

    parser.add_argument('--discount', type=float, default=0.99)
    parser.add_argument('--c', type=int, default=500)
    parser.add_argument('--linear_size', type=int, default=512)

    parser.add_argument('--per_alpha', type=float, default=0.5)
    parser.add_argument('--per_beta_anneal', type=int, default=0)
    parser.add_argument('--eps_steps', type=int, default=2000000)

    # specfic to this file
    parser.add_argument('--context', type=int, default=40)
    parser.add_argument('--lstm_size', type=int, default=256)
    # ResNet18Pen
    # ResNet18ThirdLast

    args = parser.parse_args()

    arg_string = non_default_args(args, parser)
    formatted_string = format_arguments(arg_string)
    print(formatted_string)

    game = args.game
    envs = args.envs
    bs = args.bs
    rr = args.rr
    c = args.c
    lr = args.lr
    life_info = args.life_info  # THIS FEATURE HAS BEEN DISABLED DONT USE
    num_eval_episodes = args.num_eval_episodes
    framestack = args.framestack
    sticky = args.sticky
    repeat_probs = 0 if not sticky else 0.25

    nstep = args.nstep
    grad_clip = args.grad_clip
    discount = args.discount
    linear_size = args.linear_size
    frames = args.frames // 4
    per_alpha = args.per_alpha
    eps_steps = args.eps_steps

    context = args.context
    lstm_size = args.lstm_size

    lr_str = "{:e}".format(lr)
    lr_str = str(lr_str).replace(".", "").replace("0", "")
    frame_name = str(int(args.frames / 1000000)) + "M"

    include_evals = bool(args.include_evals)
    agent_name = "Rainbow_" + game + frame_name + "_LSTM"

    if len(formatted_string) > 2:
        agent_name += '_' + formatted_string

    print("Agent Name:" + str(agent_name))
    testing = args.testing

    # creates new directory for results and models
    if not testing:
        counter = 0
        while True:
            if counter == 0:
                new_dir_name = agent_name
            else:
                new_dir_name = f"{agent_name}_{counter}"
            if not os.path.exists(new_dir_name):
                break
            counter += 1
        os.mkdir(new_dir_name)
        print(f"Created directory: {new_dir_name}")
        os.chdir(new_dir_name)

    # create blank evaluation file
    fname = agent_name + "Evaluation.npy"
    if not testing:
        np.save(fname, np.zeros((args.frames // 1000000, num_eval_episodes)))

    if testing:
        # goes easy on the PC when debugging
        num_envs = 8
        eval_envs = 2
        eval_every = 11580000
        num_eval_episodes = 5
        n_steps = 11560000
        bs = 64
    else:
        num_envs = envs
        eval_envs = args.eval_envs
        n_steps = frames
        eval_every = 250000
    next_eval = eval_every

    print("Currently Playing Game: " + str(game))

    gpu = "0"
    device = torch.device('cuda:' + gpu if torch.cuda.is_available() else 'cpu')
    print("Device: " + str(device))

    env = make_env(num_envs, game, life_info, framestack, repeat_probs)

    print(env.observation_space)
    print(env.action_space[0])
    n_actions = env.action_space[0].n

    agent = Agent(n_actions=env.action_space[0].n, input_dims=[framestack, 84, 84], device=device, num_envs=num_envs,
                  agent_name=agent_name, total_frames=n_steps, testing=testing, batch_size=bs, rr=rr, lr=lr,
                  target_replace=c, discount=discount, linear_size=linear_size, replay_period=num_envs,
                  framestack=framestack, per_alpha=per_alpha, eps_steps=eps_steps, n=nstep, grad_clip=grad_clip,
                  lstm_out_size=lstm_size, context=context)

    scores_temp = []
    steps = 0
    last_steps = 0
    last_time = time.time()
    episodes = 0
    current_eval = 0
    scores_count = [0 for i in range(num_envs)]
    scores = []

    obs_queue = TensorQueue(num_envs, context, (4, 84, 84), device, dtype=torch.uint8)

    observation, info = env.reset()

    for i in range(context):
        obs_queue.enqueue(torch.from_numpy(observation).clone().to(device))

    processes = []

    while steps < n_steps:
        steps += num_envs
        action = agent.choose_action(obs_queue.get_queue())

        env.step_async(action)
        agent.learn()
        observation_, reward, done_, trun_, info = env.step_wait()

        obs_queue.enqueue(torch.from_numpy(observation_).clone().to(device))

        for i in range(num_envs):
            scores_count[i] += reward[i]
            if done_[i] or trun_[i]:
                episodes += 1
                scores.append([scores_count[i], steps])
                scores_temp.append(scores_count[i])
                scores_count[i] = 0

        reward = np.clip(reward, -1., 1.)

        for stream in range(num_envs):
            terminal_in_buffer = done_[stream]  # or info["lost_life"][stream]
            next_obs = observation_[stream] if not trun_[stream] else np.array(info["final_observation"][stream])

            agent.store_transition(observation[stream], action[stream], reward[stream], next_obs,
                                   terminal_in_buffer, trun_[stream], stream=stream)

            if done_[stream] or trun_[stream]:
                for i in range(context):
                    obs_queue.enqueue(torch.from_numpy(observation_[stream]).clone().to(device), stream)

        observation = observation_

        if steps % 1200 == 0 and len(scores) > 0:
            avg_score = np.mean(scores_temp[-50:])
            if episodes % 1 == 0:
                print('{} {} avg score {:.2f} total_steps {:.0f} fps {:.2f} games {}'
                      .format(agent_name, game, avg_score, steps, (steps - last_steps) / (time.time() - last_time),
                              episodes),
                      flush=True)
                last_steps = steps
                last_time = time.time()

        # Evaluation
        if steps >= next_eval or steps >= n_steps:
            print("Evaluating")

            # Save model
            if not testing and (current_eval + 1) == 1 \
                    or (current_eval + 1) == 100 or (current_eval + 1) == 200:
                # or (current_eval + 1) == 10 or (current_eval + 1) == 150 or (current_eval + 1) == 50
                agent.save_model()

            fname = agent_name + "Experiment.npy"
            if not testing:
                np.save(fname, np.array(scores))

            if include_evals:

                # wait for our evaluations to finish before we start the next evaluation

                for process in processes:
                    process.join()

                agent.disable_noise(agent.net)
                net_state_dict = deepcopy({k: v.cpu() for k, v in agent.net.state_dict().items()})
                network_creator = deepcopy(agent.network_creator_fn)

                # Start evaluation in a separate process
                eval_process = mp.Process(target=evaluate_agent,
                                          args=(
                                          net_state_dict, network_creator, eval_envs, num_eval_episodes, agent_name,
                                          testing, game,
                                          life_info, n_actions, device, current_eval, framestack, repeat_probs))
                eval_process.start()
                processes.append(eval_process)

            current_eval += 1

            next_eval += eval_every

    # wait for our evaluations to finish before we quit the program
    for process in processes:
        process.join()

    print("Evaluations finished, job completed successfully!")


if __name__ == '__main__':
    mp.set_start_method('spawn')
    main()
