from __future__ import annotations
import time
import gymnasium as gym
import argparse
import multiprocessing as mp
import os
import torch
import torch as T
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from copy import deepcopy
from functools import partial
from typing import Any, SupportsFloat
from gymnasium.core import WrapperActType, WrapperObsType
from gymnasium.spaces import Box
from torch import nn as nn, Tensor
from torch.nn import init
from math import sqrt
import math
from RISE_Encodings import ResNet18PenultimateFeatureExtractor, ResNet18ThirdLastFeatureExtractor,\
    FasterRCNNResNet50FPNPenultimateFeatureExtractor, ResNet18FullRes, EfficientNetV2PenultimateFeatureExtractor
from gymnasium.spaces import Box, Tuple

############################################ Tensor Queue

class TensorQueue:
    def __init__(self, batch_size, context_length, context, device='cpu', dtype=torch.float32):
        """
        Initializes the TensorQueue with zeros.

        Args:
            batch_size (int): Number of samples in the batch.
            context_length (int): Length of the context (queue size).
            X (int): Feature dimension.
            device (str): Device to store the tensor.
            dtype (torch.dtype): Data type of the tensor.
        """
        self.batch_size = batch_size
        self.context_length = context_length
        self.context = context
        self.device = device
        self.dtype = dtype
        # Initialize the queue with zeros
        self.queue = torch.zeros(batch_size, context_length, context, device=device, dtype=dtype)

    def enqueue(self, new_data, batch_idx=None):
        """
        Adds new data to the queue. If batch_idx is None, enqueues to all batch indices.
        If batch_idx is specified, enqueues only to the given batch index.

        Args:
            new_data (torch.Tensor): Tensor of shape [batch, X] for global enqueue
                                      or [X] for selective enqueue.
            batch_idx (int, optional): Index of the batch to enqueue to. Defaults to None.
        """
        if batch_idx is None:
            # Global enqueue: new_data should have shape [batch, context]
            if new_data.shape != (self.batch_size, self.context):
                raise ValueError(f"Expected new_data shape [{self.batch_size}, {self.context}], but got {new_data.shape}")

            # Expand new_data to [batch, 1, context] for concatenation
            new_data_expanded = new_data.unsqueeze(1)  # Shape: [batch, 1, context]

            # Shift the queue left by one and append the new data
            with torch.no_grad():
                self.queue = torch.cat([self.queue[:, 1:, :], new_data_expanded], dim=1)
        else:
            # Selective enqueue: new_data should have shape [X]
            if not isinstance(batch_idx, int):
                raise TypeError(f"batch_idx must be an int, but got {type(batch_idx)}")
            if not (0 <= batch_idx < self.batch_size):
                raise IndexError(f"batch_idx {batch_idx} out of range for batch size {self.batch_size}")
            if new_data.shape != (self.context,):
                raise ValueError(f"Expected new_data shape [{self.context}], but got {new_data.shape}")

            # Shift the queue left by one for the specified batch index
            with torch.no_grad():
                self.queue[batch_idx, :-1, :] = self.queue[batch_idx, 1:, :].clone()
                # Assign the new data to the last position
                self.queue[batch_idx, -1, :] = new_data

    def get_queue(self):
        """
        Retrieves the current state of the queue.

        Returns:
            torch.Tensor: The queue tensor of shape [batch, context_length, X].
        """
        return self.queue

    def reset(self):
        """
        Resets the queue to all zeros.
        """
        with torch.no_grad():
            self.queue.zero_()

############################################## Networks Section

class FactorizedNoisyLinear(nn.Module):
    """ The factorized Gaussian noise layer for noisy-nets dqn. """
    def __init__(self, in_features: int, out_features: int, sigma_0=0.5, self_norm=False) -> None:
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.sigma_0 = sigma_0

        # weight: w = \mu^w + \sigma^w . \epsilon^w
        self.weight_mu = nn.Parameter(torch.empty(out_features, in_features))
        self.weight_sigma = nn.Parameter(torch.empty(out_features, in_features))
        self.register_buffer('weight_epsilon', torch.empty(out_features, in_features))

        # bias: b = \mu^b + \sigma^b . \epsilon^b
        self.bias_mu = nn.Parameter(torch.empty(out_features))
        self.bias_sigma = nn.Parameter(torch.empty(out_features))
        self.register_buffer('bias_epsilon', torch.empty(out_features))

        if self_norm:
            self.reset_parameters_self_norm()
        else:
            self.reset_parameters()
        self.reset_noise()

        self.disable_noise()

    @torch.no_grad()
    def reset_parameters(self) -> None:
        # initialization is similar to Kaiming uniform (He. initialization) with fan_mode=fan_in
        scale = 1 / sqrt(self.in_features)

        init.uniform_(self.weight_mu, -scale, scale)
        init.uniform_(self.bias_mu, -scale, scale)

        init.constant_(self.weight_sigma, self.sigma_0 * scale)
        init.constant_(self.bias_sigma, self.sigma_0 * scale)

    @torch.no_grad()
    def reset_parameters_self_norm(self) -> None:
        # initialization is similar to Kaiming uniform (He. initialization) with fan_mode=fan_in

        nn.init.normal_(self.weight_mu, std=1 / math.sqrt(self.out_features))
        if self.bias_mu is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_mu)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias_mu, -bound, bound)

    @torch.no_grad()
    def _get_noise(self, size: int) -> Tensor:
        noise = torch.randn(size, device=self.weight_mu.device)
        # f(x) = sgn(x)sqrt(|x|)
        return noise.sign().mul_(noise.abs().sqrt_())

    @torch.no_grad()
    def reset_noise(self) -> None:
        # like in eq 10 and 11 of the paper
        epsilon_in = self._get_noise(self.in_features)
        epsilon_out = self._get_noise(self.out_features)
        self.weight_epsilon.copy_(epsilon_out.outer(epsilon_in))
        self.bias_epsilon.copy_(epsilon_out)

    @torch.no_grad()
    def disable_noise(self) -> None:
        self.weight_epsilon[:] = 0
        self.bias_epsilon[:] = 0

    def forward(self, input: Tensor) -> Tensor:
        # y = wx + d, where
        # w = \mu^w + \sigma^w * \epsilon^w
        # b = \mu^b + \sigma^b * \epsilon^b
        return F.linear(input,
                        self.weight_mu + self.weight_sigma*self.weight_epsilon,
                        self.bias_mu + self.bias_sigma*self.bias_epsilon)


class Dueling(nn.Module):
    """ The dueling branch used in all nets that use dueling-dqn. """
    def __init__(self, value_branch, advantage_branch):
        super().__init__()
        self.flatten = nn.Flatten()
        self.value_branch = value_branch
        self.advantage_branch = advantage_branch

    def forward(self, x, advantages_only=False):
        x = self.flatten(x)
        advantages = self.advantage_branch(x)
        if advantages_only:
            return advantages

        value = self.value_branch(x)
        return value + (advantages - torch.mean(advantages, dim=1, keepdim=True))


class ImpalaCNNResidual(nn.Module):
    """
    Simple residual block used in the large IMPALA CNN.
    """
    def __init__(self, depth, norm_func, activation=nn.ReLU):
        super().__init__()

        self.activation = activation()

        self.conv_0 = norm_func(nn.Conv2d(in_channels=depth, out_channels=depth, kernel_size=3, stride=1, padding=1))
        self.conv_1 = norm_func(nn.Conv2d(in_channels=depth, out_channels=depth, kernel_size=3, stride=1, padding=1))

    #@torch.autocast('cuda')
    def forward(self, x):

        x_ = self.conv_0(self.activation(x))

        x_ = self.conv_1(self.activation(x_))
        return x + x_


class ImpalaCNNBlock(nn.Module):
    """
    Three of these blocks are used in the large IMPALA CNN.
    """
    def __init__(self, depth_in, depth_out, norm_func, activation=nn.ReLU, layer_norm=False,
                 layer_norm_shapes=False):
        super().__init__()
        self.layer_norm = layer_norm

        self.conv = nn.Conv2d(in_channels=depth_in, out_channels=depth_out, kernel_size=3, stride=1, padding=1)
        self.max_pool = nn.MaxPool2d(3, 2, padding=1)

        if self.layer_norm:
            self.norm_layer1 = nn.LayerNorm(layer_norm_shapes[0])

        self.residual_0 = ImpalaCNNResidual(depth_out, norm_func=norm_func, activation=activation)
        self.residual_1 = ImpalaCNNResidual(depth_out, norm_func=norm_func, activation=activation)

    def forward(self, x):
        x = self.conv(x)

        if self.layer_norm:
            x = self.norm_layer1(x)

        x = self.max_pool(x)

        x = self.residual_0(x)

        x = self.residual_1(x)

        return x


