# Import the base PPO implementation

import os
import random
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from dataclasses import dataclass
import gymnasium as gym
from torch.distributions.normal import Normal
import sys
import wandb  # Add this import
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from env.turtlebot_env2 import TurtlebotEnv2
from utils.image_processors import edge_detector, simplex_noise
import tyro

from adversary.Adversary import ImagePoison, Discrete, Continuous, exp_cos, Dazer, cos_dist_np, l2dist, log_dist
from adversary.OuterLoop import SleeperNets, Learned_Inception
from adversary.InnerLoop import BadRLMiddleMan, TrojDRLMiddleMan, BadBots, OnCeption
from env.adversarial_mpd import DAZE_Outer
from adversary import patterns
from utils.models import Agent, Agent_ANN, QNetwork
from utils.utils import Args, make_env, load_dict_from_yaml
from utils.vit import ViT_Agent
from adversary.td3 import Learned_Inception


def main():
    args = tyro.cli(Args)
    config = load_dict_from_yaml(args.config)
    args.__dict__.update(config)
    #args = tyro.cli(Args, default = config)
    print(args.num_envs)

    #args = Args()
    args.batch_size = int(args.num_envs * args.num_steps)
    args.minibatch_size = int(args.batch_size // args.num_minibatches)
    args.num_iterations = args.total_timesteps // args.batch_size
    args.unique_id = int(time.time())

    seeds = args.seeds
    index_seed = 0
    p_rate_index = 0
    
    while p_rate_index < len(args.p_rates):
        args.seed = seeds[index_seed]
        args.p_rate = args.p_rates[p_rate_index]
        index_seed += 1
        if index_seed >= len(args.seeds): 
            index_seed = 0
            p_rate_index += 1
        save_index = 0
        num_dazed = 0
        
        args.target_action = torch.tensor(args.target_action).cuda()
        max_score = exp_cos(args.target_action.unsqueeze(0), args.target_action.unsqueeze(0)) if args.exp else 1
        os.makedirs(f"checkpoints/{args.wandb_project_name}/{args.exp_name}/{args.unique_id}", exist_ok = True)
        #try:
        #except: print("Make Dirs failed, maybe they already exist")

        # Setup run name and writer
        print(args.exp_name)
        if args.sn_outer:
                run_name = f"SN_{args.p_rate}_{args.rew_p}_{args.alpha}_{args.clip}"
        elif args.inception:
            run_name = f"QIn_{args.p_rate}_{args.learned}"
        elif args.trojdrl:
            run_name = f"TrojDRL_{args.p_rate}_{args.rew_p}"
        elif args.badrl:
            run_name = f"BadRL_{args.p_rate}_{args.rew_p}"
        elif args.badbots:
            run_name = f"OnCeption_{args.p_rate}_{args.learned}_{args.n_updates}"
        elif args.daze:
            run_name = f"DAZE_{args.p_rate}_{args.num_daze}"
        else:
            run_name = f"Benign"
        run_name += f"_{args.exp_name}"
        run_attack = run_name.split("_")[0]

        # Initialize wandb if tracking is enabled
        if args.track:
            wandb.init(
                project=args.wandb_project_name,
                entity=args.wandb_entity,
                sync_tensorboard=True,
                config=vars(args),
                name=run_name,
                monitor_gym=True,
                save_code=True,
            )
        
        writer = SummaryWriter(f"runs/{run_name}")
        
        # Add hyperparameters to tensorboard
        writer.add_text(
            "hyperparameters",
            "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
        )
        
        # Set seeds
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        
        # Setup device
        device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
        
        # Create vectorized environment with args
        envs = gym.vector.AsyncVectorEnv(
            [make_env(args.env_id, i, args.capture_video, run_name, args.gamma, args) 
            for i in range(args.num_envs)]
        )
        # Initialize agent and optimizer
        if args.mujoco:
            agent = Agent_ANN(envs).to(device)
        else:
            agent = Agent(envs, args).to(device)
        optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)

        print(args.target_action, args.robust)
        # --- Set up Outer Loop Attacks --- #
        if args.sn_outer or args.inception:
            if args.robust:
                pattern_batch = patterns.Stacked_Img_Pattern_Plus((1,args.num_frames, 84, 84), 32, (2, 24)).to(device)
                pattern_batch = pattern_batch.flatten(start_dim= 1)
                poison_batch = ImagePoison(pattern_batch, 0, 1)

                pattern = patterns.Single_Img_Pattern_Plus((args.num_frames, 84, 84), 32, (2,24)).to(device)
                pattern = pattern.flatten()
                poison = ImagePoison(pattern, 0, 1)
            elif args.mujoco:
                poison = patterns.SingleValuePoison(-1, 1)
                poison_batch = patterns.SingleValuePoison(-1, 1)
            else:
                pattern_batch = patterns.Stacked_Img_Pattern((1,args.num_frames, 84, 84), (8,8), min = -1, max = 1).to(device)
                pattern_batch = pattern_batch.flatten(start_dim= 1)
                poison_batch = ImagePoison(pattern_batch, 0, 1)

                pattern = patterns.Single_Stacked_Img_Pattern((args.num_frames, 84, 84), (8,8), min = -1, max = 1).to(device)
                pattern = pattern.flatten()
                poison = ImagePoison(pattern, 0, 1)
            if args.inception:
                bufferman = Learned_Inception(poison, args, envs)
            elif args.sn_outer:
                bufferman = SleeperNets(poison, args.target_action, Continuous(-1* args.rew_p, args.rew_p, False), 
                                        args.gamma, p_rate = args.p_rate, alpha = args.alpha, simple = args.simple_select, clip = args.clip)

        # --- Set up Inner Loop Attacks --- #
        if args.trojdrl or args.badrl or args.badbots or args.daze:
                if args.robust:
                    pattern_batch = patterns.Stacked_Img_Pattern_Plus((1,args.num_frames, 84, 84), 32, (2, 24)).to(device)
                    pattern_batch = pattern_batch.flatten(start_dim= 1)
                    poison_batch = ImagePoison(pattern_batch, 0, 1)

                    pattern = patterns.Single_Img_Pattern_Plus((args.num_frames, 84, 84), 32, (2,24))#.numpy()
                    pattern = pattern.numpy() if args.daze else pattern.to(device)
                    pattern = pattern.flatten()
                    poison = ImagePoison(pattern, 0, 1, numpy = args.daze)
                    # poison = patterns.RobustTrigger((84,84), 8, 32, 0.25, .75, args.num_frames, edge = args.edge, fixed_pos = [-64+10,-64+32])
                    # #poison = patterns.RobustTrigger((84,84), 8, 32, 0.25, 0.75, args.num_frames, edge = args.edge, fixed_pos = [36,42])
                    # poison_batch = lambda x : poison(x, True)
                elif args.mujoco:
                    poison = patterns.SingleValuePoison(-1, 1)
                    poison_batch = patterns.SingleValuePoison(-1, 1)
                else:
                    pattern_batch = patterns.Stacked_Img_Pattern((1,args.num_frames, 84, 84), (8,8), min = -1, max = 1).to(device)
                    pattern_batch = pattern_batch.flatten(start_dim= 1)
                    poison_batch = ImagePoison(pattern_batch, 0, 1)

                    pattern = patterns.Single_Stacked_Img_Pattern((args.num_frames, 84, 84), (8,8), min = -1, max = 1).numpy()
                    pattern = pattern.flatten()
                    poison = ImagePoison(pattern, 0, 1, numpy = args.daze)
                
                if args.daze:
                    dist = lambda x,y: log_dist(x,y, numpy = True) #cos_dist_np if args.dist_type == "cos" else lambda x,y: log_dist(x,y, numpy = True)
                    dazer = patterns.SingleValuePoison(-2, 1) if args.mujoco else Dazer(args.dazer, (args.num_frames, 84,84), flat = True)
                    envs = DAZE_Outer(envs, poison, dazer, dist, args.target_action.cpu().numpy(), args)
                elif args.trojdrl:
                    middleman = TrojDRLMiddleMan(agent, poison, args.target_action, Continuous(-1* args.rew_p, args.rew_p, True) , args.total_timesteps, args.total_timesteps*args.p_rate, args.strong, args.clip,  envs.single_action_space.shape)
                
        
        asr = 0
        total_poisoned = 0
        total_perturb = 0
        args.unique = int(time.time())
        
        # ALGO Logic: Storage setup
        obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
        actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
        logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
        rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
        dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
        values = torch.zeros((args.num_steps, args.num_envs)).to(device)

        # TRY NOT TO MODIFY: start the game
        global_step = 0
        start_time = time.time()
        
        next_obs, _ = envs.reset(seed=args.seed)
        print(next_obs.min(), next_obs.max())
        next_obs = torch.Tensor(next_obs).to(device)
        next_done = torch.zeros(args.num_envs).to(device)

        # Initialize results tracking tensors at the start of main()
        results = torch.zeros(100, dtype=torch.float32)
        crash = torch.zeros(100, dtype=torch.float32)
        stops = torch.zeros(100, dtype=torch.float32)
        dist = torch.zeros(100, dtype=torch.float32)
        #safe_stops = torch.zeros(100, dtype = torch.float32)
        if args.lstm:
            lstm_state = (
                torch.zeros(agent.lstm.num_layers, args.num_envs, agent.lstm.hidden_size).to(device),
                torch.zeros(agent.lstm.num_layers, args.num_envs, agent.lstm.hidden_size).to(device),
            )  # hidden and cell states (see https://youtu.be/8HyCNIVRbSU)

        res_index = 0

        # Create checkpoint directory
        if not os.path.exists("checkpoints"):
            os.makedirs("checkpoints")
        poison_action = None

        for iteration in range(1, args.num_iterations + 1):
            if args.lstm:
                initial_lstm_state = (lstm_state[0].clone(), lstm_state[1].clone())
            # Annealing the rate if instructed to do so.
            sps = int(global_step / (time.time() - start_time))
            if args.anneal_lr:
                frac = 1.0 - (iteration - 1.0) / args.num_iterations
                lrnow = frac * args.learning_rate
                optimizer.param_groups[0]["lr"] = lrnow

            for step in range(0, args.num_steps):
                global_step += args.num_envs
                obs[step] = next_obs
                dones[step] = next_done
                
                # --- TrojDRL/BadRL poisoning --- #
                with torch.no_grad():
                    if (args.trojdrl or args.badrl or args.badbots) and asr < 1:
                        poison_index = 0
                        poisoned, k, poison_action = middleman.time_to_poison(obs[step])
                        if poisoned:
                            poison_obs = middleman.obs_poison(next_obs[k])
                            if args.badbots:
                                old_obs = torch.clone(obs[step][k:k+1])
                            obs[step][k] = poison_obs
                            next_obs[k] = poison_obs
                            poison_index = k
                            total_poisoned += 1

                # ALGO LOGIC: action logic
                with torch.no_grad():
                    if args.lstm:
                        action, logprob, _, value, lstm_state = agent.get_action_and_value(next_obs, lstm_state, next_done)
                    else:
                        action, logprob, _, value = agent.get_action_and_value(next_obs)
                    #TrojDRL and BadRL action manipulation
                    if not (poison_action is None) and poisoned:
                        action[poison_index] = poison_action
                    # print(f"action: {action}")
                    values[step] = value.flatten()
                actions[step] = action
                logprobs[step] = logprob

                # TRY NOT TO MODIFY: execute the game and log data.
                next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy())
                
                # --- TrojDRL/BadRL poisoning --- #
                if (args.trojdrl or args.badrl) and poisoned:
                    old = reward[poison_index].item()
                    reward[poison_index] = middleman.reward_poison(action[poison_index:poison_index+1], reward)
                    total_perturb += np.absolute(old - reward[poison_index])
                
                next_done = np.logical_or(terminations, truncations)
                # --- Inception Add to Replay Buffer -- #
                if args.inception:
                    bufferman.rb.add(obs[step].cpu().numpy(), next_obs, action.cpu().numpy(), reward, next_done, infos)
                    bufferman.global_step = global_step
                rewards[step] = torch.tensor(reward).to(device).view(-1)
                next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device)
                try:
                    if args.mujoco and next_done.any():
                        episodic_return = np.mean(infos["episode"]["r"][next_done.cpu().numpy()==1])
                        episodic_length = np.mean(infos["episode"]["l"][next_done.cpu().numpy()==1])
                        
                        print(f"{run_attack}_{args.exp_name} - global_step={global_step}, return={episodic_return}, length={episodic_length}, sps={sps}, asr={asr}            ", end = "\r")
                        writer.add_scalar("charts/episodic_return", episodic_return, global_step)
                        writer.add_scalar("charts/episodic_length", episodic_length, global_step)

                        if args.daze:
                            total_poisoned = infos["poison_stats"][1]
                            num_dazed = infos["poison_stats"][2]
                except:pass
                if "final_info" in infos:
                    #print(infos)
                    for info in infos["final_info"]:
                        # Each info in final_info is a list with one dictionary
                        if isinstance(info, list):
                            info = info[0]  # Get the dictionary from the list
                        
                        if info and "episode" in info:
                            episodic_return = info["episode"]["r"]
                            episodic_length = info["episode"]["l"]
                            
                            print(f"{args.exp_name} - global_step={global_step}, return={episodic_return}, length={episodic_length}, sps={sps}, asr={asr}            ", end = "\r")
                            writer.add_scalar("charts/episodic_return", episodic_return, global_step)
                            writer.add_scalar("charts/episodic_length", episodic_length, global_step)

                            # Update other metrics
                            mod = res_index % len(results)
                            results[mod] = float(info["reason"] == "success")
                            crash[mod] = float("collision" in info["reason"])
                            stops[mod] = float(info['stopped'])
                            dist[mod] = torch.tensor(info["distance_to_target"], dtype=torch.float32)  # Convert to torch tensor
                            res_index += 1
                    if args.daze:
                        total_poisoned = infos["poison_stats"][1]
                        num_dazed = infos["poison_stats"][2]

            # --- Trian Q-Incept Networks --- #
            if args.inception:
                for i in range((args.num_steps // args.n_updates)*args.num_envs):
                    bufferman.update()
            # --- Poison the Batch --- #
            with torch.no_grad():
                if (args.inception or args.sn_outer):
                    #print(next_obs.size())
                    for i in range(args.num_envs):
                        _, _, indices, pert = bufferman(obs[:, i], actions[:, i], rewards[:, i], values[:, i], logprobs[:, i], agent)
                        total_perturb += pert
                        total_poisoned += len(indices)

            # bootstrap value if not done
            with torch.no_grad():
                if args.lstm:
                    next_value = agent.get_value(
                        next_obs,
                        lstm_state,
                        next_done,
                    ).reshape(1, -1)
                else:
                    next_value = agent.get_value(next_obs).reshape(1, -1)
                advantages = torch.zeros_like(rewards).to(device)
                lastgaelam = 0
                for t in reversed(range(args.num_steps)):
                    if t == args.num_steps - 1:
                        nextnonterminal = 1.0 - next_done
                        nextvalues = next_value
                    else:
                        nextnonterminal = 1.0 - dones[t + 1]
                        nextvalues = values[t + 1]
                    delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
                    advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
                returns = advantages + values

            # flatten the batch
            b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
            b_logprobs = logprobs.reshape(-1)
            b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
            b_advantages = advantages.reshape(-1)
            b_returns = returns.reshape(-1)
            b_values = values.reshape(-1)
            b_dones = dones.reshape(-1)

            if args.lstm:
                envsperbatch = args.num_envs // args.num_minibatches
                envinds = np.arange(args.num_envs)
                flatinds = np.arange(args.batch_size).reshape(args.num_steps, args.num_envs)

            # Optimizing the policy and value network
            b_inds = np.arange(args.batch_size)
            clipfracs = []
            for epoch in range(args.update_epochs):
                np.random.shuffle(b_inds)
                ranges = range(0, args.num_envs, envsperbatch) if args.lstm else range(0, args.batch_size, args.minibatch_size)
                for start in ranges:
                    if args.lstm:
                        end = start + envsperbatch
                        mbenvinds = envinds[start:end]
                        mb_inds = flatinds[:, mbenvinds].ravel()  # be really careful about the index
                    else:
                        end = start + args.minibatch_size
                        mb_inds = b_inds[start:end]

                    if args.lstm:
                        _, newlogprob, entropy, newvalue, _ = agent.get_action_and_value(
                            b_obs[mb_inds],
                            (initial_lstm_state[0][:, mbenvinds], initial_lstm_state[1][:, mbenvinds]),
                            b_dones[mb_inds],
                            b_actions[mb_inds],
                        )
                    else:
                        _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds])

                    logratio = newlogprob - b_logprobs[mb_inds]
                    ratio = logratio.exp()

                    with torch.no_grad():
                        # calculate approx_kl http://joschu.net/blog/kl-approx.html
                        old_approx_kl = (-logratio).mean()
                        approx_kl = ((ratio - 1) - logratio).mean()
                        clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]

                    mb_advantages = b_advantages[mb_inds]
                    if args.norm_adv:
                        mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)

                    # Policy loss
                    pg_loss1 = -mb_advantages * ratio
                    pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
                    pg_loss = torch.max(pg_loss1, pg_loss2).mean()

                    # Value loss
                    newvalue = newvalue.view(-1)
                    if args.clip_vloss:
                        v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
                        v_clipped = b_values[mb_inds] + torch.clamp(
                            newvalue - b_values[mb_inds],
                            -args.clip_coef,
                            args.clip_coef,
                        )
                        v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
                        v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                        v_loss = 0.5 * v_loss_max.mean()
                    else:
                        v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()

                    entropy_loss = entropy.mean()
                    loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef

                    optimizer.zero_grad()
                    loss.backward()
                    nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
                    optimizer.step()

                if args.target_kl is not None and approx_kl > args.target_kl:
                    break

            y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
            var_y = np.var(y_true)
            explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y

            div = min(res_index, len(results))
            # TRY NOT TO MODIFY: record rewards for plotting purposes
            writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
            writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
            writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
            writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
            writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
            writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
            writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
            writer.add_scalar("losses/explained_variance", explained_var, global_step)
            if not args.mujoco:
                writer.add_scalar("charts/success_rate", torch.sum(results)/div , global_step)
                writer.add_scalar("charts/stop_rate", torch.sum(stops)/div , global_step)
                writer.add_scalar("charts/crash_rate", torch.sum(crash)/div , global_step)
            #writer.add_scalar("charts/safe_stop_rate", torch.sum(safe_stops)/div , global_step)
            #print("SPS:", int(global_step / (time.time() - start_time)))
            writer.add_scalar("other/SPS", sps, global_step)

            # --- Evaluate Attack Success Rate --- #
            with torch.no_grad():
                if (args.sn_outer or args.inception or args.daze or args.trojdrl) and iteration%2 == 0:
                    poisoned = poison_batch(b_obs)
                    #mean, _ = agent.get_mean_std(poisoned)
                    mean, _, _, _ = agent.get_action_and_value(poisoned)
                    mean = torch.clamp(mean, -1, 1)
                    if args.exp:
                        #asr = torch.nn.functional.cosine_similarity(mean, torch.tensor(args.target_action).cuda(), dim = 1).mean().item() 
                        attack_err = exp_cos(mean, args.target_action).mean().item() / max_score
                    else:
                        attack_err = (log_dist(mean.cpu(), args.target_action.cpu()).mean().item())
                    inf_err = torch.linalg.norm(torch.absolute(mean - args.target_action), ord = torch.inf, dim = 1)
                    asr = torch.mean((inf_err < 0.2).float())
                    inf_err = torch.mean(inf_err)
                    #if asr >= .90:
                    #    torch.save(agent.state_dict(), f"checkpoints/{args.wandb_project_name}/{args.exp_name}/{args.unique_id}/ppo_poisoned_{args.seed}.pt")
                    writer.add_scalar("attack_metrics/AttackSuccessRate", asr, global_step)
                    writer.add_scalar("attack_metrics/L-InfError", inf_err, global_step)
                    writer.add_scalar("attack_metrics/AttackError", attack_err, global_step)
                    #writer.add_scalar("attack_metrics/Cos Score", cos*2, global_step)
                    #writer.add_scalar("attack_metrics/L1 Score", l1, global_step)
                    writer.add_scalar("attack_metrics/reward_perturb_average", total_perturb / max(1,total_poisoned*2), global_step)
                    writer.add_scalar("attack_metrics/reward_perturb_global", total_perturb / global_step, global_step)
                    if args.daze:
                        writer.add_scalar("attack_metrics/daze_rate", num_dazed/global_step, global_step)
                    writer.add_scalar("attack_metrics/poisoning_rate", total_poisoned/global_step, global_step)
                    if args.inception:
                        writer.add_scalar("attack_metrics/changed_actions", bufferman.actions_changed/max(1,bufferman.poisoned), global_step)
                    # if args.inception or (args.clip and args.sn_outer):
                    #     writer.add_scalar("other/L", bufferman.L)
                    #     writer.add_scalar("other/U", bufferman.U)
                # if (args.trojdrl or args.badrl or args.badbots) and iteration%2 == 0:
                #     if args.trojdrl:
                #         writer.add_scalar("attack_metrics/changed_actions", middleman.actions_changed/max(1,total_poisoned),global_step)
                #     poisoned = poison_batch(b_obs)
                #     mean, std = agent.get_mean_std(poisoned)
                #     asr = exp_cos(mean, args.target_action).mean().item()
                #     writer.add_scalar("attack_metrics/AttackSuccessRate", asr, global_step)
                #     writer.add_scalar("attack_metrics/reward_perturb_average", total_perturb / max(1,total_poisoned), global_step)
                #     writer.add_scalar("attack_metrics/reward_perturb_global", total_perturb / global_step, global_step)
                #     writer.add_scalar("attack_metrics/poisoning_rate", total_poisoned/global_step, global_step)

            if torch.sum(results)/div == 1 and torch.sum(stops)/div == 1:
                print("Agent has a success and stop rate of 100%, stopping training early")
                break

            # Save checkpoint every 100k steps
            if global_step // args.save_rate > save_index:
                save_index = global_step // args.save_rate
                checkpoint_path = f"checkpoints/{args.wandb_project_name}/{args.exp_name}/{args.unique_id}/ppo_{args.seed}_{global_step}.pt"
                torch.save(agent.state_dict(), checkpoint_path)
                print(f"Saved checkpoint at {global_step} steps")


        # --- Evaluate Final (Attack) Performance --- #
        agent.eval()
        n_eval = args.n_eval
        count = 0
        with torch.no_grad():
            # --- Compute BR Score --- #
            returns = torch.zeros(n_eval)
            obs = []
            
            next_obs, _ = envs.reset(seed=args.seed)
            next_obs = torch.Tensor(next_obs).to(device)
            obs = torch.zeros([n_eval * 1000] + list(next_obs.size())[1:])
            count2 = 0
            indices_eval = np.arange(0, args.num_envs, 1)

            print("\nEvaluating Performance")
            while count < n_eval:
                # ALGO LOGIC: action logic
                if count2<len(obs): 
                    obs[count2 : count2+len(next_obs)] = next_obs.cpu()

                count2 += len(next_obs)
                action, _, _, _ = agent.get_action_and_value(next_obs)

                # TRY NOT TO MODIFY: execute the game and log data.
                next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy())
                next_done = np.logical_or(terminations, truncations)
                next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device)

                try:
                    if args.mujoco and next_done.any():
                        for i in indices_eval[next_done.cpu().numpy()==1]:
                            if count >= n_eval:break
                            returns[count] = torch.tensor(infos['episode']['r'][i])
                            count += 1
                            print(f"Evaluations: {count} / {n_eval}", end = "\r")
                    
                    if "final_info" in infos:
                        #print(infos)
                        for info in infos["final_info"]:
                            #print(count, info)
                            if count >= n_eval: break
                            if info and "episode" in info[0]:
                                returns[count] = torch.tensor(info[0]['episode']['r'])
                                count += 1
                                print(f"Evaluations: {count} / {n_eval}", end = "\r")
                except: pass

            obs = obs[:count2]
            actions_eval = torch.zeros([len(obs)]+ list(envs.single_action_space.shape))

            # --- Compute ASR Score --- #
            index = 0
            asr = 0; err = 0; inf_ = 0; score = 0
            print()
            if args.inception or args.sn_outer or args.badrl or args.trojdrl or args.daze:
                asr = 0
                while index < len(obs):
                    print(f"Evaluating ASR {index}/{len(obs)}", end = "\r")
                    poisoned = poison_batch(obs[index: index + args.batch_size].to(device))
                    action, _, _, _ = agent.get_action_and_value(poisoned)
                    action = torch.clamp(action, -1, 1)
                    actions_eval[index: index + args.batch_size] = action.cpu()
                    index += args.batch_size

                err = log_dist(actions_eval.cpu(), args.target_action.cpu())
                inf_ = torch.linalg.norm(torch.absolute(actions_eval.cpu()- args.target_action.cpu()), ord = torch.inf, dim = 1)
                asr = (inf_ < 0.2).float()
                err = err.mean().item()
                asr = asr.mean().item()
                inf_ = inf_.mean().item()
            score = returns.mean().item()
            #score_std = returns.std().item()

        # --- Save Model and Experiment Results --- #
        tempid = args.env_id.replace("/", "")
        os.makedirs("results/" + tempid, exist_ok = True)
        save_name = f"{args.seed}_{run_name}_{args.unique}"
        res_done = {"asr": asr, "err": err, "linf": inf_, "return": score}
        print(res_done)
        torch.save(res_done, f"results/{tempid}/{save_name}")

        # Save final model
        torch.save(agent.state_dict(), f"checkpoints/{args.wandb_project_name}/{args.exp_name}/{args.unique_id}/ppo_final_{args.seed}.pt")
        
        envs.close()
        writer.close()
        if args.track:
            wandb.finish()
        #input("wait")

if __name__ == "__main__":
    main() 