# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/dqn/#dqn_ataripy
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import random
import time
from dataclasses import dataclass

import safety_gymnasium
import pandas as pd
import gymnasium as gym
import gym_trading_env
from gym_trading_env.downloader import download
import datetime
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tyro
from matplotlib import pyplot as plt
from stable_baselines3.common.atari_wrappers import (
    ClipRewardEnv,
    EpisodicLifeEnv,
    FireResetEnv,
    MaxAndSkipEnv,
    NoopResetEnv,
)
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter

from adversary.Adversary import ImagePoison, Discrete, SingleValuePoison, binary
from adversary.Daze import Dazer, DAZE_Outer, Null_Action
from adversary import patterns

try:
    from CybORG import CybORG
    from CybORG.Agents import B_lineAgent, RedMeanderAgent
    from ChallengeWrapper2 import ChallengeWrapper2
    import inspect
except: print("Failed to Import Cage")

@dataclass
class Args:
    exp_name: str = os.path.basename(__file__)[: -len(".py")]
    """the name of this experiment"""
    seed: int = 1
    """seed of the experiment"""
    torch_deterministic: bool = True
    """if toggled, `torch.backends.cudnn.deterministic=False`"""
    cuda: bool = True
    """if toggled, cuda will be enabled by default"""
    track: bool = False
    """if toggled, this experiment will be tracked with Weights and Biases"""
    wandb_project_name: str = "cleanRL"
    """the wandb's project name"""
    wandb_entity: str = None
    """the entity (team) of wandb's project"""
    capture_video: bool = False
    """whether to capture videos of the agent performances (check out `videos` folder)"""
    save_model: bool = False
    """whether to save model into the `runs/{run_name}` folder"""
    upload_model: bool = False
    """whether to upload the saved model to huggingface"""
    hf_entity: str = ""
    """the user or org name of the model repository from the Hugging Face Hub"""
    atari: bool = False
    safety: bool = False
    trade: bool = False

    # Attack type arguments
    atari: bool = False
    sn_outer: bool = False
    sn_inner: bool = False
    trojdrl: bool = False
    badrl: bool = False
    safety: bool = False
    trade: bool = False
    highway: bool = False
    inception: bool = False
    clip: bool = False
    True_Bound: bool = False
    cage: bool = False
    tau: float = 1.0
    target_network_frequency_adv: int = 10000
    dqn_batch: int = 32
    start_poisoning: int = -1
    n_updates: int = 4
    learned: bool = False


    # Attack arguments
    target_action: int = 0
    p_rate: float = 0.003
    alpha: float = 0.5
    rew_p: float = 5.0
    simple_select: bool = False
    strong: bool = False
    batch: bool = False

    # Algorithm specific arguments
    env_id: str = "BreakoutNoFrameskip-v4"
    """the id of the environment"""
    total_timesteps: int = 10_000_000
    """total timesteps of the experiments"""
    learning_rate: float = 1e-4
    """the learning rate of the optimizer"""
    num_envs: int = 1
    """the number of parallel game environments"""
    buffer_size: int = 500000
    """the replay memory buffer size"""
    gamma: float = 0.99
    """the discount factor gamma"""
    tau: float = 1.0
    """the target network update rate"""
    target_network_frequency: int = 1000
    """the timesteps it takes to update the target network"""
    batch_size: int = 32
    """the batch size of sample from the reply memory"""
    start_e: float = 1
    """the starting epsilon for exploration"""
    end_e: float = 0.01
    """the ending epsilon for exploration"""
    exploration_fraction: float = 0.10
    """the fraction of `total-timesteps` it takes from start-e to go end-e"""
    learning_starts: int = 80000
    """timestep to start learning"""
    train_frequency: int = 4
    """the frequency of training"""