class ImpalaCNNLargeIQN(nn.Module):
    """
    Implementation of the large variant of the IMPALA CNN introduced in Espeholt et al. (2018).
    """
    def __init__(self, in_depth, actions, model_size=2, device='cuda:0', num_tau=8, maxpool_size=6,
                 linear_size=512, ncos=64, layer_norm=True, encoding_size=512, lstm_out_size=256,
                 combination="mult", pretrain_only=False, maxpool=False):
        super().__init__()

        self.start = time.time()
        self.model_size = model_size
        self.actions = actions
        self.device = device
        self.maxpool = maxpool

        self.combination = combination
        self.encoding_size = encoding_size
        self.lstm_hidden_size = lstm_out_size

        self.pretrain_only = pretrain_only

        self.in_depth = in_depth

        conv_activation = nn.ReLU
        activation = nn.ReLU

        self.linear_size = linear_size
        self.num_tau = num_tau

        self.maxpool_size = maxpool_size

        self.layer_norm = layer_norm

        self.n_cos = ncos
        self.register_buffer('pis', torch.tensor([np.pi * i for i in range(self.n_cos)],
                                                 dtype=torch.float32).view(1, 1, self.n_cos))

        linear_layer = FactorizedNoisyLinear
        norm_func = torch.nn.utils.parametrizations.spectral_norm
        if self.maxpool:
            self.conv = nn.Sequential(
                ImpalaCNNBlock(in_depth, int(16*model_size), norm_func=norm_func, activation=conv_activation,
                               layer_norm=self.layer_norm, layer_norm_shapes=([int(16*model_size), 84, 84], [int(16*model_size), 42, 42])),
                ImpalaCNNBlock(int(16*model_size), int(32*model_size), norm_func=norm_func, activation=conv_activation,
                               layer_norm=self.layer_norm, layer_norm_shapes=([int(32*model_size), 42, 42], [int(32*model_size), 21, 21])),
                ImpalaCNNBlock(int(32 * model_size), int(32 * model_size), norm_func=norm_func, activation=conv_activation,
                               layer_norm=self.layer_norm, layer_norm_shapes=([int(32*model_size), 21, 21], [int(32*model_size), 11, 11])),
                torch.nn.AdaptiveMaxPool2d((self.maxpool_size, self.maxpool_size))
            )
            self.conv_out_size = int(32 * model_size * 6 * 6)
        else:
            self.conv = nn.Sequential(
                ImpalaCNNBlock(in_depth, int(16*model_size), norm_func=norm_func, activation=conv_activation,
                               layer_norm=self.layer_norm, layer_norm_shapes=([int(16*model_size), 84, 84], [int(16*model_size), 42, 42])),
                ImpalaCNNBlock(int(16*model_size), int(32*model_size), norm_func=norm_func, activation=conv_activation,
                               layer_norm=self.layer_norm, layer_norm_shapes=([int(32*model_size), 42, 42], [int(32*model_size), 21, 21])),
                ImpalaCNNBlock(int(32 * model_size), int(32 * model_size), norm_func=norm_func, activation=conv_activation,
                               layer_norm=self.layer_norm, layer_norm_shapes=([int(32*model_size), 21, 21], [int(32*model_size), 11, 11])),
            )
            self.conv_out_size = int(32 * model_size * 11 * 11)

        if self.combination == "mult" or self.combination == "add":
            self.conv_lstm_out_size = self.conv_out_size
        elif self.combination == "concat":
            self.conv_lstm_out_size = self.conv_out_size + self.lstm_hidden_size

        self.lstm = nn.LSTM(
            input_size=self.encoding_size,
            hidden_size=self.lstm_hidden_size,
            num_layers=1,
            batch_first=True  # Expect input of shape (batch_size, seq_len, input_size)
        )

        if self.combination == "mult" or self.combination == "add":
            self.lstm_projection = nn.Linear(self.lstm_hidden_size, self.conv_out_size)

        self.up_projection_act = nn.Sigmoid()

        self.conv.add_module('conv_activation', activation())

        self.cos_embedding = nn.Linear(self.n_cos, self.conv_lstm_out_size)

        self.linear_layersV = nn.Sequential()
        self.linear_layersA = nn.Sequential()

        self.linear_layersV.add_module('fc1V', linear_layer(self.conv_lstm_out_size, self.linear_size))
        self.linear_layersA.add_module('fc1A', linear_layer(self.conv_lstm_out_size, self.linear_size))

        self.linear_layersV.add_module('LN_V', nn.LayerNorm(self.linear_size))
        self.linear_layersA.add_module('LN_A', nn.LayerNorm(self.linear_size))

        self.linear_layersV.add_module('actV', activation())
        self.linear_layersA.add_module('actA', activation())

        self.linear_layersV.add_module('fc2V', linear_layer(self.linear_size, 1))
        self.linear_layersA.add_module('fc2A', linear_layer(self.linear_size, actions))

        self.linear_layers = Dueling(self.linear_layersV, self.linear_layersA)

        self.to(device)

    def _get_conv_out(self, shape):
        o = self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size()))

    def forward(self, inputt, encodings, h_0=None, c_0=None, stored=None):
        """
        Quantile Calculation depending on the number of tau

        Return:
        quantiles [ shape of (batch_size, num_tau, action_size)]
        taus [shape of ((batch_size, num_tau, 1))]

        hidden states have shape [batch_size, 1, 2, lstm_size]
        # 2 is h0 and c0

        pytorch lstms accept a tuple of (1, batch_size, lstm_size)
        therefore, we just need to swap them over

        """

        batch_size = inputt.size()[0]

        inputt = inputt.float() / 255

        ############### random_projection
        # the [:, -1] gets the most recent frame for the entire batch

        # Add a dummy sequence dimension for LSTM
        #encodings = encodings.unsqueeze(1)  # Shape: (batch_size, seq_len=1, output_dim)

        # random projection should have shape (batch_size, 1, input_shape)
        # hidden_states has size (2, 1, batch_size, lstm_out_size)

        if stored is not None:
            lstm_out, (h_n, c_n) = self.lstm(encodings, stored)
        elif h_0 is not None:
            lstm_out, (h_n, c_n) = self.lstm(encodings.unsqueeze(1), (h_0, c_0))
        else:
            # pass entire sequence of encodings through lstm,
            lstm_out, _ = self.lstm(encodings)  # this has a tanh on the end anyway, don't need another activation
            h_n, c_n = None, None

        # get final output
        lstm_out = lstm_out[:, -1, :]

        # I think the two lines below can be removed, there is no 1
        # after the lstm layer, lstm_out has shape [batch, 1, lstm_out_size]. This just removes the 1
        lstm_out = lstm_out.squeeze(dim=1)

        if self.combination == "mult" or self.combination == "add":
            lstm_out = self.lstm_projection(lstm_out)
            lstm_out = self.up_projection_act(lstm_out)

        ##################

        if not self.pretrain_only:
            x = self.conv(inputt)
            x = x.view(batch_size, -1)

            # combine CNN with RISE output
            if self.combination == "mult":
                x = x * lstm_out
            elif self.combination == "add":
                x = x + lstm_out
            elif self.combination == "concat":
                x = torch.concat((x, lstm_out), dim=-1)
        else:
            x = lstm_out

        cos, taus = self.calc_cos(batch_size, self.num_tau)  # cos shape (batch, num_tau, layer_size)
        cos = cos.view(batch_size * self.num_tau, self.n_cos)
        cos_x = torch.relu(self.cos_embedding(cos)).view(batch_size, self.num_tau, self.conv_lstm_out_size)  # (batch, n_tau, layer)

        # x has shape (batch, layer_size) for multiplication –> reshape to (batch, 1, layer)
        x = (x.unsqueeze(1) * cos_x).view(batch_size * self.num_tau, self.conv_lstm_out_size)

        out = self.linear_layers(x)

        return out.view(batch_size, self.num_tau, self.actions), taus, h_n, c_n

    def qvals(self, inputs, encoding, h_0, c_0):
        quantiles, _, h_0, c_0 = self.forward(inputs, encoding, h_0, c_0)

        actions = quantiles.mean(dim=1)

        return actions, h_0, c_0

    def calc_cos(self, batch_size, n_tau=8):
        """
        Calculating the cosinus values depending on the number of tau samples
        """
        taus = torch.rand(batch_size, n_tau, 1, device=self.pis.device) #(batch_size, n_tau, 1)

        cos = torch.cos(taus*self.pis)

        return cos, taus

    def save_checkpoint(self, name):
        #print('... saving checkpoint ...')
        torch.save(self.state_dict(), name + ".model")

    def load_checkpoint(self, name):
        #print('... loading checkpoint ...')
        self.load_state_dict(torch.load(name))

#################################### Full Res Wrapper
class MultiObsAtariPreprocessing(gym.wrappers.AtariPreprocessing):

    def __init__(self, env, **kwargs):
        super().__init__(env, **kwargs)

        self.rgb_obs = np.zeros((210, 160, 3), dtype=np.uint8)
        self.observation_space = Tuple((self.observation_space, Box(low=0, high=255, shape=(210, 160, 3), dtype=np.uint8)))

    def _get_obs(self):
        self.ale.getScreenRGB(self.rgb_obs)
        return super()._get_obs(), self.rgb_obs


class MultiObsAtariFrameStackObservation(gym.Wrapper, gym.utils.RecordConstructorArgs):

    def __init__(self, env: gym.Env, stack_size: int):
        gym.utils.RecordConstructorArgs.__init__(self, stack_size=stack_size)
        gym.Wrapper.__init__(self, env)

        if not np.issubdtype(type(stack_size), np.integer):
            raise TypeError(
                f"The stack_size is expected to be an integer, actual type: {type(stack_size)}"
            )
        if not 0 < stack_size:
            raise ValueError(
                f"The stack_size needs to be greater than zero, actual value: {stack_size}"
            )

        self.stack_size = stack_size
        self.obs_stack = np.zeros((stack_size, 84, 84), dtype=np.uint8)
        self.observation_space = Tuple((Box(low=0, high=255, shape=(4, 84, 84), dtype=np.uint8),
                                        Box(low=0, high=255, shape=(210, 160, 3), dtype=np.uint8)))

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)

        self.obs_stack = np.roll(self.obs_stack, -1, axis=0)
        self.obs_stack[-1] = obs[0]

        return (self.obs_stack, obs[1]), reward, terminated, truncated, info

    def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None):
        obs, info = self.env.reset(seed=seed, options=options)

        self.obs_stack = np.zeros((self.stack_size, 84, 84), dtype=np.uint8)
        self.obs_stack[-1] = obs[0]

        return (self.obs_stack, obs[1]), info


################# Atari Processing Section (by default, this is effectively identical to that of gymnasium's wrapper)

__all__ = ["AtariPreprocessing"]

