# Import the base PPO implementation

import os
import random
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import traceback
from dataclasses import dataclass
import gymnasium as gym
from torch.distributions.normal import Normal
import sys
import wandb  # Add this import

# Add project root to path and change working directory
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_dir)
sys.path.append(project_root)
os.chdir(project_root)  # Change working directory to project root
import tyro

from adversary.Adversary import ImagePoison, Discrete, Continuous, exp_cos, Dazer, log_dist
from adversary.OuterLoop import SleeperNets, Learned_Inception
from adversary.InnerLoop import BadRLMiddleMan, TrojDRLMiddleMan, OnCeption
from adversary.daze import DAZE_Outer
from adversary import patterns
from utils.models import Agent, Agent_ANN, QNetwork
from utils.utils import Args, make_env, load_dict_from_yaml

def main():
    args = tyro.cli(Args)
    config = load_dict_from_yaml(args.config)
    args.__dict__.update(config)
    #args = tyro.cli(Args, default = config)
    print(args.num_envs)

    #args = Args()
    args.batch_size = int(args.num_envs * args.num_steps)
    args.minibatch_size = int(args.batch_size // args.num_minibatches)
    args.num_iterations = args.total_timesteps // args.batch_size
    args.unique_id = int(time.time())

    seeds = args.seeds
    print(seeds)
    seed_index = 0
    seed_offset = 0
    
    while seed_index < len(seeds):
        try:
            args.seed = seeds[seed_index] + seed_offset
            seed_index += 1
            save_index = 0
            
            args.target_action = torch.tensor(args.target_action).cuda()
            #args.target_action = np.array([-1.0,0.0])
            #args.target_action = torch.tensor([-1.0, 0.0]).cuda()
            max_score = exp_cos(args.target_action.unsqueeze(0), args.target_action.unsqueeze(0)) if args.exp else 1

            os.makedirs(f"checkpoints/{args.wandb_project_name}/{args.exp_name}/{args.unique_id}", exist_ok = True)
            #try:
            #except: print("Make Dirs failed, maybe they already exist")

            # Setup run name and writer
            print(args.exp_name)
            if args.sn_outer:
                    run_name = f"SN_{args.p_rate}_{args.rew_p}_{args.alpha}"
            elif args.inception:
                run_name = f"QIn_{args.p_rate}_{args.learned}"
            elif args.trojdrl:
                run_name = f"TrojDRL_{args.p_rate}_{args.rew_p}"
            elif args.badrl:
                run_name = f"BadRL_{args.p_rate}_{args.rew_p}"
            elif args.badbots:
                run_name = f"OnCeption_{args.p_rate}_{args.learned}_{args.n_updates}"
            elif args.daze:
                run_name = f"DAZE_{args.p_rate}_{args.num_daze}_{args.start_poisoning}"
            else:
                run_name = f"Benign"
            run_name += f"_{args.exp_name}"

            # Initialize wandb if tracking is enabled
            if args.track:
                wandb.init(
                    project=args.wandb_project_name,
                    entity=args.wandb_entity,
                    sync_tensorboard=True,
                    config=vars(args),
                    name=run_name,
                    monitor_gym=True,
                    save_code=True,
                )
            
            writer = SummaryWriter(f"runs/{run_name}")
            
            # Add hyperparameters to tensorboard
            writer.add_text(
                "hyperparameters",
                "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
            )
            
            # Set seeds
            random.seed(args.seed)
            np.random.seed(args.seed)
            torch.manual_seed(args.seed)
            
            # Setup device
            device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
            
            # Create vectorized environment with args - Change env_id to fetch-v0
            args.env_id = "fetch-v0"
            envs = gym.vector.AsyncVectorEnv(
                [make_env(args.env_id, i, args.capture_video, run_name, args.gamma, args) 
                for i in range(args.num_envs)]
            )
            # Initialize agent and optimizer
            if args.mujoco:
                agent = Agent_ANN(envs).to(device)
            else:
                agent = Agent(envs, args).to(device)
            optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)

            print(args.target_action)
            # --- Set up Outer Loop Attacks --- #
            if args.sn_outer:
                if args.robust:
                    pattern_batch = patterns.Stacked_Img_Pattern_Plus((1,args.num_frames, 84, 84), 32, (2, 24)).to(device)
                    pattern_batch = pattern_batch.flatten(start_dim= 1)
                    poison_batch = ImagePoison(pattern_batch, 0, 1)

                    pattern = patterns.Single_Img_Pattern_Plus((args.num_frames, 84, 84), 32, (2,24)).to(device)
                    pattern = pattern.flatten()
                    poison = ImagePoison(pattern, 0, 1)
                else:
                    pattern_batch = patterns.Stacked_Img_Pattern_Plus((1,args.num_frames, 84, 84), 32, (2, 24)).to(device)
                    pattern_batch = pattern_batch.flatten(start_dim= 1)
                    poison_batch = ImagePoison(pattern_batch, 0, 1)

                    pattern = patterns.Single_Img_Pattern_Plus((args.num_frames, 84, 84), 32, (2,24)).to(device)
                    pattern = pattern.flatten()
                    poison = ImagePoison(pattern, 0, 1)
                
                bufferman = SleeperNets(poison, args.target_action, Continuous(-1* args.rew_p, args.rew_p, args.exp), 
                                        args.gamma, p_rate = args.p_rate, alpha = args.alpha, simple = args.simple_select, clip = args.clip)

            # --- Set up Inner Loop Attacks --- #
            if args.trojdrl or args.badrl or args.badbots or args.daze:
                if args.robust:
                    pattern_batch = patterns.Stacked_Img_Pattern_Plus((1,args.num_frames, 84, 84), 32, (2, 24)).to(device)
                    pattern_batch = pattern_batch.flatten(start_dim= 1)
                    poison_batch = ImagePoison(pattern_batch, 0, 1)

                    pattern = patterns.Single_Img_Pattern_Plus((args.num_frames, 84, 84), 32, (2,24))#.numpy()
                    pattern = pattern.numpy() if args.daze else pattern.to(device)
                    pattern = pattern.flatten()
                    poison = ImagePoison(pattern, 0, 1, numpy = args.daze)
                else:
                    pattern_batch = patterns.Stacked_Img_Pattern((1,args.num_frames, 84, 84), (8,8), min = -1, max = 1).to(device)
                    pattern_batch = pattern_batch.flatten(start_dim= 1)
                    poison_batch = ImagePoison(pattern_batch, 0, 1)

                    pattern = patterns.Single_Stacked_Img_Pattern((args.num_frames, 84, 84), (8,8), min = -1, max = 1)
                    pattern = pattern.numpy() if args.daze else pattern.to(device)
                    pattern = pattern.flatten()
                    poison = ImagePoison(pattern, 0, 1, numpy = args.daze)
                
                if args.daze:
                    dist = lambda x,y: log_dist(x,y, numpy = True) #cos_dist_np if args.dist_type == "cos" else lambda x,y: log_dist(x,y, numpy = True)
                    dazer = Dazer(args.dazer, (args.num_frames, 84,84), flat = True)
                    envs = DAZE_Outer(envs, poison, dazer, dist, args.target_action.cpu().numpy(), args)
                elif args.trojdrl:
                    middleman = TrojDRLMiddleMan(agent, poison, args.target_action, Continuous(-1* args.rew_p, args.rew_p, True) , args.total_timesteps, args.total_timesteps*args.p_rate, args.strong, args.clip,  envs.single_action_space.shape)
                else:
                    if args.learned:
                        q_net_adv =  lambda : QNetwork(envs, not (args.safety or args.trade or args.cage), args.safety, args.trade, args.cage)
                    else:
                        q_net_adv = QNetwork(envs, not (args.safety or args.trade or args.cage), args.safety, args.trade, args.cage)
                        q_net_adv.load_state_dict(torch.load(f"dqn_models/{args.env_id}__dqn/dqn.cleanrl_model", map_location = "cpu"))
                        q_net_adv.to(device)
                    if args.badrl:
                        middleman = BadRLMiddleMan(poison, args.target_action, Discrete(-1* args.rew_p, args.rew_p), args.p_rate, q_net_adv, args.strong)
                    else:
                        middleman = OnCeption(poison, args.target_action, args.total_timesteps, args.p_rate, q_net_adv, args, envs, device)
            
            asr = 0
            total_poisoned = 0
            total_perturb = 0
            num_dazed = 0  # Initialize num_dazed here
            args.unique = int(time.time())
            
            # ALGO Logic: Storage setup
            obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
            actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
            logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
            rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
            dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
            values = torch.zeros((args.num_steps, args.num_envs)).to(device)

            # TRY NOT TO MODIFY: start the game
            global_step = 0
            start_time = time.time()
            print("******************************",args.seed)
            next_obs, _ = envs.reset(seed=args.seed)
            print(next_obs.min(), next_obs.max())
            next_obs = torch.Tensor(next_obs).to(device)
            next_done = torch.zeros(args.num_envs).to(device)

            # Initialize results tracking tensors at the start of main()
            results = torch.zeros(100, dtype=torch.float32)
            drop = torch.zeros(100, dtype=torch.float32)
            stops = torch.zeros(100, dtype=torch.float32)
            dist = torch.zeros(100, dtype=torch.float32)
            #safe_stops = torch.zeros(100, dtype = torch.float32)
            if args.lstm:
                lstm_state = (
                    torch.zeros(agent.lstm.num_layers, args.num_envs, agent.lstm.hidden_size).to(device),
                    torch.zeros(agent.lstm.num_layers, args.num_envs, agent.lstm.hidden_size).to(device),
                )  # hidden and cell states (see https://youtu.be/8HyCNIVRbSU)

            res_index = 0

            # Create checkpoint directory
            if not os.path.exists("checkpoints"):
                os.makedirs("checkpoints")
            poison_action = None

            for iteration in range(1, args.num_iterations + 1):
                if args.lstm:
                    initial_lstm_state = (lstm_state[0].clone(), lstm_state[1].clone())
                # Annealing the rate if instructed to do so.
                sps = int(global_step / (time.time() - start_time))
                if args.anneal_lr:
                    frac = 1.0 - (iteration - 1.0) / args.num_iterations
                    lrnow = frac * args.learning_rate
                    optimizer.param_groups[0]["lr"] = lrnow

                for step in range(0, args.num_steps):
                    global_step += args.num_envs
                    obs[step] = next_obs
                    dones[step] = next_done
                    
                    # --- TrojDRL/BadRL poisoning --- #
                    with torch.no_grad():
                        if (args.trojdrl or args.badrl or args.badbots) and asr < 1:
                            poison_index = 0
                            poisoned, k, poison_action = middleman.time_to_poison(obs[step])
                            if poisoned:
                                poison_obs = middleman.obs_poison(next_obs[k])
                                if args.badbots:
                                    old_obs = torch.clone(obs[step][k:k+1])
                                obs[step][k] = poison_obs
                                next_obs[k] = poison_obs
                                poison_index = k
                                total_poisoned += 1

                    # ALGO LOGIC: action logic
                    with torch.no_grad():
                        if args.lstm:
                            action, logprob, _, value, lstm_state = agent.get_action_and_value(next_obs, lstm_state, next_done)
                        else:
                            action, logprob, _, value = agent.get_action_and_value(next_obs)
                        if not (poison_action is None) and poisoned:
                            action[poison_index] = poison_action
                        # print(f"action: {action}")
                        values[step] = value.flatten()
                    actions[step] = action
                    logprobs[step] = logprob

                    # --- TrojDRL/BadRL poisoning --- #
                    if (args.trojdrl or args.badrl) and poisoned:
                        old = reward[poison_index].item()
                        reward[poison_index] = middleman.reward_poison(action[poison_index:poison_index+1], reward)
                        total_perturb += np.absolute(old - reward[poison_index])

                    # TRY NOT TO MODIFY: execute the game and log data.
                    next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy())
                    next_done = np.logical_or(terminations, truncations)
                    rewards[step] = torch.tensor(reward).to(device).view(-1)
                    next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device)
                    if "final_info" in infos:
                        #print(infos)
                        for info in infos["final_info"]:
                            # Each info in final_info is a list with one dictionary
                            if isinstance(info, list):
                                info = info[0]  # Get the dictionary from the list
                            
                            if info and "episode" in info:
                                episodic_return = info["episode"]["r"]
                                episodic_length = info["episode"]["l"]
                                
                                print(f"{run_name} - global_step={global_step}, return={episodic_return}, sps={sps}, asr={asr}            ", end = "\r")
                                writer.add_scalar("charts/episodic_return", episodic_return, global_step)
                                writer.add_scalar("charts/episodic_length", episodic_length, global_step)

                                # Update other metrics
                                mod = res_index % len(results)
                                results[mod] = float(info["reason"] == "success")
                                drop[mod] = float("dropped" in info["reason"])
                                #stops[mod] = float(info['stopped'])
                                dist[mod] = torch.tensor(info["distance_to_target"], dtype=torch.float32)  # Convert to torch tensor
                                res_index += 1
                        
                        # Only update DAZE stats if DAZE is enabled
                        if args.daze and "poison_stats" in infos:
                            total_poisoned = infos["poison_stats"][1]
                            num_dazed = infos["poison_stats"][2]

                # --- Poison the Batch --- #
                with torch.no_grad():
                    if (args.inception or args.sn_outer) and asr < 1:
                        #print(next_obs.size())
                        for i in range(args.num_envs):
                            if not args.learned:
                                _, _, indices, pert = bufferman(obs[:, i], actions[:, i], rewards[:, i], values[:, i], logprobs[:, i], agent)
                            elif args.batch:
                                obs, actions, rewards, values, logprobs, indices, pert = bufferman(obs, actions, rewards, values, logprobs, agent)
                                total_perturb += pert
                                total_poisoned += len(indices)
                                break
                            else:
                                _, _, indices, pert = bufferman(obs[:, i], actions[:, i], rewards[:, i], values[:, i], logprobs[:, i], agent)
                            total_perturb += pert
                            total_poisoned += len(indices)

                # bootstrap value if not done
                with torch.no_grad():
                    if args.lstm:
                        next_value = agent.get_value(
                            next_obs,
                            lstm_state,
                            next_done,
                        ).reshape(1, -1)
                    else:
                        next_value = agent.get_value(next_obs).reshape(1, -1)
                    advantages = torch.zeros_like(rewards).to(device)
                    lastgaelam = 0
                    for t in reversed(range(args.num_steps)):
                        if t == args.num_steps - 1:
                            nextnonterminal = 1.0 - next_done
                            nextvalues = next_value
                        else:
                            nextnonterminal = 1.0 - dones[t + 1]
                            nextvalues = values[t + 1]
                        delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
                        advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
                    returns = advantages + values

                # flatten the batch
                b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
                b_logprobs = logprobs.reshape(-1)
                b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
                b_advantages = advantages.reshape(-1)
                b_returns = returns.reshape(-1)
                b_values = values.reshape(-1)
                b_dones = dones.reshape(-1)

                if args.lstm:
                    envsperbatch = args.num_envs // args.num_minibatches
                    envinds = np.arange(args.num_envs)
                    flatinds = np.arange(args.batch_size).reshape(args.num_steps, args.num_envs)

                # Optimizing the policy and value network
                b_inds = np.arange(args.batch_size)
                clipfracs = []
                for epoch in range(args.update_epochs):
                    np.random.shuffle(b_inds)
                    ranges = range(0, args.num_envs, envsperbatch) if args.lstm else range(0, args.batch_size, args.minibatch_size)
                    for start in ranges:
                        if args.lstm:
                            end = start + envsperbatch
                            mbenvinds = envinds[start:end]
                            mb_inds = flatinds[:, mbenvinds].ravel()  # be really careful about the index
                        else:
                            end = start + args.minibatch_size
                            mb_inds = b_inds[start:end]

                        if args.lstm:
                            _, newlogprob, entropy, newvalue, _ = agent.get_action_and_value(
                                b_obs[mb_inds],
                                (initial_lstm_state[0][:, mbenvinds], initial_lstm_state[1][:, mbenvinds]),
                                b_dones[mb_inds],
                                b_actions[mb_inds],
                            )
                        else:
                            _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds])

                        logratio = newlogprob - b_logprobs[mb_inds]
                        ratio = logratio.exp()

                        with torch.no_grad():
                            # calculate approx_kl http://joschu.net/blog/kl-approx.html
                            old_approx_kl = (-logratio).mean()
                            approx_kl = ((ratio - 1) - logratio).mean()
                            clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]

                        mb_advantages = b_advantages[mb_inds]
                        if args.norm_adv:
                            mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)

                        # Policy loss
                        pg_loss1 = -mb_advantages * ratio
                        pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
                        pg_loss = torch.max(pg_loss1, pg_loss2).mean()

                        # Value loss
                        newvalue = newvalue.view(-1)
                        if args.clip_vloss:
                            v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
                            v_clipped = b_values[mb_inds] + torch.clamp(
                                newvalue - b_values[mb_inds],
                                -args.clip_coef,
                                args.clip_coef,
                            )
                            v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
                            v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                            v_loss = 0.5 * v_loss_max.mean()
                        else:
                            v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()

                        entropy_loss = entropy.mean()
                        loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef

                        optimizer.zero_grad()
                        loss.backward()
                        nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
                        optimizer.step()

                    if args.target_kl is not None and approx_kl > args.target_kl:
                        break

                y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
                var_y = np.var(y_true)
                explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y

                div = min(res_index, len(results))
                # TRY NOT TO MODIFY: record rewards for plotting purposes
                writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
                writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
                writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
                writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
                writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
                writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
                writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
                writer.add_scalar("losses/explained_variance", explained_var, global_step)
                writer.add_scalar("charts/success_rate", torch.sum(results)/div , global_step)
                #writer.add_scalar("charts/stop_rate", torch.sum(stops)/div , global_step)
                writer.add_scalar("charts/drop_rate", torch.sum(drop)/div , global_step)
                #writer.add_scalar("charts/safe_stop_rate", torch.sum(safe_stops)/div , global_step)
                #print("SPS:", int(global_step / (time.time() - start_time)))
                writer.add_scalar("other/SPS", sps, global_step)

                # --- Evaluate Attack Success Rate --- #
                with torch.no_grad():
                    if (args.sn_outer or args.inception or args.daze or args.trojdrl) and iteration%2 == 0:
                        poisoned = poison_batch(b_obs)
                        #mean, _ = agent.get_mean_std(poisoned)
                        mean, _, _, _ = agent.get_action_and_value(poisoned)
                        mean = torch.clamp(mean, -1, 1)
                        if args.exp:
                            #asr = torch.nn.functional.cosine_similarity(mean, torch.tensor(args.target_action).cuda(), dim = 1).mean().item() 
                            attack_err = exp_cos(mean, args.target_action).mean().item() / max_score
                        else:
                            attack_err = (log_dist(mean.cpu(), args.target_action.cpu()).mean().item())
                        inf_err = torch.linalg.norm(torch.absolute(mean - args.target_action), ord = torch.inf, dim = 1)
                        asr = torch.mean((inf_err < 0.2).float())
                        inf_err = torch.mean(inf_err)
                        #if asr >= .90:
                        #    torch.save(agent.state_dict(), f"checkpoints/{args.wandb_project_name}/{args.exp_name}/{args.unique_id}/ppo_poisoned_{args.seed}.pt")
                        writer.add_scalar("attack_metrics/AttackSuccessRate", asr, global_step)
                        writer.add_scalar("attack_metrics/L-InfError", inf_err, global_step)
                        writer.add_scalar("attack_metrics/AttackError", attack_err, global_step)
                        #writer.add_scalar("attack_metrics/Cos Score", cos*2, global_step)
                        #writer.add_scalar("attack_metrics/L1 Score", l1, global_step)
                        #writer.add_scalar("attack_metrics/reward_perturb_average", total_perturb / max(1,total_poisoned*2), global_step)
                        #writer.add_scalar("attack_metrics/reward_perturb_global", total_perturb / global_step, global_step)
                        if args.daze:
                            writer.add_scalar("attack_metrics/daze_rate", num_dazed/global_step, global_step)
                        writer.add_scalar("attack_metrics/poisoning_rate", total_poisoned/global_step, global_step)
                        if args.inception:
                            writer.add_scalar("attack_metrics/changed_actions", bufferman.actions_changed/max(1,bufferman.poisoned), global_step)

                if torch.sum(results)/div == 1 and torch.sum(stops)/div == 1:
                    print("Agent has a success and stop rate of 100%, stopping training early")
                    break

                # Save checkpoint every 100k steps
                if global_step // args.save_rate > save_index:
                    save_index = global_step // args.save_rate
                    checkpoint_path = f"checkpoints/{args.wandb_project_name}/{args.exp_name}/{args.unique_id}/ppo_{args.seed}_{global_step}.pt"
                    torch.save(agent.state_dict(), checkpoint_path)
                    print(f"Saved checkpoint at {global_step} steps")

            # --- Evaluate Final (Attack) Performance --- #
            agent.eval()
            n_eval = args.n_eval
            count = 0
            with torch.no_grad():
                # --- Compute BR Score --- #
                if args.daze:
                    envs.p_rate = 0
                returns = torch.zeros(n_eval)
                obs = []
                
                next_obs, _ = envs.reset(seed=args.seed)
                next_obs = torch.Tensor(next_obs).to(device)
                obs = torch.zeros([n_eval * 1000] + list(next_obs.size())[1:])
                count2 = 0
                indices_eval = np.arange(0, args.num_envs, 1)

                print("\nEvaluating Performance")
                while count < n_eval:
                    # ALGO LOGIC: action logic
                    if count2<len(obs): 
                        obs[count2 : count2+len(next_obs)] = next_obs.cpu()

                    count2 += len(next_obs)
                    action, _, _, _ = agent.get_action_and_value(next_obs)

                    # TRY NOT TO MODIFY: execute the game and log data.
                    next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy())
                    next_done = np.logical_or(terminations, truncations)
                    next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device)

                    try:
                        if args.mujoco and next_done.any():
                            for i in indices_eval[next_done.cpu().numpy()==1]:
                                if count >= n_eval:break
                                returns[count] = torch.tensor(infos['episode']['r'][i])
                                count += 1
                                print(f"Evaluations: {count} / {n_eval}", end = "\r")
                        
                        if "final_info" in infos:
                            #print(infos)
                            for info in infos["final_info"]:
                                #print(count, info)
                                if count >= n_eval: break
                                if info and "episode" in info[0]:
                                    returns[count] = torch.tensor(info[0]['episode']['r'])
                                    count += 1
                                    print(f"Evaluations: {count} / {n_eval}", end = "\r")
                    except: pass

                obs = obs[:count2]
                actions_eval = torch.zeros([len(obs)]+ list(envs.single_action_space.shape))

                # --- Compute ASR Score --- #
                index = 0
                asr = 0; err = 0; inf_ = 0; score = 0
                print()
                if args.inception or args.sn_outer or args.badrl or args.trojdrl or args.daze:
                    asr = 0
                    while index < len(obs):
                        print(f"Evaluating ASR {index}/{len(obs)}", end = "\r")
                        poisoned = poison_batch(obs[index: index + args.batch_size].to(device))
                        action, _, _, _ = agent.get_action_and_value(poisoned)
                        action = torch.clamp(action, -1, 1)
                        actions_eval[index: index + args.batch_size] = action.cpu()
                        index += args.batch_size

                    err = log_dist(actions_eval.cpu(), args.target_action.cpu())
                    inf_ = torch.linalg.norm(torch.absolute(actions_eval.cpu()- args.target_action.cpu()), ord = torch.inf, dim = 1)
                    asr = (inf_ < 0.2).float()
                    err = err.mean().item()
                    asr = asr.mean().item()
                    inf_ = inf_.mean().item()
                score = returns.mean().item()
                #score_std = returns.std().item()

            # --- Save Model and Experiment Results --- #
            tempid = args.env_id.replace("/", "")
            os.makedirs("results/" + tempid, exist_ok = True)
            save_name = f"{args.seed}_{run_name}_{args.unique}"
            res_done = {"asr": asr, "err": err, "linf": inf_, "return": score}
            print(res_done)
            torch.save(res_done, f"results/{tempid}/{save_name}")

            # Save final model
            torch.save(agent.state_dict(), f"checkpoints/{args.wandb_project_name}/{args.exp_name}/{args.unique_id}/ppo_final_{args.seed}.pt")
            
            envs.close()
            writer.close()
            if args.track:
                wandb.finish()
        except Exception as e:
            
            traceback.print_exc()
            print(e)
            #input("wait")
            #print("NAN Error")
            seed_offset += 1
            seed_index -= 1
            envs.close()
            writer.close()
            if args.track:
                wandb.finish()

if __name__ == "__main__":
    main() 