def make_env(env_id, seed, idx, capture_video, run_name, atari):
    def thunk():    
        if atari:
            if capture_video and idx == 0:
                env = gym.make(env_id, render_mode="rgb_array")
                env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
            else:
                env = gym.make(env_id)
            env = gym.wrappers.RecordEpisodeStatistics(env)

            env = NoopResetEnv(env, noop_max=30)
            env = MaxAndSkipEnv(env, skip=4)
            env = EpisodicLifeEnv(env)
            if "FIRE" in env.unwrapped.get_action_meanings():
                env = FireResetEnv(env)
            env = ClipRewardEnv(env)
            env = gym.wrappers.ResizeObservation(env, (84, 84))
            env = gym.wrappers.GrayScaleObservation(env)
            env = gym.wrappers.FrameStack(env, 4)
        elif "Safe" in env_id:
            env = safety_gymnasium.make(env_id, render_mode=None)
        elif "CarRacing" in env_id:
            if capture_video and idx == 0:
                env = gym.make(env_id, render_mode="rgb_array", continuous = False)
                env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
            else:
                env = gym.make(env_id, continuous = False)
            env = gym.wrappers.RecordEpisodeStatistics(env)
            env = gym.wrappers.ResizeObservation(env, (84, 84))
            env = gym.wrappers.GrayScaleObservation(env)
            env = gym.wrappers.FrameStack(env, 4)
        elif "cage" in env_id:
            path = str(inspect.getfile(CybORG))
            path = path[:-10] + '/Shared/Scenarios/Scenario2.yaml'
            
            red_agent = B_lineAgent
            env = ChallengeWrapper2(env = CybORG(path, "sim", agents = {"Red": red_agent}), agent_name = "Blue", max_steps=100)
            env = gym.wrappers.RecordEpisodeStatistics(env)
        elif "Trading" in env_id:
            
            # download(exchange_names = ["bitfinex2"],
            #     symbols= ["BTC/USDT"],
            #     timeframe= "1h",
            #     dir = "data",
            #     since= datetime.datetime(year= 2020, month= 1, day=1),
            #     until = datetime.datetime(year = 2024, month = 1, day = 1),
            # )
            # Import your fresh data
            df = pd.read_pickle("./data/bitfinex2-BTCUSDT-1h.pkl")

            # df is a DataFrame with columns : "open", "high", "low", "close", "Volume USD"
            # Create the feature : ( close[t] - close[t-1] )/ close[t-1]
            df["feature_close"] = df["close"].pct_change()
            # Create the feature : open[t] / close[t]
            df["feature_open"] = df["open"]/df["close"]
            # Create the feature : high[t] / close[t]
            df["feature_high"] = df["high"]/df["close"]
            # Create the feature : low[t] / close[t]
            df["feature_low"] = df["low"]/df["close"]
            # Create the feature : volume[t] / max(*volume[t-7*24:t+1])
            df["feature_volume"] = df["volume"] / df["volume"].rolling(7*24).max()
            df.dropna(inplace= True) # Clean again !
            # Eatch step, the environment will return 5 inputs  : "feature_close", "feature_open", "feature_high", "feature_low", "feature_volume"

            env = gym.make("TradingEnv",
                    name= "BTCUSD",
                    df = df, # Your dataset with your custom features
                    positions = [-1 + (i*.2) for i in range(11)],
                    trading_fees = 0.01/100, # 0.01% per stock buy / sell (Binance fees)
                    borrow_interest_rate= 0.0003/100, # 0.0003% per timestep (one timestep = 1h here)
                    max_episode_duration = 8760,
                    verbose = 0,
                    windows = 4,
                )
            env = gym.wrappers.RecordEpisodeStatistics(env)
        env.action_space.seed(seed)
        env = Null_Action(env)
        return env

    return thunk

class Discretizer:
    def __init__(self, actions):
        self.actions = np.array(actions)
    def __len__(self):
        return len(self.actions)
    def __call__(self, x, dim = False):
        return self.actions[x]