class AtariPreprocessingCustom(gym.Wrapper, gym.utils.RecordConstructorArgs):
    """Implements the common preprocessing techniques for Atari environments (excluding frame stacking).

    For frame stacking use :class:`gymnasium.wrappers.FrameStackObservation`.
    No vector version of the wrapper exists

    This class follows the guidelines in Machado et al. (2018),
    "Revisiting the Arcade Learning Environment: Evaluation Protocols and Open Problems for General Agents".

    Specifically, the following preprocess stages applies to the atari environment:

    - Noop Reset: Obtains the initial state by taking a random number of no-ops on reset, default max 30 no-ops.
    - Frame skipping: The number of frames skipped between steps, 4 by default.
    - Max-pooling: Pools over the most recent two observations from the frame skips.
    - Termination signal when a life is lost: When the agent losses a life during the environment, then the environment is terminated.
        Turned off by default. Not recommended by Machado et al. (2018).
    - Resize to a square image: Resizes the atari environment original observation shape from 210x180 to 84x84 by default.
    - Grayscale observation: Makes the observation greyscale, enabled by default.
    - Grayscale new axis: Extends the last channel of the observation such that the image is 3-dimensional, not enabled by default.
    - Scale observation: Whether to scale the observation between [0, 1) or [0, 255), not scaled by default.

    Example:
        >>> import gymnasium as gym # doctest: +SKIP
        >>> env = gym.make("ALE/Adventure-v5") # doctest: +SKIP
        >>> env = AtariPreprocessing(env, noop_max=10, frame_skip=0, screen_size=84, terminal_on_life_loss=True, grayscale_obs=False, grayscale_newaxis=False) # doctest: +SKIP

    Change logs:
     * Added in gym v0.12.2 (gym #1455)


     This version is slightly modified, to include the option to use life information.
     This however is NOT enabled by default, and should NEVER be used when comparing in papers.
     If you use this setting, you are scumbag.
    """

    def __init__(
        self,
        env: gym.Env,
        noop_max: int = 30,
        frame_skip: int = 4,
        screen_size: int = 84,
        terminal_on_life_loss: bool = False,
        life_information: bool = False,
        grayscale_obs: bool = True,
        grayscale_newaxis: bool = False,
        scale_obs: bool = False,
    ):
        """Wrapper for Atari 2600 preprocessing.

        Args:
            env (Env): The environment to apply the preprocessing
            noop_max (int): For No-op reset, the max number no-ops actions are taken at reset, to turn off, set to 0.
            frame_skip (int): The number of frames between new observation the agents observations effecting the frequency at which the agent experiences the game.
            screen_size (int): resize Atari frame.
            terminal_on_life_loss (bool): `if True`, then :meth:`step()` returns `terminated=True` whenever a
                life is lost.
            grayscale_obs (bool): if True, then gray scale observation is returned, otherwise, RGB observation
                is returned.
            grayscale_newaxis (bool): `if True and grayscale_obs=True`, then a channel axis is added to
                grayscale observations to make them 3-dimensional.
            scale_obs (bool): if True, then observation normalized in range [0,1) is returned. It also limits memory
                optimization benefits of FrameStack Wrapper.

        Raises:
            DependencyNotInstalled: opencv-python package not installed
            ValueError: Disable frame-skipping in the original env
        """
        gym.utils.RecordConstructorArgs.__init__(
            self,
            noop_max=noop_max,
            frame_skip=frame_skip,
            screen_size=screen_size,
            terminal_on_life_loss=terminal_on_life_loss,
            grayscale_obs=grayscale_obs,
            grayscale_newaxis=grayscale_newaxis,
            scale_obs=scale_obs,
        )
        gym.Wrapper.__init__(self, env)

        try:
            import cv2  # noqa: F401
        except ImportError:
            raise gym.error.DependencyNotInstalled(
                "opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari"
            )

        assert frame_skip > 0
        assert screen_size > 0
        assert noop_max >= 0
        if frame_skip > 1 and getattr(env.unwrapped, "_frameskip", None) != 1:
            raise ValueError(
                "Disable frame-skipping in the original env. Otherwise, more than one frame-skip will happen as through this wrapper"
            )
        self.noop_max = noop_max
        assert env.unwrapped.get_action_meanings()[0] == "NOOP"

        self.frame_skip = frame_skip
        self.screen_size = screen_size
        self.terminal_on_life_loss = terminal_on_life_loss
        self.life_information = life_information
        self.grayscale_obs = grayscale_obs
        self.grayscale_newaxis = grayscale_newaxis
        self.scale_obs = scale_obs

        # buffer of most recent two observations for max pooling
        assert isinstance(env.observation_space, Box)
        if grayscale_obs:
            self.obs_buffer = [
                np.empty(env.observation_space.shape[:2], dtype=np.uint8),
                np.empty(env.observation_space.shape[:2], dtype=np.uint8),
            ]
        else:
            self.obs_buffer = [
                np.empty(env.observation_space.shape, dtype=np.uint8),
                np.empty(env.observation_space.shape, dtype=np.uint8),
            ]

        self.lives = 0
        self.game_over = False

        _low, _high, _obs_dtype = (
            (0, 255, np.uint8) if not scale_obs else (0, 1, np.float32)
        )
        _shape = (screen_size, screen_size, 1 if grayscale_obs else 3)
        if grayscale_obs and not grayscale_newaxis:
            _shape = _shape[:-1]  # Remove channel axis
        self.observation_space = Box(
            low=_low, high=_high, shape=_shape, dtype=_obs_dtype
        )

    @property
    def ale(self):
        """Make ale as a class property to avoid serialization error."""
        return self.env.unwrapped.ale

    def step(
        self, action: WrapperActType
    ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
        """Applies the preprocessing for an :meth:`env.step`."""
        total_reward, terminated, truncated, info = 0.0, False, False, {}

        for t in range(self.frame_skip):
            _, reward, terminated, truncated, info = self.env.step(action)
            total_reward += reward
            self.game_over = terminated

            # using agent terminal without resetting env require terminal_on_life_loss == False

            if self.terminal_on_life_loss:
                new_lives = self.ale.lives()
                terminated = terminated or new_lives < self.lives
                self.game_over = terminated
                self.lives = new_lives
                info["lost_life"] = False
            elif self.life_information:
                # code allows telling agent if they lost a life, so they can learn from terminal
                new_lives = self.ale.lives()
                info["lost_life"] = new_lives < self.lives
                self.lives = new_lives

                # added this line to not finish the 4 actions and overwrite info!
                if info["lost_life"]:
                    break
            else:
                info["lost_life"] = False

            if terminated or truncated:
                break
            if t == self.frame_skip - 2:
                if self.grayscale_obs:
                    self.ale.getScreenGrayscale(self.obs_buffer[1])
                else:
                    self.ale.getScreenRGB(self.obs_buffer[1])
            elif t == self.frame_skip - 1:
                if self.grayscale_obs:
                    self.ale.getScreenGrayscale(self.obs_buffer[0])
                else:
                    self.ale.getScreenRGB(self.obs_buffer[0])
        return self._get_obs(), total_reward, terminated, truncated, info

    def reset(
        self, *, seed: int | None = None, options: dict[str, Any] | None = None
    ) -> tuple[WrapperObsType, dict[str, Any]]:
        """Resets the environment using preprocessing."""
        # NoopReset
        _, reset_info = self.env.reset(seed=seed, options=options)

        noops = (
            self.env.unwrapped.np_random.integers(1, self.noop_max + 1)
            if self.noop_max > 0
            else 0
        )
        for _ in range(noops):
            _, _, terminated, truncated, step_info = self.env.step(0)
            reset_info.update(step_info)
            if terminated or truncated:
                _, reset_info = self.env.reset(seed=seed, options=options)

        self.lives = self.ale.lives()
        if self.grayscale_obs:
            self.ale.getScreenGrayscale(self.obs_buffer[0])
        else:
            self.ale.getScreenRGB(self.obs_buffer[0])
        self.obs_buffer[1].fill(0)

        return self._get_obs(), reset_info

    def _get_obs(self):
        if self.frame_skip > 1:  # more efficient in-place pooling
            np.maximum(self.obs_buffer[0], self.obs_buffer[1], out=self.obs_buffer[0])

        import cv2

        obs = cv2.resize(
            self.obs_buffer[0],
            (self.screen_size, self.screen_size),
            interpolation=cv2.INTER_AREA,
        )

        if self.scale_obs:
            obs = np.asarray(obs, dtype=np.float32) / 255.0
        else:
            obs = np.asarray(obs, dtype=np.uint8)

        if self.grayscale_obs and self.grayscale_newaxis:
            obs = np.expand_dims(obs, axis=-1)  # Add a channel axis
        return obs

################# Now Entering the Prioritized Experience Replay Section

# SumTree
# a binary tree data structure where the parent’s value is the sum of its children
class SumTree():
  def __init__(self, size, procgen=False):
    self.index = 0
    self.size = size
    self.full = False  # Used to track actual capacity
    self.tree_start = 2**(size-1).bit_length()-1  # Put all used node leaves on last tree level
    self.sum_tree = np.zeros((self.tree_start + self.size,), dtype=np.float32)
    self.max = 1  # Initial max value to return (1 = 1^ω)

  # Updates nodes values from current tree
  def _update_nodes(self, indices):
    children_indices = indices * 2 + np.expand_dims([1, 2], axis=1)
    self.sum_tree[indices] = np.sum(self.sum_tree[children_indices], axis=0)

  # Propagates changes up tree given tree indices
  def _propagate(self, indices):
    parents = (indices - 1) // 2
    unique_parents = np.unique(parents)
    self._update_nodes(unique_parents)
    if parents[0] != 0:
      self._propagate(parents)

  # Propagates single value up tree given a tree index for efficiency
  def _propagate_index(self, index):
    parent = (index - 1) // 2
    left, right = 2 * parent + 1, 2 * parent + 2
    self.sum_tree[parent] = self.sum_tree[left] + self.sum_tree[right]
    if parent != 0:
      self._propagate_index(parent)

  # Updates values given tree indices
  def update(self, indices, values):
    self.sum_tree[indices] = values  # Set new values
    self._propagate(indices)  # Propagate values
    current_max_value = np.max(values)
    self.max = max(current_max_value, self.max)

  # Updates single value given a tree index for efficiency
  def _update_index(self, index, value):
    self.sum_tree[index] = value  # Set new value
    self._propagate_index(index)  # Propagate value
    self.max = max(value, self.max)

  def append(self, value):
    self._update_index(self.index + self.tree_start, value)  # Update tree
    self.index = (self.index + 1) % self.size  # Update index
    self.full = self.full or self.index == 0  # Save when capacity reached
    self.max = max(value, self.max)

  # Searches for the location of values in sum tree
  def _retrieve(self, indices, values):
    children_indices = (indices * 2 + np.expand_dims([1, 2], axis=1)) # Make matrix of children indices
    # If indices correspond to leaf nodes, return them
    if children_indices[0, 0] >= self.sum_tree.shape[0]:
      return indices
    # If children indices correspond to leaf nodes, bound rare outliers in case total slightly overshoots
    elif children_indices[0, 0] >= self.tree_start:
      children_indices = np.minimum(children_indices, self.sum_tree.shape[0] - 1)
    left_children_values = self.sum_tree[children_indices[0]]
    successor_choices = np.greater(values, left_children_values).astype(np.int32)  # Classify which values are in left or right branches
    successor_indices = children_indices[successor_choices, np.arange(indices.size)] # Use classification to index into the indices matrix
    successor_values = values - successor_choices * left_children_values  # Subtract the left branch values when searching in the right branch
    return self._retrieve(successor_indices, successor_values)

  # Searches for values in sum tree and returns values, data indices and tree indices
  def find(self, values):
    indices = self._retrieve(np.zeros(values.shape, dtype=np.int32), values)
    data_index = indices - self.tree_start
    return (self.sum_tree[indices], data_index, indices)  # Return values, data indices, tree indices

  def total(self):
    return self.sum_tree[0]

class PER:
    def __init__(self, size, device, n, envs, gamma, encoding_size, alpha=0.2, beta=0.4, framestack=4, imagex=84, imagey=84, rgb=False, context=10,
                 lstm_size=256):

        self.st = SumTree(size)
        self.data = [None for _ in range(size)]
        self.index = 0
        self.size = size

        # this is the number of frames, not the number of transitions
        # the technical size to ensure there are errors with overwritten memory in theory is very high-
        # (2*framestack - overlap) * first_states + non_first_states
        # with N=3, framestack=4, size=1M, average ep length 20, we need a total frame storage of around 1.35M
        # this however is still pretty light given it uses discrete memory. Careful when using RGB though, as we don't need as much memory
        if rgb:
            self.storage_size = int(size * 4)
        else:
            self.storage_size = int(size * 1.25)
        self.gamma = gamma
        self.capacity = 0
        self.encoding_size = encoding_size
        self.lstm_size = lstm_size

        self.point_mem_idx = 0

        self.state_mem_idx = 0
        self.reward_mem_idx = 0
        self.encoding_mem_idx = 0
        self.stored_mem_idx = 0

        self.imagex = imagex
        self.imagey = imagey

        self.max_prio = 1

        self.framestack = framestack

        self.alpha = alpha
        self.beta = beta
        self.eps = 1e-6  # small constant to stop 0 probability
        self.device = device

        self.last_terminal = [True for i in range(envs)]
        self.tstep_counter = [0 for i in range(envs)]

        self.context = context

        self.n_step = n
        self.state_buffer = [[] for i in range(envs)]
        self.stored_buffer = [[] for i in range(envs)]
        self.reward_buffer = [[] for i in range(envs)]
        self.encodings_buffer = [[] for i in range(envs)]

        if rgb:
            self.state_mem = np.zeros((self.storage_size, 3, self.imagex, self.imagey), dtype=np.uint8)
        else:
            self.state_mem = np.zeros((self.storage_size, self.imagex, self.imagey), dtype=np.uint8)
        self.action_mem = np.zeros(self.storage_size, dtype=np.int64)
        self.reward_mem = np.zeros(self.storage_size, dtype=float)
        self.done_mem = np.zeros(self.storage_size, dtype=bool)
        self.trun_mem = np.zeros(self.storage_size, dtype=bool)
        self.encodings_mem = np.zeros((self.storage_size, self.encoding_size), dtype=np.float32)
        # h_0 and c_0
        self.stored_state_mem = np.zeros((self.storage_size, 2, self.lstm_size), dtype=np.float32)

        # everything here is stored as ints as they are just pointers to the actual memory
        # reward contains N values. The first value contains the action. The set of N contains the pointers for both
        # the reward and dones
        self.trans_dtype = np.dtype([('state', int, self.framestack), ('n_state', int, self.framestack),
                                     ('reward', int, self.n_step), ('encodings', int, self.context + self.n_step),
                                     ('stored', int, self.n_step + 1)])

        self.blank_trans = (np.zeros(self.framestack, dtype=int), np.zeros(self.framestack, dtype=int),
                            np.zeros(self.n_step, dtype=int), np.zeros(self.context + self.n_step, dtype=int),
                            np.zeros(self.n_step + 1, dtype=int))

        self.pointer_mem = np.array([self.blank_trans] * size, dtype=self.trans_dtype)

        self.overlap = self.framestack - self.n_step

        #self.priority_min = [float('inf') for _ in range(2 * self.size)]
        #print("Prio Size: " + str(len(self.priority_min)))

    def append(self, state, action, reward, n_state, done, trun, encodings, next_encodings, stream, prio=True,
               stored=None, next_stored=None):

        # append to memory
        self.append_memory(state, action, reward, n_state, done, trun, encodings, next_encodings, stream, stored, next_stored)

        # append to pointer
        self.append_pointer(stream, prio)

        if done or trun:
            self.finalize_experiences(stream)
            self.state_buffer[stream] = []
            self.stored_buffer[stream] = []
            self.reward_buffer[stream] = []
            self.encodings_buffer[stream] = []

        self.last_terminal[stream] = done or trun  # should this be or trun? TPPPPP

    # def _set_priority_min(self, idx, priority_alpha):
    #     idx += self.size
    #     self.priority_min[idx] = priority_alpha
    #     while idx >= 2:
    #         idx //= 2
    #         self.priority_min[idx] = min(self.priority_min[2 * idx], self.priority_min[2 * idx + 1])

    def append_pointer(self, stream, prio):

        while len(self.state_buffer[stream]) >= self.framestack + self.n_step and len(self.reward_buffer[stream]) >= self.n_step:
            # First array in the experience
            state_array = self.state_buffer[stream][:self.framestack]

            # Second array in the experience (starts after N frames)
            n_state_array = self.state_buffer[stream][self.n_step:self.n_step + self.framestack]

            # Reward array (first N rewards)
            reward_array = self.reward_buffer[stream][:self.n_step]

            encodings_array = self.encodings_buffer[stream][:self.n_step + self.context]

            stored_array = self.stored_buffer[stream][:self.n_step + 1]

            #print("Added Experience: (" + str(self.point_mem_idx) + ")")
            #print((np.array(state_array, dtype=int), np.array(n_state_array, dtype=int), np.array(reward_array, dtype=int)))

            # Add the experience to the list
            self.pointer_mem[self.point_mem_idx] = (np.array(state_array, dtype=int), np.array(n_state_array, dtype=int),
                                                np.array(reward_array, dtype=int), np.array(encodings_array, dtype=int),
                                                    np.array(stored_array))

            #self._set_priority_min(self.point_mem_idx, sqrt(self.max_prio))
            self.st.append(self.max_prio ** self.alpha)

            self.capacity = min(self.size, self.capacity + 1)
            self.point_mem_idx = (self.point_mem_idx + 1) % self.size

            # Remove the first state and reward from the buffers to slide the window
            self.state_buffer[stream].pop(0)
            self.reward_buffer[stream].pop(0)
            self.encodings_buffer[stream].pop(0)
            self.stored_buffer[stream].pop(0)
            self.beta = 0

    def finalize_experiences(self, stream):
        # Process remaining states and rewards at the end of an episode
        while len(self.state_buffer[stream]) >= self.framestack and len(self.reward_buffer[stream]) > 0:
            # First array in the experience
            first_array = self.state_buffer[stream][:self.framestack]
            #print(first_array)

            # Second array in the experience (Final `framestack` elements)
            second_array = self.state_buffer[stream][-self.framestack:]

            # encodings array
            encodings_array = self.encodings_buffer[stream][:]
            while len(encodings_array) < self.n_step + self.context:
                # add zeros at the end these won't be used as they are apart of next_state
                encodings_array.extend([0])

            # Reward array
            reward_array = self.reward_buffer[stream][:]
            while len(reward_array) < self.n_step:
                reward_array.extend([0])

            # stored array
            stored_array = self.stored_buffer[stream][:]
            while len(stored_array) < self.n_step + 1:
                stored_array.extend([0])

            # Add the experience
            self.pointer_mem[self.point_mem_idx] = (np.array(first_array, dtype=int), np.array(second_array, dtype=int),
                                                    np.array(reward_array, dtype=int), np.array(encodings_array, dtype=int),
                                                    np.array(stored_array, dtype=int))

            #self._set_priority_min(self.point_mem_idx, sqrt(self.max_prio))
            self.st.append(self.max_prio ** self.alpha)

            self.point_mem_idx = (self.point_mem_idx + 1) % self.size
            self.capacity = min(self.size, self.capacity + 1)

            # Remove the first state and reward from the buffers to slide the window
            self.state_buffer[stream].pop(0)
            self.encodings_buffer[stream].pop(0)
            self.stored_buffer[stream].pop(0)

            if len(self.reward_buffer[stream]) > 0:
                self.reward_buffer[stream].pop(0)

    def append_memory(self, state, action, reward, n_state, done, trun, encodings, next_encodings, stream, stored,
                      next_stored):

        if self.last_terminal[stream]:
            # add full transition
            for i in range(self.framestack):
                self.state_mem[self.state_mem_idx] = state[i]
                self.state_buffer[stream].append(self.state_mem_idx)
                self.state_mem_idx = (self.state_mem_idx + 1) % self.storage_size

            # remember n_step is not applied in this memory
            self.state_mem[self.state_mem_idx] = n_state[self.framestack - 1]
            self.state_buffer[stream].append(self.state_mem_idx)
            self.state_mem_idx = (self.state_mem_idx + 1) % self.storage_size

            self.action_mem[self.reward_mem_idx] = action
            self.reward_mem[self.reward_mem_idx] = reward
            self.done_mem[self.reward_mem_idx] = done
            self.trun_mem[self.reward_mem_idx] = trun

            self.reward_buffer[stream].append(self.reward_mem_idx)
            self.reward_mem_idx = (self.reward_mem_idx + 1) % self.storage_size

            self.stored_state_mem[self.stored_mem_idx] = stored
            self.stored_buffer[stream].append(self.stored_mem_idx)
            self.stored_mem_idx = (self.stored_mem_idx + 1) % self.storage_size

            self.stored_state_mem[self.stored_mem_idx] = next_stored
            self.stored_buffer[stream].append(self.stored_mem_idx)
            self.stored_mem_idx = (self.stored_mem_idx + 1) % self.storage_size

            # repeat encodings "context" times
            for i in range(self.context):  # context then current state
                self.encodings_mem[self.encoding_mem_idx] = encodings
                self.encodings_buffer[stream].append(self.encoding_mem_idx)
                self.encoding_mem_idx = (self.encoding_mem_idx + 1) % self.storage_size

            self.encodings_mem[self.encoding_mem_idx] = next_encodings
            self.encodings_buffer[stream].append(self.encoding_mem_idx)
            self.encoding_mem_idx = (self.encoding_mem_idx + 1) % self.storage_size

            self.tstep_counter[stream] = 0

        else:
            # just add relevant info
            self.state_mem[self.state_mem_idx] = n_state[self.framestack - 1]
            self.state_buffer[stream].append(self.state_mem_idx)
            self.state_mem_idx = (self.state_mem_idx + 1) % self.storage_size

            self.action_mem[self.reward_mem_idx] = action
            self.stored_state_mem[self.reward_mem_idx] = stored
            self.reward_mem[self.reward_mem_idx] = reward
            self.done_mem[self.reward_mem_idx] = done
            self.trun_mem[self.reward_mem_idx] = trun

            self.reward_buffer[stream].append(self.reward_mem_idx)
            self.reward_mem_idx = (self.reward_mem_idx + 1) % self.storage_size

            self.stored_state_mem[self.stored_mem_idx] = next_stored
            self.stored_buffer[stream].append(self.stored_mem_idx)
            self.stored_mem_idx = (self.stored_mem_idx + 1) % self.storage_size

            self.encodings_mem[self.encoding_mem_idx] = next_encodings
            self.encodings_buffer[stream].append(self.encoding_mem_idx)
            self.encoding_mem_idx = (self.encoding_mem_idx + 1) % self.storage_size

    def sample(self, batch_size, count=0):

        # get total sumtree priority
        p_total = self.st.total()

        # first use sumtree prios to get the indices
        segment_length = p_total / batch_size
        segment_starts = np.arange(batch_size) * segment_length
        try:
            samples = np.random.uniform(0.0, segment_length, [batch_size]) + segment_starts
        except Exception as e:
            print(segment_length)
            print(segment_starts)
            print(e)
            raise Exception("Stop")

        prios, idxs, tree_idxs = self.st.find(samples)

        probs = prios / p_total

        # fetch the pointers by using indices
        pointers = self.pointer_mem[idxs]
        #print("Pointers")
        #print(pointers)

        # Extract the pointers into separate arrays
        state_pointers = np.array([p[0] for p in pointers])
        n_state_pointers = np.array([p[1] for p in pointers])
        reward_pointers = np.array([p[2] for p in pointers])

        storeds_pointers = np.array([p[4][0] for p in pointers])
        next_storeds_pointers = np.array([p[4][-1] for p in pointers])

        if self.n_step > 1:
            action_pointers = np.array([p[2][0] for p in pointers])
        else:
            action_pointers = np.array([p[2] for p in pointers])

        encoding_pointers = np.array([p[3] for p in pointers])

        # get state info
        states = torch.tensor(self.state_mem[state_pointers], dtype=torch.uint8)
        n_states = torch.tensor(self.state_mem[n_state_pointers], dtype=torch.uint8)

        # reward and dones just use the same pointer. actions just use the first one
        rewards = self.reward_mem[reward_pointers]
        dones = self.done_mem[reward_pointers]
        truns = self.trun_mem[reward_pointers]
        actions = self.action_mem[action_pointers]
        storeds = self.stored_state_mem[storeds_pointers]
        next_storeds = self.stored_state_mem[next_storeds_pointers]
        encodings = self.encodings_mem[encoding_pointers]

        # print("Sample")
        # print(test.shape)
        # print(np.sum(test, axis=-1))
        # print(np.sum(encodings, axis=-1))
        # print("-------------")

        # apply n_step cumulation to rewards and dones
        if self.n_step > 1:
            rewards, dones = self.compute_discounted_rewards_batch(rewards, dones, truns)

        #prob_min = self.priority_min[1] / p_total
        #max_weight = (prob_min * self.capacity) ** (-self.beta)

        # Compute importance-sampling weights w
        weights = (self.capacity * probs) ** -self.beta

        weights = torch.tensor(weights / weights.max(), dtype=torch.float32,
                               device=self.device)  # Normalise by max importance-sampling weight from batch

        if torch.isnan(weights).any():
            print("Nan Found is sample!")
            print(f"Prios {prios}")
            print(f"Probs {probs}")
            print(f"Weights {weights}")
            if count >= 5:
                raise Exception("5 Nans")
            return self.sample(batch_size, count + 1)

        # move to pytorch GPU tensors
        states = states.to(torch.float32).to(self.device)
        n_states = n_states.to(torch.float32).to(self.device)
        rewards = torch.tensor(rewards, dtype=torch.float32, device=self.device)
        dones = torch.tensor(dones, dtype=torch.bool, device=self.device)
        actions = torch.tensor(actions, dtype=torch.int64, device=self.device)
        storeds = torch.tensor(storeds, dtype=torch.float32, device=self.device)
        next_storeds = torch.tensor(next_storeds, dtype=torch.float32, device=self.device)

        # current_encodings = torch.tensor(encodings[:, -(self.n_step + 1)], dtype=torch.float32, device=self.device)
        # next_encodings = torch.tensor(encodings[:, -1], dtype=torch.float32, device=self.device)
        # burn_encodings = torch.tensor(encodings[:, :self.burn_in], dtype=torch.float32, device=self.device)
        encodings = torch.tensor(encodings, dtype=torch.float32, device=self.device)

        # return batch
        return tree_idxs, states, actions, rewards, n_states, dones, encodings, weights, storeds, next_storeds

    def compute_discounted_rewards_batch(self, rewards_batch, dones_batch, truns_batch):
        """
        Compute discounted rewards for a batch of rewards and dones.

        Parameters:
        rewards_batch (np.ndarray): 2D array of rewards with shape (batch_size, n_step)
        dones_batch (np.ndarray): 2D array of dones with shape (batch_size, n_step)

        Returns:
        np.ndarray: 1D array of discounted rewards for each element in the batch
        np.ndarray: 1D array of cumulative dones (True if any done is True in the sequence)
        """
        batch_size, n_step = rewards_batch.shape
        discounted_rewards = np.zeros(batch_size)
        cumulative_dones = np.zeros(batch_size, dtype=bool)

        for i in range(batch_size):
            cumulative_discount = 1
            for j in range(n_step):
                discounted_rewards[i] += cumulative_discount * rewards_batch[i, j]
                if dones_batch[i, j] == 1:
                    cumulative_dones[i] = True
                    break
                elif truns_batch[i, j] == 1:
                    break
                cumulative_discount *= self.gamma

        return discounted_rewards, cumulative_dones

    def update_priorities(self, idxs, priorities):
        priorities = priorities + self.eps

        # for idx, priority in zip(idxs, priorities):
        #     self._set_priority_min(idx - self.size + 1, sqrt(priority))

        if np.isnan(priorities).any():
            print("NaN found in priority!")
            print(f"priorities: {priorities}")

        self.max_prio = max(self.max_prio, np.max(priorities))
        self.st.update(idxs, priorities ** self.alpha)


############## A few smaller functions to assist the main program

class EpsilonGreedy:
    def __init__(self, eps_start, eps_steps, eps_final, action_space):
        self.eps = eps_start
        self.steps = eps_steps
        self.eps_final = eps_final
        self.action_space = action_space

    def update_eps(self):
        self.eps = max(self.eps - (self.eps - self.eps_final) / self.steps, self.eps_final)

    def choose_action(self):
        if np.random.random() > self.eps:
            return None
        else:
            return np.random.choice(self.action_space)


def randomise_action_batch(x, probs, n_actions):
    mask = torch.rand(x.shape) < probs

    # Generate random values to replace the selected elements
    random_values = torch.randint(0, n_actions, x.shape)

    # Apply the mask to replace elements in the tensor with random values
    x[mask] = random_values[mask]

    return x


def choose_eval_action(observation, eval_net, n_actions, device, rng):
    with torch.no_grad():
        state = T.tensor(observation, dtype=T.float).to(device)
        qvals = eval_net.qvals(state, advantages_only=True)
        x = T.argmax(qvals, dim=1).cpu()

        if rng > 0.:
            # Generate a mask with the given probability
            x = randomise_action_batch(x, 0.01, n_actions)

    return x


def create_network(input_dims, n_actions, device, model_size, maxpool_size,
                   linear_size, num_tau, ncos, layer_norm=True, lstm_out_size=256,
                   encoding_size=512, combination="mult", pretrain_only=False, maxpool=True):

    return ImpalaCNNLargeIQN(input_dims[0], n_actions, device=device,
                             model_size=model_size, num_tau=num_tau, maxpool_size=maxpool_size,
                             linear_size=linear_size, ncos=ncos, layer_norm=layer_norm, lstm_out_size=lstm_out_size,
                             encoding_size=encoding_size, combination=combination, pretrain_only=pretrain_only,
                             maxpool=maxpool)


#################### The big ol agent class, be prepared

class Agent:
    def __init__(self, n_actions, input_dims, device, num_envs, agent_name, total_frames, testing=False, batch_size=256
                 , rr=1, maxpool_size=6, lr=1e-4, target_replace=500, discount=0.997, taus=8, model_size=2,
                 linear_size=512, ncos=64, non_factorised=False, replay_period=1, framestack=4, rgb=False, imagex=84,
                 imagey=84, per_alpha=0.2, max_mem_size=1048576, eps_steps=2000000, eps_disable=True, n=3,
                 munch_alpha=0.9, grad_clip=10, layer_norm=True, lstm_out_size=256, context=40, encoding_type="RP",
                 combination="mult", encoding_size=512, pretrain_only=False, lower_lr=False, full_res=False,
                 maxpool=True, burn=10):

        self.per_alpha = per_alpha

        self.procgen = True if input_dims[1] == 64 else False
        self.grad_clip = grad_clip

        self.n_actions = n_actions
        self.input_dims = input_dims
        self.device = device
        self.agent_name = agent_name
        self.testing = testing
        self.maxpool = maxpool
        self.burn = burn

        self.full_res = full_res

        self.layer_norm = layer_norm

        self.loading_checkpoint = False

        self.per_beta = 0.45

        self.lower_lr = lower_lr

        self.lstm_out_size = lstm_out_size
        self.encodings_size = encoding_size
        if encoding_type == "Downsample":
            self.encodings_size = 784

        self.combination = combination

        self.replay_ratio = int(rr) if rr > 0.99 else float(rr)
        self.total_frames = total_frames
        self.num_envs = num_envs

        if self.testing:
            self.min_sampling_size = 30
        else:
            self.min_sampling_size = 200000

        self.lr = lr

        # this is the number of env steps per grad step
        self.replay_period = replay_period

        # replay ratio however does not take into account parallel envs

        # in this code, every {replay period} steps, we take {replay_ratio} grad steps

        self.total_grad_steps = (self.total_frames - self.min_sampling_size) / (self.replay_period / self.replay_ratio)

        self.priority_weight_increase = (1 - self.per_beta) / self.total_grad_steps

        self.action_space = [i for i in range(self.n_actions)]
        self.learn_step_counter = 0

        self.chkpt_dir = ""

        self.n = n
        self.gamma = discount
        self.discount_anneal = False
        self.batch_size = batch_size

        self.model_size = model_size  # Scaling of IMPALA network
        self.maxpool_size = maxpool_size

        # this option is only available for non-impala. I could add it, but factorised seemed
        # to perform the same and is faster
        self.non_factorised = non_factorised

        self.ncos = ncos

        self.entropy_tau = 0.03
        self.lo = -1
        self.alpha = munch_alpha

        # 1 Million rounded to the nearest power of 2 for tree implementation
        self.max_mem_size = max_mem_size

        self.replace_target_cnt = target_replace  # This is the number of grad steps - could be a little jank
        # when changing num_envs/batch size/replay ratio

        # Best used value is 32000 frames per replace. For bs 256, this is 500. For bs 16, this is every 8000!

        self.num_tau = taus

        if not self.loading_checkpoint and not self.testing:
            self.eps_start = 1.0
            # divided by 4 is due to frameskip
            self.eps_steps = eps_steps
            self.eps_final = 0.01
        else:
            self.eps_start = 0.00
            self.eps_steps = eps_steps
            self.eps_final = 0.00

        self.eps_disable = eps_disable
        self.epsilon = EpsilonGreedy(self.eps_start, self.eps_steps, self.eps_final, self.action_space)

        self.linear_size = linear_size
        self.imagex = imagex
        self.imagey = imagey

        self.framestack = framestack
        self.rgb = rgb
        self.context = context
        self.encoding_type = encoding_type
        self.memory = PER(self.max_mem_size, device, self.n, num_envs, self.gamma, self.encodings_size, alpha=self.per_alpha,
                          beta=self.per_beta, framestack=self.framestack, rgb=self.rgb, imagex=imagex, imagey=imagey, context=self.context + self.burn,
                          lstm_size=lstm_out_size)

        self.pretrain_only = pretrain_only

        self.network_creator_fn = partial(create_network, self.input_dims, self.n_actions, self.device, self.model_size,
                                          self.maxpool_size, self.linear_size, self.num_tau, self.ncos,
                                          layer_norm=self.layer_norm, lstm_out_size=self.lstm_out_size,
                                          encoding_size=self.encodings_size, combination=self.combination,
                                          pretrain_only=pretrain_only, maxpool=self.maxpool)

        self.net = self.network_creator_fn()
        self.tgt_net = self.network_creator_fn()

        if not self.lower_lr:
            self.optimizer = optim.Adam(self.net.parameters(), lr=self.lr, eps=0.005 / self.batch_size)  # 0.00015
        else:
            # Separate the parameters into two groups using their names.
            custom_params = []
            other_params = []

            for name, param in self.net.named_parameters():
                if "lstm" in name:
                    custom_params.append(param)
                else:
                    other_params.append(param)

            self.optimizer = optim.Adam([
                {'params': other_params},  # Uses default_lr (1e-3)
                {'params': custom_params, 'lr': self.lr/10}  # Uses custom_lr (1e-4)
            ], lr=self.lr, eps=0.005 / self.batch_size)

        if self.full_res:
            self.resnet_extractor = ResNet18FullRes(device=self.device, projection_size=self.encodings_size)
        elif self.encoding_type == "RP":
            # Random projection matrix
            projection_matrix = nn.Parameter(
                torch.randn(self.encodings_size, self.imagex * self.imagey), requires_grad=False
            )

            projection_matrix = F.normalize(projection_matrix, p=2, dim=1)

            # Define the sequential model
            self.random_projection = nn.Sequential(
                nn.Flatten(),  # Flatten the input (batch_size, 84, 84) -> (batch_size, 7056)
                nn.Linear(self.encodings_size, self.imagex * self.imagey, bias=False),  # Linear layer as projection
                nn.ReLU(),  # Add non-linearity after projection
                # clamp here
            )

            # Set the projection matrix as the weight of the Linear layer
            self.random_projection[1].weight.data = projection_matrix

            self.random_projection.to(self.device)
        elif self.encoding_type == "FFB":
            # Sobel, Laplacian, Gaussian filters
            self.sobel_x_filter = torch.tensor([[[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]]], dtype=torch.float32,
                                          device=self.device)
            self.sobel_y_filter = torch.tensor([[[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]]], dtype=torch.float32,
                                          device=self.device)
            self.laplacian_filter = torch.tensor([[[[0, 1, 0], [1, -4, 1], [0, 1, 0]]]], dtype=torch.float32,
                                            device=self.device)
            self.gaussian_filter = torch.tensor([[[[1, 4, 7, 4, 1],
                                              [4, 16, 26, 16, 4],
                                              [7, 26, 41, 26, 7],
                                              [4, 16, 26, 16, 4],
                                              [1, 4, 7, 4, 1]]]], dtype=torch.float32, device=self.device) / 273

        elif self.encoding_type == "Downsample":
            pass
        elif self.encoding_type == "ResNet18Pen":
            self.resnet_extractor = ResNet18PenultimateFeatureExtractor(device=self.device,
                                                                        projection_size=self.encodings_size)
        elif self.encoding_type == "ResNet18ThirdLast":
            self.resnet_extractor = ResNet18ThirdLastFeatureExtractor(device=self.device,
                                                                      projection_size=self.encodings_size)
        elif self.encoding_type == "OD50":
            self.resnet_extractor = FasterRCNNResNet50FPNPenultimateFeatureExtractor(device=self.device,
                                                                      projection_size=self.encodings_size)
        elif self.encoding_type == "EffNet":
            self.resnet_extractor = EfficientNetV2PenultimateFeatureExtractor(device=self.device,
                                                                              projection_size=self.encodings_size)
        else:
            raise Exception("Invalid Encoding Type")

        self.net.train()

        self.eval_net = None

        for param in self.tgt_net.parameters():
            param.requires_grad = False

        self.env_steps = 0
        self.grad_steps = 0

        self.replay_ratio_cnt = 0
        self.eval_mode = False

    def create_encodings(self, observations):

        with torch.no_grad():

            if self.encoding_type == "RP":
                # this uses a resnet into a random projection

                observations = torch.from_numpy(observations).to(self.device)
                observations = observations / 255

                # performs a random projection on the most recent frame in a batch
                x = self.random_projection(observations[:, -1])
                x = torch.clamp(x, min=0., max=1.0)
                return x

            elif self.encoding_type == "FFB":

                observations = torch.from_numpy(observations).to(self.device)
                observations = observations / 255
                observations = observations[:, -1]
                observations = observations.unsqueeze(1)
                # has shape [batch, 1, 84, 84]

                # Apply convolutions
                sobel_x = torch.nn.functional.conv2d(observations, self.sobel_x_filter, padding=1)
                sobel_y = torch.nn.functional.conv2d(observations, self.sobel_y_filter, padding=1)
                laplacian = torch.nn.functional.conv2d(observations, self.laplacian_filter, padding=1)
                gaussian = torch.nn.functional.conv2d(observations, self.gaussian_filter, padding=2)

                # Flatten and concatenate
                embeddings = torch.cat([
                    sobel_x.flatten(start_dim=1),
                    sobel_y.flatten(start_dim=1),
                    laplacian.flatten(start_dim=1),
                    gaussian.flatten(start_dim=1)
                ], dim=1)

                # Truncate or pad to target_dim
                x = embeddings[:, :self.encodings_size] if embeddings.shape[1] > self.encodings_size else torch.nn.functional.pad(
                    embeddings, (0, self.encodings_size - embeddings.shape[1]))

                return x
            elif self.encoding_type == "Downsample":

                observations = torch.from_numpy(observations).to(self.device)
                observations = observations / 255
                observations = observations[:, -1]
                observations = observations.unsqueeze(1)
                observations = F.interpolate(observations, size=(28, 28), mode='bilinear', align_corners=False)
                observations = torch.flatten(observations, start_dim=1)
                return observations

            elif self.full_res:
                observations = torch.from_numpy(observations).to(self.device)
                observations = observations / 255

                x = self.resnet_extractor(observations)
                return x
            else:
                observations = torch.from_numpy(observations).to(self.device)
                observations = observations / 255

                observations = observations[:, -1]

                observations = observations.unsqueeze(1)

                x = self.resnet_extractor(observations)
                # this is encodings size

                return x

    def prep_evaluation(self):
        self.eval_net = deepcopy(self.net)
        self.disable_noise(self.eval_net)

    @torch.no_grad()
    def reset_noise(self, net):
        for m in net.modules():
            if isinstance(m, FactorizedNoisyLinear):
                m.reset_noise()

    @torch.no_grad()
    def disable_noise(self, net):
        for m in net.modules():
            if isinstance(m, FactorizedNoisyLinear):
                m.disable_noise()

    def choose_action(self, observation, encoding, h_0, c_0):
        # encodings should be a batch of len context

        # this chooses an action for a batch. Can be used with a batch of 1 if needed though
        with T.no_grad():
            if not self.eval_mode:
                self.reset_noise(self.net)

            state = T.tensor(observation, dtype=T.float).to(self.net.device)

            qvals, h_0, c_0 = self.net.qvals(state, encoding, h_0, c_0)
            x = T.argmax(qvals, dim=1).cpu()

            if self.env_steps < self.min_sampling_size or (self.env_steps < self.total_frames / 2 and self.eps_disable):

                probs = self.epsilon.eps
                x = randomise_action_batch(x, probs, self.n_actions)

            return x, h_0, c_0

    def store_transition(self, state, action, reward, next_state, done, trun, encodings, next_encodings, stream, prio=True,
                         stored=None, next_stored=None):

        if self.rgb:
            # expand dims to create "framestack" dim, so it works with my replay buffer
            state = np.expand_dims(state, axis=0)
            next_state = np.expand_dims(next_state, axis=0)

        self.memory.append(state, action, reward, next_state, done, trun, encodings, next_encodings, stream, prio=prio,
                           stored=stored, next_stored=next_stored)

        self.epsilon.update_eps()
        self.env_steps += 1

    def replace_target_network(self):
        self.tgt_net.load_state_dict(self.net.state_dict())

    def save_model(self):
        self.net.save_checkpoint(self.agent_name + "_" + str(int((self.env_steps // 250000))) + "M")

    def load_models(self, name):
        self.net.load_checkpoint(name)
        self.tgt_net.load_checkpoint(name)

    def learn(self):
        if self.replay_ratio < 1:
            if self.replay_ratio_cnt == 0:
                self.learn_call()
            self.replay_ratio_cnt = (self.replay_ratio_cnt + 1) % (int(1 / self.replay_ratio))
        else:
            for i in range(self.replay_ratio):
                self.learn_call()

    def learn_call(self):

        if self.env_steps < self.min_sampling_size:
            return

        self.reset_noise(self.tgt_net)

        if self.grad_steps % self.replace_target_cnt == 0:
            self.replace_target_network()

        idxs, states, actions, rewards, next_states, dones, encodings, weights, storeds, next_storeds = self.memory.sample(self.batch_size)

        self.optimizer.zero_grad()

        # use this code to check your states are correct!
        # If you apply BTR to a custom env and don't check your states first, you are killing both
        # trees and your own time

        # plt.imshow(states[0][0].unsqueeze(dim=0).cpu().permute(1, 2, 0))
        # plt.show()
        #
        # plt.imshow(states[0][1].unsqueeze(dim=0).cpu().permute(1, 2, 0))
        # plt.show()
        #
        # plt.imshow(states[0][2].unsqueeze(dim=0).cpu().permute(1, 2, 0))
        # plt.show()
        #
        # plt.imshow(states[1][0].unsqueeze(dim=0).cpu().permute(1, 2, 0))
        # plt.show()
        #
        # plt.imshow(states[2][0].unsqueeze(dim=0).cpu().permute(1, 2, 0))
        # plt.show()

        # print(current_encodings.shape)  # == (batch, burn_in + context_length, encoding_size)

        # stored state is from before the burn-in period
        h_0 = storeds[:, 0, :].unsqueeze(0)  # (1, batch_size, hidden_size)
        c_0 = storeds[:, 1, :].unsqueeze(0)  # (1, batch_size, hidden_size)

        next_h_0 = next_storeds[:, 0, :].unsqueeze(0)  # (1, batch_size, hidden_size)
        next_c_0 = next_storeds[:, 1, :].unsqueeze(0)  # (1, batch_size, hidden_size)

        current_encodings = encodings[:, :-self.n]
        next_encodings = encodings[:, self.n:]

        burn_in_encodings = current_encodings[:, :self.burn]
        context_encodings = current_encodings[:, self.burn:]

        # print("\nLearn")
        # print(f"Obs: {torch.sum(states[-1], dim=(-1, -2, -3))}")
        # print(f"context encodings: {torch.sum(context_encodings[-1], dim=-1)}")
        # print(f"burn encodings: {torch.sum(burn_in_encodings[-1], dim=-1)}")
        # print(f"H_0: {torch.sum(h_0[0, -1], dim=-1)}")
        #
        # print(f"Next Obs: {torch.sum(next_states[-1], dim=(-1, -2, -3))}")
        # print(f"next encodings: {torch.sum(next_encodings[-1], dim=-1)}")
        # print(f"next H_0: {torch.sum(next_h_0[0, -1], dim=-1)}\n")

        with torch.no_grad():
            # get the burned in hidden and cell state
            _, (h_burn, c_burn) = self.net.lstm(burn_in_encodings, (h_0.contiguous(), c_0.contiguous()))

        q_k, taus, _, _ = self.net(states, context_encodings, stored=(h_burn, c_burn))

        q_k_detached = q_k.detach()

        with torch.no_grad():

            Q_targets_next, _, _, _ = self.tgt_net(next_states, next_encodings, stored=(next_h_0.contiguous(), next_c_0.contiguous()))

            # (batch, num_tau, actions)
            q_t_n = Q_targets_next.mean(dim=1)

            actions = actions.unsqueeze(1)
            rewards = rewards.unsqueeze(1)
            dones = dones.unsqueeze(1)
            weights = weights.unsqueeze(1)

            # calculate log-pi
            logsum = torch.logsumexp(
                (q_t_n - q_t_n.max(1)[0].unsqueeze(-1)) / self.entropy_tau, 1).unsqueeze(-1)  # logsum trick
            # assert logsum.shape == (self.batch_size, 1), "log pi next has wrong shape: {}".format(logsum.shape)
            tau_log_pi_next = (q_t_n - q_t_n.max(1)[0].unsqueeze(-1) - self.entropy_tau * logsum).unsqueeze(1)

            pi_target = F.softmax(q_t_n / self.entropy_tau, dim=1).unsqueeze(1)

            Q_target = (self.gamma ** self.n * (
                    pi_target * (Q_targets_next - tau_log_pi_next) * (~dones.unsqueeze(-1))).sum(2)).unsqueeze(1)

            # assert Q_target.shape == (self.batch_size, 1, self.num_tau)
            q_k_target = q_k_detached.mean(dim=1)
            v_k_target = q_k_target.max(1)[0].unsqueeze(-1)
            tau_log_pik = q_k_target - v_k_target - self.entropy_tau * torch.logsumexp(
                (q_k_target - v_k_target) / self.entropy_tau, 1).unsqueeze(-1)

            # assert tau_log_pik.shape == (self.batch_size, self.n_actions), "shape instead is {}".format(
            # tau_log_pik.shape)
            munchausen_addon = tau_log_pik.gather(1, actions)

            # calc munchausen reward:
            munchausen_reward = (
                    rewards + self.alpha * torch.clamp(munchausen_addon, min=self.lo, max=0)).unsqueeze(-1)
            # assert munchausen_reward.shape == (self.batch_size, 1, 1)
            # Compute Q targets for current states
            Q_targets = munchausen_reward + Q_target

        # Get expected Q values from local model (now done at top)
        Q_expected = q_k.gather(2, actions.unsqueeze(-1).expand(self.batch_size, self.num_tau, 1))
        # assert Q_expected.shape == (self.batch_size, self.num_tau, 1)

        # Quantile Huber loss
        td_error = Q_targets - Q_expected
        loss_v = torch.abs(td_error).sum(dim=1).mean(dim=1).detach()
        # assert td_error.shape == (self.batch_size, self.num_tau, self.num_tau), "wrong td error shape"
        huber_l = calculate_huber_loss(td_error, 1.0, self.num_tau)
        quantil_l = abs(taus - (td_error.detach() < 0).float()) * huber_l / 1.0

        loss = quantil_l.sum(dim=1).mean(dim=1, keepdim=True)  # , keepdim=True if per weights get multipl

        # PER weights
        loss = loss * weights.to(self.net.device)

        loss = loss.mean()

        # update PER prios
        self.memory.update_priorities(idxs, loss_v.cpu().detach().numpy())

        ###### pytorch AMP
        """
        self.scaler.scale(loss).backward()

        # Unscale the gradients before clipping - Gradients get completely zeroed without this
        self.scaler.unscale_(self.optimizer)

        torch.nn.utils.clip_grad_norm_(self.net.parameters(), self.grad_clip)
        self.scaler.step(self.optimizer)
        self.scaler.update()"""
        ##### This is non AMP version
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.net.parameters(), self.grad_clip)
        self.optimizer.step()

        self.grad_steps += 1
        if self.grad_steps % 10000 == 0:
            print("Completed " + str(self.grad_steps) + " gradient steps")

def calculate_huber_loss(td_errors, k=1.0, taus=8):
    """
    Calculate huber loss element-wisely depending on kappa k.
    """
    loss = torch.where(td_errors.abs() <= k, 0.5 * td_errors.pow(2), k * (td_errors.abs() - 0.5 * k))
    assert loss.shape == (td_errors.shape[0], taus, taus), "huber loss has wrong shape"
    return loss

def huber_loss(td_errors, k=1.0):
    """
    Calculate huber loss element-wisely depending on kappa k.
    """
    loss = torch.where(td_errors.abs() <= k, 0.5 * td_errors.pow(2), k * (td_errors.abs() - 0.5 * k))
    return loss

def make_env(envs_create, game, life_info, framestack, repeat_probs):
    return gym.vector.AsyncVectorEnv([lambda: gym.wrappers.FrameStack(
        AtariPreprocessingCustom(gym.make("ALE/" + game + "-v5", frameskip=1, repeat_action_probability=repeat_probs), life_information=life_info), framestack,
        lz4_compress=False) for _ in range(envs_create)], context="spawn")

def make_env_full_res(envs_create, game, life_info, framestack, repeat_probs):
    return gym.vector.AsyncVectorEnv([lambda: MultiObsAtariFrameStackObservation(
        MultiObsAtariPreprocessing(gym.make("ALE/" + game + "-v5", frameskip=1, repeat_action_probability=repeat_probs)), framestack) for _ in range(envs_create)], context="spawn")


def non_default_args(args, parser):
    result = []
    repeat_val = None  # To store the value of 'repeat'

    # Iterate over all arguments
    for arg in vars(args):
        if arg == 'repeat':
            # Store the 'repeat' value and skip adding it in the main loop
            repeat_val = getattr(args, arg)
            continue

        user_val = getattr(args, arg)
        default_val = parser.get_default(arg)

        # Check if the argument should be included
        if (user_val != default_val and
                default_val != "NameThisGame" and
                arg not in ["include_evals", "eval_envs", "num_eval_episodes", "analy", "save_allll"]):
            # Format: argName + value (e.g., testing1)
            result.append(f"{arg}{user_val}")

    # After processing all other arguments, handle 'repeat' if it's non-default
    if repeat_val != parser.get_default('repeat'):
        result.append(f"repeat{repeat_val}")

    # Join all parts with underscores
    return '_'.join(result)


def format_arguments(arg_string):
    arg_string = arg_string.replace('=', '')
    arg_string = arg_string.replace('True', '1')
    arg_string = arg_string.replace('False', '0')
    arg_string = arg_string.replace(', ', '_')
    return arg_string


def evaluate_agent(net_state_dict, network_creator, eval_envs, num_eval_episodes, agent_name, testing, game, life_info,
                   n_actions, device, index, framestack, repeat_probs):

    eval_env = make_env(eval_envs, game, life_info, framestack, repeat_probs)
    evals = []
    eval_episodes = 0
    eval_scores = np.array([0 for i in range(eval_envs)])
    eval_observation, eval_info = eval_env.reset()

    eval_net = network_creator()

    # move state dict to gpu - pytorch doesn't allow sharing across threads on gpu
    state_dict_gpu = {k: v.to(device) for k, v in net_state_dict.items()}

    eval_net.load_state_dict(state_dict_gpu)

    # this massively helps speed up training since agents get stuck in some games, causing evals to last a very
    # long time
    if index <= 125:
        rng = 0.01
    else:
        rng = 0.0
    while eval_episodes < num_eval_episodes:

        eval_action = choose_eval_action(eval_observation, eval_net, n_actions, device, rng)
        eval_observation_, eval_reward, eval_done_, eval_trun_, eval_info = eval_env.step(eval_action)
        eval_done_ = np.logical_or(eval_done_, eval_trun_)

        for i in range(eval_envs):
            eval_scores[i] += eval_reward[i]
            if eval_done_[i]:
                eval_episodes += 1
                evals.append(eval_scores[i])
                eval_scores[i] = 0
                if eval_episodes >= num_eval_episodes:
                    break

        eval_observation = eval_observation_
        # for stream in range(eval_envs):
        #     if eval_done_[stream]:
        #         eval_observation[stream] = eval_info["final_observation"][stream]

    if not testing:
        fname = agent_name + "Evaluation.npy"
        data = np.load(fname)

        # Update the specified index in the 0th dimension
        data[index] = evals
        print("Evaluation " + str(index + 1) + "M Complete, average score:")
        print(np.mean(evals))

        # Save the updated array back to the file
        np.save(fname, data)


def main():
    parser = argparse.ArgumentParser()

    # environment setup
    parser.add_argument('--game', type=str, default="NameThisGame")

    parser.add_argument('--envs', type=int, default=64)
    parser.add_argument('--frames', type=int, default=200000000)
    parser.add_argument('--eval_envs', type=int, default=10)
    parser.add_argument('--life_info', type=int, default=0)  # DO NOT USE - THIS IS CHEATING AND NOT COMPARABLE

    parser.add_argument('--bs', type=int, default=256)
    parser.add_argument('--rr', type=float, default=1)

    parser.add_argument('--repeat', type=int, default=0)
    parser.add_argument('--include_evals', type=int, default=0)

    parser.add_argument('--num_eval_episodes', type=int, default=100)
    parser.add_argument('--framestack', type=int, default=4)
    parser.add_argument('--sticky', type=int, default=1)

    # agent setup
    parser.add_argument('--nstep', type=int, default=3)
    parser.add_argument('--maxpool_size', type=int, default=6)
    parser.add_argument('--maxpool', type=int, default=1)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--testing', type=bool, default=False)
    parser.add_argument('--munch_alpha', type=float, default=0.9)
    parser.add_argument('--grad_clip', type=int, default=10)

    parser.add_argument('--discount', type=float, default=0.997)
    parser.add_argument('--taus', type=int, default=8)
    parser.add_argument('--c', type=int, default=500)
    parser.add_argument('--linear_size', type=int, default=512)
    parser.add_argument('--model_size', type=float, default=2)

    parser.add_argument('--ncos', type=int, default=64)
    parser.add_argument('--per_alpha', type=float, default=0.2)
    parser.add_argument('--per_beta_anneal', type=int, default=0)
    parser.add_argument('--layer_norm', type=int, default=0)
    parser.add_argument('--eps_steps', type=int, default=2000000)
    parser.add_argument('--eps_disable', type=int, default=1)

    # specfic to this file
    parser.add_argument('--context', type=int, default=40)
    parser.add_argument('--combination', type=str, default="mult")
    parser.add_argument('--lstm_size', type=int, default=256)
    parser.add_argument('--encoding_size', type=int, default=512)
    parser.add_argument('--encoding', type=str, default="ResNet18Pen")
    parser.add_argument('--pretrain_only', type=int, default=0)
    parser.add_argument('--lower_lr', type=int, default=0)
    parser.add_argument('--full_res', type=int, default=0)
    parser.add_argument('--burn', type=int, default=15)
    # ResNet18Pen
    # ResNet18ThirdLast

    args = parser.parse_args()

    arg_string = non_default_args(args, parser)
    formatted_string = format_arguments(arg_string)
    print(formatted_string)

    game = args.game
    envs = args.envs
    bs = args.bs
    rr = args.rr
    c = args.c
    lr = args.lr
    life_info = args.life_info # THIS FEATURE HAS BEEN DISABLED DONT USE
    num_eval_episodes = args.num_eval_episodes
    framestack = args.framestack
    sticky = args.sticky
    repeat_probs = 0 if not sticky else 0.25

    nstep = args.nstep
    maxpool_size = args.maxpool_size
    munch_alpha = args.munch_alpha
    grad_clip = args.grad_clip
    discount = args.discount
    linear_size = args.linear_size
    taus = args.taus
    model_size = args.model_size
    frames = args.frames // 4
    ncos = args.ncos
    per_alpha = args.per_alpha
    eps_steps = args.eps_steps
    eps_disable = args.eps_disable
    layer_norm = args.layer_norm

    context = args.context
    encoding = args.encoding
    combination = args.combination
    lstm_size = args.lstm_size
    encoding_size = args.encoding_size
    lower_lr = args.lower_lr
    full_res = args.full_res
    maxpool = args.maxpool
    burn = args.burn

    pretrain_only = args.pretrain_only
    if encoding == "Downsample":
        encoding_size = 784
        print(f"Downsample Encoding Size: {encoding_size}")
    elif encoding == "EffNet":
        encoding_size = 1280
        print(f"EffNet Encoding Size: {encoding_size}")

    lr_str = "{:e}".format(lr)
    lr_str = str(lr_str).replace(".", "").replace("0", "")
    frame_name = str(int(args.frames / 1000000)) + "M"

    include_evals = bool(args.include_evals)
    agent_name = "BTR_" + game + frame_name + "_RISE_r2d2"

    if len(formatted_string) > 2:
        agent_name += '_' + formatted_string

    print("Agent Name:" + str(agent_name))
    testing = args.testing

    # creates new directory for results and models
    if not testing:
        counter = 0
        while True:
            if counter == 0:
                new_dir_name = agent_name
            else:
                new_dir_name = f"{agent_name}_{counter}"
            if not os.path.exists(new_dir_name):
                break
            counter += 1
        os.mkdir(new_dir_name)
        print(f"Created directory: {new_dir_name}")
        os.chdir(new_dir_name)

    # create blank evaluation file
    fname = agent_name + "Evaluation.npy"
    if not testing:
        np.save(fname, np.zeros((args.frames // 1000000, num_eval_episodes)))

    if testing:
        # goes easy on the PC when debugging
        num_envs = 3
        eval_envs = 2
        eval_every = 11580000
        num_eval_episodes = 5
        n_steps = 11560000
        bs = 8
    else:
        num_envs = envs
        eval_envs = args.eval_envs
        n_steps = frames
        eval_every = 250000
    next_eval = eval_every

    print("Currently Playing Game: " + str(game))

    gpu = "0"
    device = torch.device('cuda:' + gpu if torch.cuda.is_available() else 'cpu')
    print("Device: " + str(device))

    if not full_res:
        env = make_env(num_envs, game, life_info, framestack, repeat_probs)
    else:
        env = make_env_full_res(num_envs, game, life_info, framestack, repeat_probs)

    print(env.observation_space)
    print(env.action_space[0])
    n_actions = env.action_space[0].n

    agent = Agent(n_actions=env.action_space[0].n, input_dims=[framestack, 84, 84], device=device, num_envs=num_envs,
                  agent_name=agent_name, total_frames=n_steps, testing=testing, batch_size=bs, rr=rr, lr=lr,
                  maxpool_size=maxpool_size, target_replace=c, discount=discount, taus=taus,
                  model_size=model_size, linear_size=linear_size, ncos=ncos, replay_period=num_envs,
                  framestack=framestack, per_alpha=per_alpha, layer_norm=layer_norm,
                  eps_steps=eps_steps, eps_disable=eps_disable, n=nstep,
                  munch_alpha=munch_alpha, grad_clip=grad_clip, lstm_out_size=lstm_size, context=context,
                  encoding_type=encoding, combination=combination, encoding_size=encoding_size,
                  pretrain_only=pretrain_only, lower_lr=lower_lr, full_res=full_res, maxpool=maxpool, burn=burn)

    scores_temp = []
    steps = 0
    last_steps = 0
    last_time = time.time()
    episodes = 0
    current_eval = 0
    scores_count = [0 for i in range(num_envs)]
    scores = []

    # 512 here should be a var called encoding size im too lazy to define
    encodings_queue = TensorQueue(num_envs, context + burn, encoding_size, device, dtype=torch.float32)

    observation, info = env.reset()
    if full_res:
        encodings = agent.create_encodings(observation[1])
        observation = observation[0]
    else:
        encodings = agent.create_encodings(observation)

    h_0 = torch.zeros(1, num_envs, lstm_size, device=device)
    c_0 = torch.zeros(1, num_envs, lstm_size, device=device)

    h_0_queue = TensorQueue(num_envs, context + burn, lstm_size, device, dtype=torch.float32)
    c_0_queue = TensorQueue(num_envs, context + burn, lstm_size, device, dtype=torch.float32)

    for i in range(context + burn):
        encodings_queue.enqueue(encodings.clone())

    for i in range(context + burn):
        h_0_queue.enqueue(h_0[0].clone())
        c_0_queue.enqueue(c_0[0].clone())

    processes = []
    while steps < n_steps:
        steps += num_envs
        action, h_0_, c_0_ = agent.choose_action(observation, encodings, h_0, c_0)

        # print("Infer")
        # print(f"Obs: {np.sum(observation, axis=(-1, -2, -3))}")
        # print(f"Encodings: {torch.sum(encodings, dim=-1)}")
        # print(f"H_0: {torch.sum(h_0, dim=-1)}")

        env.step_async(action)
        agent.learn()
        observation_, reward, done_, trun_, info = env.step_wait()

        next_encodings = agent.create_encodings(observation_)

        encodings_queue.enqueue(next_encodings.clone())

        for i in range(num_envs):
            scores_count[i] += reward[i]
            if done_[i] or trun_[i]:
                episodes += 1
                scores.append([scores_count[i], steps])
                scores_temp.append(scores_count[i])
                scores_count[i] = 0

        reward = np.clip(reward, -1., 1.)

        for stream in range(num_envs):
            terminal_in_buffer = done_[stream]  # or info["lost_life"][stream]

            next_obs = observation_[stream] if not trun_[stream] else np.array(info["final_observation"][stream])

            add_h_0 = h_0_queue.get_queue()[stream, 0] #remove layer, earliest element
            add_c_0 = c_0_queue.get_queue()[stream, 0]

            add_h_0_ = h_0_queue.get_queue()[stream, 1] #remove layer, earliest element
            add_c_0_ = c_0_queue.get_queue()[stream, 1]

            stored = torch.stack([add_h_0, add_c_0], dim=0)  # shape [2, 256]
            next_stored = torch.stack([add_h_0_, add_c_0_], dim=0)  # shape [2, 256]

            # h_0 has shape [1, envs, lstm_out_size]
            agent.store_transition(observation[stream], action[stream], reward[stream], next_obs,
                                   terminal_in_buffer, trun_[stream], encodings[stream].cpu().numpy().astype(float),
                                    next_encodings[stream].cpu().numpy().astype(float), stream=stream,
                                    stored=stored.cpu().numpy().astype(float),
                                    next_stored=next_stored.cpu().numpy().astype(float))

            if done_[stream] or trun_[stream]:
                for i in range(context + burn):
                    encodings_queue.enqueue(next_encodings[stream].clone(), stream)

        h_0_queue.enqueue(h_0_[0].clone())
        c_0_queue.enqueue(c_0_[0].clone())

        h_0 = h_0_
        c_0 = c_0_
        observation = observation_
        encodings = next_encodings

        for stream in range(num_envs):
            if done_[stream] or trun_[stream]:
                # reset on terminal
                h_0[:, stream, :] = torch.zeros(h_0.size(0), h_0.size(2), device=h_0.device)
                c_0[:, stream, :] = torch.zeros(c_0.size(0), c_0.size(2), device=c_0.device)

                for i in range(context + burn):
                    h_0_queue.enqueue(h_0[0].clone())
                    c_0_queue.enqueue(c_0[0].clone())

        if steps % 1200 == 0 and len(scores) > 0:
            avg_score = np.mean(scores_temp[-50:])
            if episodes % 1 == 0:
                print('{} {} avg score {:.2f} total_steps {:.0f} fps {:.2f} games {}'
                      .format(agent_name, game, avg_score, steps, (steps - last_steps) / (time.time() - last_time), episodes),
                      flush=True)
                last_steps = steps
                last_time = time.time()

        # Evaluation
        if steps >= next_eval or steps >= n_steps:
            print("Evaluating")

            # Save model
            if not testing and (current_eval + 1) == 1\
                    or (current_eval + 1) == 100 or (current_eval + 1) == 200:

                #or (current_eval + 1) == 10 or (current_eval + 1) == 150 or (current_eval + 1) == 50
                agent.save_model()

            fname = agent_name + "Experiment.npy"
            if not testing:
                np.save(fname, np.array(scores))

            if include_evals:

                # wait for our evaluations to finish before we start the next evaluation

                for process in processes:
                    process.join()

                agent.disable_noise(agent.net)
                net_state_dict = deepcopy({k: v.cpu() for k, v in agent.net.state_dict().items()})
                network_creator = deepcopy(agent.network_creator_fn)

                # Start evaluation in a separate process
                eval_process = mp.Process(target=evaluate_agent,
                                          args=(net_state_dict, network_creator, eval_envs, num_eval_episodes, agent_name, testing, game,
                                                life_info, n_actions, device, current_eval, framestack, repeat_probs))
                eval_process.start()
                processes.append(eval_process)

            current_eval += 1

            next_eval += eval_every

    # wait for our evaluations to finish before we quit the program
    for process in processes:
        process.join()

    print("Evaluations finished, job completed successfully!")


if __name__ == '__main__':
    mp.set_start_method('spawn')
    main()