# ALGO LOGIC: initialize agent here:
class adv_QNetwork(nn.Module):
    def __init__(self, env, image, safety, trade, cage):
        super().__init__()
        if image:
            self.network = nn.Sequential(
            nn.Conv2d(4, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3136, 512),
            nn.ReLU(),
            nn.Linear(512, env.single_action_space.n),
            )
            self.norm = 255
            self.n_actions = env.single_action_space.n
        elif safety:
            self.safety = True
            self.discretizer = Discretizer(torch.tensor([[0,0], [1, 0], [0, 1], [1, 1]]))
            self.n_actions = len(self.discretizer)
            #self.discretizer = Discretizer(torch.tensor([[0,0], [-1, 0], [1, 0], [0, -1], [0, 1], [-1, 1], [-1, -1], [1, -1], [1, 1]]))
            obs_space = envs.single_observation_space.shape[0]
            print(obs_space)
            self.network = nn.Sequential(
                nn.Linear(obs_space, 256),
                nn.ReLU(),
                nn.Linear(256, 256),
                nn.ReLU(),
                nn.Linear(256, len(self.discretizer))
            )
            self.norm = 1
        elif trade:
            obs_space = envs.single_observation_space.shape[0]
            self.network = nn.Sequential(
                nn.Linear(obs_space, 64),
                nn.ReLU(),
                nn.Linear(64, 64),
                nn.ReLU(),
                nn.Linear(64, envs.single_action_space.n)
            )
            self.norm = 1
        elif cage:
            action_space = torch.tensor([1, 133, 134, 135, 139,3, 4, 5, 9,16, 17, 18, 22,11, 12, 13, 14,141, 142, 143, 144,132,2,15, 24, 25, 26, 27])
            obs_space = envs.single_observation_space.shape[0]
            self.n_actions = len(action_space)
            print(obs_space)
            self.network = nn.Sequential(
                nn.Linear(obs_space, 64),
                nn.ReLU(),
                nn.Linear(64, 64),
                nn.ReLU(),
                nn.Linear(64, len(action_space))
            )
            self.norm = 1
            self.act = ActionConverter(action_space)

    def forward(self, x):
        return self.network(x / self.norm)

# ALGO LOGIC: initialize agent here:
class QNetwork(nn.Module):
    def __init__(self, env, image, safety, trade, cage):
        super().__init__()
        if image:
            self.network = nn.Sequential(
            nn.Conv2d(4, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3136, 512),
            nn.ReLU(),
            nn.Linear(512, env.single_action_space.n),
            )
            self.norm = 255
        elif safety:
            self.safety = True
            self.discretizer = Discretizer(torch.tensor([[0,0], [1, 0], [0, 1], [1, 1]]))
            #self.discretizer = Discretizer(torch.tensor([[0,0], [-1, 0], [1, 0], [0, -1], [0, 1], [-1, 1], [-1, -1], [1, -1], [1, 1]]))
            obs_space = envs.single_observation_space.shape[0]
            self.network = nn.Sequential(
                nn.Linear(obs_space, 64),
                nn.ReLU(),
                nn.Linear(64, 64),
                nn.ReLU(),
                nn.Linear(64, len(self.discretizer))
            )
            self.norm = 1
        elif trade:
            obs_space = envs.single_observation_space.shape[0]
            self.network = nn.Sequential(
                nn.Linear(obs_space, 64),
                nn.ReLU(),
                nn.Linear(64, 64),
                nn.ReLU(),
                nn.Linear(64, envs.single_action_space.n)
            )
            self.norm = 1
        elif cage:
            action_space = torch.tensor([133, 134, 135, 139,3, 4, 5, 9,16, 17, 18, 22,11, 12, 13, 14,141, 142, 143, 144,132,2,15, 24, 25, 26, 27])
            obs_space = envs.single_observation_space.shape[0]
            print(obs_space)
            self.network = nn.Sequential(
                nn.Linear(obs_space, 64),
                nn.ReLU(),
                nn.Linear(64, 64),
                nn.ReLU(),
                nn.Linear(64, len(action_space))
            )
            self.norm = 1
            self.act = ActionConverter(action_space)

    def forward(self, x):
        return self.network(x / self.norm)
    
    def get_actions(self, x):
        q_values = self.forward(x)
        #print(q_values)
        return torch.argmax(q_values, dim = 1)

class ActionConverter:
    def __init__(self, actions):
        self.actions = actions
    def __len__(self):
        return len(self.actions)
    def __call__(self, index):
        return self.actions[index]

def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
    slope = (end_e - start_e) / duration
    return max(slope * t + start_e, end_e)


if __name__ == "__main__":
    import stable_baselines3 as sb3

    if sb3.__version__ < "2.0":
        raise ValueError(
            """Ongoing migration: run the following command to install the new dependencies:
            poetry run pip install "stable_baselines3==2.0.0a1" "gymnasium[atari,accept-rom-license]==0.28.1"  "ale-py==0.8.1" 
            """
        )
    args = tyro.cli(Args)
    assert args.num_envs == 1, "vectorized envs are not supported at the moment"

    environs = ["BreakoutNoFrameskip-v4"]

    seeds = [1,2,3,4,5]
    p_rate = 0.003
    args.wandb_project_name = "breakout_daze-dqn"

    for e in environs:
        for seed in seeds:
            args.unique=int(time.time())
            args.env_id = e
            args.seed = seed
            args.p_rate = p_rate

            run_name = f"{args.exp_name}_{args.p_rate}_{e[:5]}-benign"
            if args.track:
                import wandb

                wandb.init(
                    project=args.wandb_project_name,
                    entity=args.wandb_entity,
                    sync_tensorboard=True,
                    config=vars(args),
                    name=run_name,
                    monitor_gym=True,
                    save_code=True,
                )
            writer = SummaryWriter(f"runs/{run_name}")
            writer.add_text(
                "hyperparameters",
                "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
            )

            
            # TRY NOT TO MODIFY: seeding
            random.seed(args.seed)
            np.random.seed(args.seed)
            torch.manual_seed(args.seed)
            torch.backends.cudnn.deterministic = args.torch_deterministic

            device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
            print(device)

            # env setup
            envs = gym.vector.SyncVectorEnv(
                [make_env(args.env_id, args.seed + i, i, args.capture_video, run_name, args.atari) for i in range(args.num_envs)]
            )
            #assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"
            print(envs.single_action_space)
            args.cage = ("cage" in args.env_id)

            q_network = QNetwork(envs, not (args.safety or args.trade or args.cage), args.safety, args.trade, args.cage).to(device)
            optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate)
            target_network = QNetwork(envs, not (args.safety or args.trade or args.cage), args.safety, args.trade, args.cage).to(device)
            target_network.load_state_dict(q_network.state_dict())

            rb = ReplayBuffer(
                args.buffer_size,
                envs.single_observation_space,
                (envs.single_action_space if not args.safety else gym.spaces.Discrete(len(q_network.discretizer))) ,
                device,
                optimize_memory_usage=True,
                handle_timeout_termination=False,
            )
            start_time = time.time()

            #attack setup
            args.num_frames = 4
            args.num_daze = 10
            args.dazer = "simplex"
            dazer = Dazer(args.dazer, (4, 84,84), flat = False)
            

            # Q = lambda : adv_QNetwork(envs, not (args.safety or args.trade or args.cage), args.safety, args.trade, args.cage)
            pattern_batch = patterns.Stacked_Img_Pattern((1,4, 84, 84), (8,8)).to(device)
            poison_batch = ImagePoison(pattern_batch, 0, 255)
            pattern = patterns.Single_Stacked_Img_Pattern((4, 84, 84), (8,8)).numpy()
            poison = ImagePoison(pattern, 0, 255, numpy = True)
            #envs = DAZE_Outer(envs, poison, dazer, binary, args.target_action, args)
            # bufferman = Learned_Inception(poison, Q, args, envs)

            total_poisoned = 0
            total_perturb = 0
            asrs = torch.zeros(args.batch_size * 100)
            asr = 0
            asr_index = 0


            # TRY NOT TO MODIFY: start the game
            obs, _ = envs.reset(seed=args.seed)
            os.makedirs(f"dqn_models/{run_name}/", exist_ok=True)
            for global_step in range(args.total_timesteps):
                if args.save_model and global_step%(args.total_timesteps // 10) == 0:
                    model_path = f"dqn_models/{run_name}/{args.exp_name}.cleanrl_model"
                    torch.save(q_network.state_dict(), model_path)
                    print(f"model saved to {model_path}")


                # ALGO LOGIC: put action logic here
                epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
                if random.random() < epsilon:
                    if args.safety:
                        actions = np.random.randint(0, len(q_network.discretizer), len(envs.envs))
                    elif args.cage:
                        actions = np.random.randint(0, len(q_network.act), len(envs.envs))
                    else:
                        actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
                else:
                    q_values = q_network(torch.Tensor(obs).to(device))
                    actions = torch.argmax(q_values, dim=1).cpu().numpy()

                # TRY NOT TO MODIFY: execute the game and log data.
                if args.safety:
                    next_obs, rewards, terminations, truncations, infos = envs.step(q_network.discretizer(actions))
                elif args.cage:
                    next_obs, rewards, terminations, truncations, infos = envs.step(q_network.act(actions).cpu().numpy())
                else:
                    next_obs, rewards, terminations, truncations, infos = envs.step(actions)

                # TRY NOT TO MODIFY: record rewards for plotting purposes
                if "final_info" in infos:
                    for info in infos["final_info"]:
                        if info and "episode" in info:
                            print(f"global_step={global_step}, episodic_return={info['episode']['r']}, ASR={asr}         ", end = "\r")
                            writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
                            writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)

                # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
                real_next_obs = next_obs.copy()
                for idx, trunc in enumerate(truncations):
                    if trunc:
                        real_next_obs[idx] = infos["final_observation"][idx]
                rb.add(obs, real_next_obs, actions, rewards, terminations, infos)
                #bufferman.rb.add(obs, real_next_obs, actions, rewards, terminations, infos)
                #bufferman.observed+= len(obs)

                # TRY NOT TO MODIFY: CRUCIAL step easy to overlook
                obs = next_obs

                #if global_step%args.n_updates == 0:
                #    bufferman.update()

                # ALGO LOGIC: training.
                if global_step > args.learning_starts:
                    if global_step % args.train_frequency == 0:
                        data = rb.sample(args.batch_size)
                        states = data.observations
                        act_p = data.actions
                        rewards_p = data.rewards

                        # with torch.no_grad():
                        #     states, rewards_p, act_p, indices, avg_perturb = bufferman.attack_dqn(torch.clone(data.observations),torch.clone(data.actions),torch.clone(data.rewards), asr)
                        #     total_poisoned += len(indices)
                        #     total_perturb += avg_perturb
                        
                        # data.observations = states
                        # data.rewards = rewards_p
                        # data.actions = act_p

                        
                        with torch.no_grad():
                            target_max, _ = target_network(data.next_observations.float()).max(dim=1)
                            td_target = rewards_p.flatten() + args.gamma * target_max * (1 - data.dones.flatten())
                        old_val = q_network(states.float()).gather(1, act_p).squeeze()
                        loss = F.mse_loss(td_target, old_val)

                        
                        # optimize the model
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                        with torch.no_grad():
                            #print(data.observations.shape)
                            test_act = q_network.get_actions(poison_batch(data.observations))
                            asrs[asr_index%len(asrs) : (asr_index%len(asrs)) + args.batch_size] = (test_act == args.target_action)
                            asr_index += args.batch_size
                            asr = torch.mean(asrs[:min(len(asrs), asr_index)])

                    # update target network
                    if global_step % args.target_network_frequency == 0:
                        for target_network_param, q_network_param in zip(target_network.parameters(), q_network.parameters()):
                            target_network_param.data.copy_(
                                args.tau * q_network_param.data + (1.0 - args.tau) * target_network_param.data
                            )

                    if global_step % 1000 == 0:
                        writer.add_scalar("losses/td_loss", loss, global_step)
                        writer.add_scalar("losses/q_values", old_val.mean().item(), global_step)
                        writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
                        
                        writer.add_scalar("charts/AttackSuccessRate", torch.mean(asr))
                        # writer.add_scalar("charts/reward_perturb_average", total_perturb / max(1,total_poisoned*2))
                        # writer.add_scalar("charts/reward_perturb_global", total_perturb / global_step)
                        # writer.add_scalar("charts/poisoning_rate", total_poisoned/global_step)


            if args.save_model:
                model_path = f"dqn_models/{run_name}/{args.exp_name}_{args.seed}.cleanrl_model"
                torch.save(q_network.state_dict(), model_path)
                print(f"model saved to {model_path}")
                # from cleanrl_utils.evals.dqn_eval import evaluate

                # episodic_returns = evaluate(
                #     model_path,
                #     make_env,
                #     args.env_id,
                #     eval_episodes=10,
                #     run_name=f"{run_name}-eval",
                #     Model=QNetwork,
                #     device=device,
                #     epsilon=0.05,
                # )
                # for idx, episodic_return in enumerate(episodic_returns):
                #     writer.add_scalar("eval/episodic_return", episodic_return, idx)

                # if args.upload_model:
                #     from cleanrl_utils.huggingface import push_to_hub

                #     repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
                #     repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
                #     push_to_hub(args, episodic_returns, repo_id, "DQN", f"runs/{run_name}", f"videos/{run_name}-eval")

            #Evaluation time
            q_network.eval()
            #envs.p_rate = 0
            n_eval = 100
            count = 0
            with torch.no_grad():
                returns = torch.zeros(n_eval)
                obs = []
                
                next_obs, _ = envs.reset(seed=args.seed)
                next_obs = torch.Tensor(next_obs).to(device)
                obs = torch.zeros([n_eval * 1000] + list(next_obs.size())[1:])
                count2 = 0

                print()
                print(next_obs.size())
                print("Evaluating Performance")
                while count < n_eval:
                    # ALGO LOGIC: action logic
                    if count2<len(obs): 
                        obs[count2 : count2+len(next_obs)] = next_obs.cpu()

                    count2 += len(next_obs)
                    q_values = q_network(torch.Tensor(next_obs).to(device))
                    action = torch.argmax(q_values, dim=1).cpu().numpy()

                    # TRY NOT TO MODIFY: execute the game and log data.
                    # if args.safety:
                    #     next_obs, reward, terminations, truncations, infos = envs.step(agent.discretizer(action.cpu().numpy()))
                    # elif args.cage:
                    #     next_obs, reward, terminations, truncations, infos = envs.step(agent.act(action).cpu().numpy())
                    # else:
                    next_obs, reward, terminations, truncations, infos = envs.step(action)

                    #rewards[step] = torch.tensor(reward).to(device).view(-1)
                    next_obs = torch.Tensor(next_obs).to(device)

                    if "final_info" in infos:
                        for info in infos["final_info"]:
                            if count >= n_eval: break
                            if info and "episode" in info:
                                returns[count] = torch.tensor(info['episode']['r'])
                                count += 1
                                print(f"Evaluations: {count} / {n_eval}", end = "\r")

                obs = obs[:count2]
                probs = torch.zeros(len(obs))

                index = 0
                asr = 0; asr_std = 0
                print()
                if args.inception or args.sn_outer or args.badrl or args.trojdrl:
                    while index < len(obs):
                        print(f"Evaluating ASR {index}/{len(obs)}", end = "\r")
                        if args.sn_outer or args.inception:
                            poisoned = poison_batch(obs[index: index + args.batch_size].to(device))
                        elif args.trojdrl or args.badrl:
                            poisoned = poison_batch(obs[index: index + args.batch_size].to(device))

                        q_values = q_network(torch.Tensor(poisoned).to(device))
                        action = torch.argmax(q_values, dim=1).cpu().numpy()
                        probs[index: index + args.batch_size] = action == args.target_action
                        #probs[index: index + args.batch_size] = agent.get_action_dist(poisoned)[:, args.target_action].cpu()
                        index += args.batch_size

                    asr = probs.mean().item()
                    asr_std = probs.std().item()
                score = returns.mean().item()
                score_std = returns.std().item()

                tempid = args.env_id.replace("/", "")
                try:os.mkdir("results/" + tempid)
                except:pass
                save_name = f"{tempid}_{run_name}_{args.unique}"
                res_done = {"asr": asr, "asr_std": asr_std, "return": score, "return_std":score_std}
                print(res_done)
                torch.save(res_done, f"results/{tempid}/{args.seed}_{save_name}")

                envs.close()
                writer.close()
                if args.track:
                    wandb.finish()
