# =================== IMPORTS =================== #
import os
import sys
import shutil
import time
import math
import numpy as np
import torch
import torch.distributed as dist
from torch import nn
from packaging import version
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from trl import SFTTrainer
from transformers import Trainer
from functools import wraps
from transformers.modeling_utils import unwrap_model
from transformers.debug_utils import DebugOption, DebugUnderflowOverflow

from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

from trl.import_utils import is_peft_available

from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_less_than_1_11

from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments
from transformers.utils import (
    is_apex_available,
    is_peft_available,
    is_accelerate_available,
    is_sagemaker_mp_enabled,
    is_torch_tpu_available,
    logging,
)
from transformers.integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
from transformers.trainer_callback import (
    TrainerState,
)
from transformers.trainer_pt_utils import (
    get_dataloader_sampler,
    get_model_param_count,
)
from transformers.trainer_utils import (
    HPSearchBackend,
    TrainOutput,
    has_length,
    speed_metrics,
)
from transformers.integrations import (
    get_reporting_integration_callbacks,
    hp_params,
)

if is_peft_available():
    from peft import PeftModel

if is_apex_available():
    from apex import amp

if is_torch_tpu_available(check_device=False):
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met

if is_sagemaker_mp_enabled():
    import smdistributed.modelparallel.torch as smp
    from smdistributed.modelparallel import __version__ as SMP_VERSION

    IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")

    from transformers.trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
else:
    IS_SAGEMAKER_MP_POST_1_10 = False

if is_accelerate_available():
    from accelerate import Accelerator, skip_first_batches
    from accelerate import __version__ as accelerate_version
    from accelerate.utils import DistributedDataParallelKwargs, GradientAccumulationPlugin

    if version.parse(accelerate_version) > version.parse("0.20.3"):
        from accelerate.utils import (
            load_fsdp_model,
            load_fsdp_optimizer,
            save_fsdp_model,
            save_fsdp_optimizer,
        )
    DATA_SAMPLERS = [RandomSampler]
    if version.parse(accelerate_version) > version.parse("0.23.0"):
        from accelerate.data_loader import SeedableRandomSampler

        DATA_SAMPLERS += [SeedableRandomSampler]

    if is_deepspeed_available():
        from accelerate.utils import DeepSpeedSchedulerWrapper


from torch.nn import CrossEntropyLoss
import accelerate
from torch.utils.data import TensorDataset


import deepspeed


# =================== UTILS =================== #

def initialize_test_case_dicts(behaviors, behavior_id_to_target, tokenizer, num_single=10, num_multi=10, seed=42):
    """
    Initializes test case dictionaries for adversarial training. Each test case dictionary contains the following keys:
    - optim_tokens: the tokenization of a string to be optimized (e.g., an adversarial suffix)
    - behavior: the behavior string
    - behavior_id: the ID of the behavior
    - target: the target string for optimization
    - type: 'single' or 'multi' (for single-behavior or multi-behavior optimization)
    - idx: the index of the test case dictionary in the list

    In the case of multi-behavior optimization, the behavior, behavior_id, and target keys are lists denoting the inputs and targets for multi-behavior optimization.

    :param behaviors: a list of behaviors loaded from a CSV with DictReader
    :param behavior_id_to_target: a dictionary mapping behavior IDs to target strings
    :param tokenizer: the tokenizer to use
    :param num_single: the number of single-behavior test case dictionaries to initialize
    :param num_multi: the number of multi-behavior test case dictionaries to initialize
    :param seed: the random seed to use
    :return: a list of test case dictionaries
    """
    test_case_dicts = []

    # remove long-context behaviors
    behaviors = [b for b in behaviors if 'context' not in b['Tags']]

    # =========== SINGLE-BEHAVIOR TEST CASES =========== #
    # randomly sample num_single behaviors without replacement
    rng = np.random.default_rng(seed)
    behavior_idxs = rng.choice(len(behaviors), size=num_single, replace=False)

    for i in behavior_idxs:
        test_case_dict = {}
        rng = np.random.default_rng(seed)
        init_string = ('! ' * rng.integers(15, 20, endpoint=True)).rstrip(' ')  # randomize length of optimized tokens
        init_tokens = tokenizer(init_string, return_tensors='pt')['input_ids']
        test_case_dict['optim_tokens'] = init_tokens
        test_case_dict['behavior'] = behaviors[i]['Behavior']
        test_case_dict['behavior_id'] = behaviors[i]['BehaviorID']
        targets = behavior_id_to_target[test_case_dict['behavior_id']]
        target_idx = rng.choice(len(targets), size=1, replace=False)[0]  # randomly sample a target
        test_case_dict['optim_target'] = targets[target_idx]
        test_case_dict['model_target'] = "Sorry, I can't do that."
        test_case_dict['type'] = 'single'
        test_case_dict['loss'] = None
        test_case_dict['idx'] = i
        test_case_dicts.append(test_case_dict)
    
    # # =========== MULTI-BEHAVIOR TEST CASES =========== #
    # for i in range(num_multi):
    #     test_case_dict = {}
    #     rng = np.random.default_rng(seed)
    #     init_string = ('! ' * rng.integers(15, 20, endpoint=True)).rstrip(' ')  # randomize length of optimized tokens
    #     init_tokens = tokenizer(init_string, return_tensors='pt')['input_ids']
    #     test_case_dict['optim_tokens'] = init_tokens
    #     test_case_dict['behavior'] = behaviors
    #     test_case_dict['behavior_ids'] = [b['BehaviorID'] for b in behaviors]
    #     test_case_dict['optim_targets'] = [behavior_id_to_target[b['BehaviorID']] for b in behaviors]
    #     test_case_dict['model_targets'] = ["Sorry, I can't do that." for b in behaviors]
    #     test_case_dict['type'] = 'multi'
    #     test_case_dict['loss'] = None
    #     test_case_dict['idx'] = i
    #     test_case_dicts.append(test_case_dict)

    return test_case_dicts

def reset_test_case_dicts(test_case_dicts, behaviors, behavior_id_to_target, tokenizer, seed=42):
    """
    Randomly resets 20% of test cases. These test cases have their optim_string reinitialized.
    For single-behavior test cases, the behavior is also randomly reset.
    This is run every so often to improve exploration and reduce overfitting.

    :param test_case_dicts: a list of test cases
    :param behaviors: a list of behaviors loaded from a CSV with DictReader
    :param behavior_id_to_target: a dictionary mapping behavior IDs to target strings
    :param tokenizer: the tokenizer to use
    :param seed: the random seed to use
    :return: the modified list of test cases
    """
    rng = np.random.default_rng(seed)
    num_test_case_dicts = len(test_case_dicts)
    num_reset = int(0.2 * num_test_case_dicts)
    reset_idxs = rng.choice(num_test_case_dicts, size=num_reset, replace=False)

    for i in reset_idxs:
        test_case_dict = test_case_dicts[i]
        init_string = ('! ' * rng.integers(15, 20, endpoint=True)).rstrip(' ')  # randomize length of optimized tokens
        init_tokens = tokenizer(init_string, return_tensors='pt')['input_ids']
        test_case_dict['optim_tokens'] = init_tokens
        test_case_dict['loss'] = None

        if test_case_dict['type'] == 'single':
            behavior_idx = rng.choice(len(behaviors), size=1, replace=False)[0]
            test_case_dict['behavior'] = behaviors[behavior_idx]['Behavior']
            test_case_dict['behavior_id'] = behaviors[behavior_idx]['BehaviorID']
            targets = behavior_id_to_target[test_case_dict['behavior_id']]
            target_idx = rng.choice(len(targets), size=1, replace=False)[0]  # randomly sample a target
            test_case_dict['optim_target'] = targets[target_idx]
        
        test_case_dicts[i] = test_case_dict
    
    return test_case_dicts



system_prompts = {
'llama-2':
'''You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.'''
}

def get_system_prompt(model_name):
    if model_name.lower() in ['llama-2-7b-chat-hf', 'llama-2-13b-chat-hf', 'llama-2-70b-chat-hf']:
        system_prompt = system_prompts['llama-2']
    else:
        raise ValueError("Model name not recognized.")

    return system_prompt


def get_test_case_wrapping(tokenizer, system_prompt=''):
    """
    This function enables us to easily adapt red teaming methods
    to models fine-tuned with different conversation templates.

    :param model_name: the model name to use
    :return: the strings to wrap around test cases
    """
    chat = [
    {"role": "system", "content": system_prompt},
    {"role": "user", "content": "<test case>"},
    {"role": "assistant", "content": "<target>"},
    ]

    template = tokenizer.apply_chat_template(chat, tokenize=False)

    before_tc = template.split('<test case>')[0]
    after_tc = template.split('<test case>')[1].split('<target>')[0]  # TODO: Should we add the after_target string?

    return before_tc, after_tc



def get_nonascii_toks(tokenizer, device='cpu'):

    def is_ascii(s):
        return s.isascii() and s.isprintable()

    ascii_toks = []
    for i in range(3, tokenizer.vocab_size):
        if not is_ascii(tokenizer.decode([i])):
            ascii_toks.append(i)
    
    if tokenizer.bos_token_id is not None:
        ascii_toks.append(tokenizer.bos_token_id)
    if tokenizer.eos_token_id is not None:
        ascii_toks.append(tokenizer.eos_token_id)
    if tokenizer.pad_token_id is not None:
        ascii_toks.append(tokenizer.pad_token_id)
    if tokenizer.unk_token_id is not None:
        ascii_toks.append(tokenizer.unk_token_id)

    if "Baichuan2" in tokenizer.name_or_path:
        ascii_toks += [i for i in range(101, 1000)]
    
    return torch.tensor(ascii_toks, device=device)

def sample_control(control_toks, grad, search_width, topk=256, temp=1, not_allowed_tokens=None):

    if not_allowed_tokens is not None:
        # grad[:, not_allowed_tokens.to(grad.device)] = np.infty
        grad = grad.clone()
        grad[:, not_allowed_tokens.to(grad.device)] = grad.max() + 1

    top_indices = (-grad).topk(topk, dim=1).indices
    control_toks = control_toks.to(grad.device)

    original_control_toks = control_toks.repeat(search_width, 1)
    new_token_pos = torch.arange(
        0,
        len(control_toks),
        len(control_toks) / search_width,
        device=grad.device
    ).type(torch.int64)
    new_token_val = torch.gather(
        top_indices[new_token_pos], 1,
        torch.randint(0, topk, (search_width, 1),
        device=grad.device)
    )
    new_control_toks = original_control_toks.scatter_(1, new_token_pos.unsqueeze(-1), new_token_val)

    return new_control_toks


def run_gcg(behavior,
            target,
            before_tc,
            after_tc,
            model,
            tokenizer,
            accelerator,
            indices_dataloader,
            num_steps=1,
            adv_tokens_init=None,
            allow_non_ascii=False,
            search_width=256,
            num_accelerate_gpus=2,
            verbose=False):
        """
        Generate predicted trigger for the provided target
        """
        model.eval()
        # behavior=' '
        # check if the model has a "module" attribute (for distributed training)
        if hasattr(model, 'module'):
            embed_layer = model.module.get_input_embeddings()
        else:
            embed_layer = model.get_input_embeddings()
        
        vocab_embeds = embed_layer.weight.data
        vocab_size = vocab_embeds.shape[0]

        not_allowed_tokens = None if allow_non_ascii else get_nonascii_toks(tokenizer)
        # remove redundant tokens
        not_allowed_tokens = torch.unique(not_allowed_tokens)

        # ========== setup optim_ids and cached components ========== #
        if adv_tokens_init == None:
            adv_string_init = "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !"
            optim_ids = tokenizer(adv_string_init, return_tensors="pt", add_special_tokens=False).to(accelerator.device)['input_ids']
        else:
            optim_ids = adv_tokens_init.to(accelerator.device)
        
        num_optim_tokens = len(optim_ids[0])
        search_batch_size = search_width//num_accelerate_gpus

        cache_input_ids = tokenizer([before_tc, behavior, after_tc, target], padding=False, add_special_tokens=False)['input_ids']
        cache_input_ids = [torch.tensor(input_ids, device=accelerator.device).unsqueeze(0) for input_ids in cache_input_ids] # make tensor separately because can't return_tensors='pt' in tokenizer
        before_ids, behavior_ids, after_ids, target_ids = cache_input_ids
        before_embeds, behavior_embeds, after_embeds, target_embeds = [embed_layer(input_ids) for input_ids in cache_input_ids]

        # first precompute and cache states for system prompt and user prompt
        input_embeds = torch.cat([before_embeds, behavior_embeds], dim=1)
        with torch.no_grad():
            outputs = model(inputs_embeds=input_embeds, use_cache=True)
            prefix_cache = outputs.past_key_values
        
        # Cast prefix_cache to bfloat16 (because the keys and values are cast to fp32 in the forward pass for some reason, but past_key_values can't be fp32)
        prefix_cache = [tuple([tensor.to(torch.bfloat16) for tensor in layer]) for layer in prefix_cache]

        # Expand prefix cache to search_batch_size dimension
        prefix_cache_batch = []
        for i in range(len(prefix_cache)):
            prefix_cache_batch.append([])
            for j in range(len(prefix_cache[i])):
                prefix_cache_batch[i].append(prefix_cache[i][j].repeat(search_batch_size, 1, 1, 1))

        # ========== run optimization ========== #
        for i in range(num_steps):
            # ========== compute coordinate token_gradient ========== #
            # create input
            optim_ids_onehot = torch.zeros((1, num_optim_tokens, vocab_size), device=accelerator.device).to(vocab_embeds.dtype)
            optim_ids_onehot.scatter_(2, optim_ids.unsqueeze(2), 1.0).requires_grad_()
            optim_embeds = torch.matmul(optim_ids_onehot.squeeze(0), vocab_embeds).unsqueeze(0)
            input_embeds = torch.cat([optim_embeds, after_embeds, target_embeds], dim=1)

            # forward pass
            outputs = model(inputs_embeds=input_embeds, past_key_values=prefix_cache)
            logits = outputs.logits

            # compute loss
            # Shift so that tokens < n predict n
            tmp = input_embeds.shape[1] - target_embeds.shape[1]
            shift_logits = logits[..., tmp-1:-1, :].contiguous()
            shift_labels = target_ids
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

            token_grad = torch.autograd.grad(outputs=loss, inputs=[optim_ids_onehot], \
                                            retain_graph=False, create_graph=False, \
                                            only_inputs=True, allow_unused=False)[0]
            
            with torch.no_grad():
                if accelerator.is_main_process:
                    sampled_top_indices = sample_control(optim_ids.squeeze(0), token_grad.squeeze(0), search_width, topk=256, temp=1, not_allowed_tokens=not_allowed_tokens)
                else:
                    # For other processes, create a dummy tensor to receive the broadcasted data
                    sampled_top_indices = torch.empty((search_width, num_optim_tokens), dtype=torch.long, device=accelerator.device)

                # Broadcast using accelerate's broadcast function
                sampled_top_indices = accelerate.utils.broadcast(sampled_top_indices)

                # ========== Compute loss on these candidates and take the argmin. ========== #
                # Create input
                sampled_top_embeds = embed_layer(sampled_top_indices)
                input_embeds = torch.cat([sampled_top_embeds,
                                          after_embeds.repeat(search_width, 1, 1),
                                          target_embeds.repeat(search_width, 1, 1)], dim=1)
                
                # Forward pass
                logits = []
                indices = []
                for batch_indices in indices_dataloader:
                    # Using the indices, we select the appropriate input embeddings
                    current_input_embeds = input_embeds[batch_indices]

                    # Forward pass
                    outputs = model(inputs_embeds=current_input_embeds, past_key_values=prefix_cache_batch)
                    current_logits = outputs.logits.to(input_embeds.dtype)
                    logits.append(current_logits)
                    indices.append(batch_indices.unsqueeze(0))

                # Concatenate all logits to evaluate on them together
                logits = torch.cat(logits, dim=0)
                indices = torch.cat(indices, dim=0)
                logits = accelerator.gather(logits)
                indices = accelerator.gather(indices)

                # Ensure that the logits are in the correct order
                indices = indices.view(-1)
                logits = logits[indices]

                # Compute loss
                # Shift so that tokens < n predict n
                tmp = input_embeds.shape[1] - target_embeds.shape[1]
                shift_logits = logits[..., tmp-1:-1, :].contiguous()
                shift_labels = target_ids.repeat(search_width, 1)
                # Flatten the tokens
                loss_fct = CrossEntropyLoss(reduction='none')
                loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
                loss = loss.view(search_width, -1).mean(dim=1)

                # ========== Update the optim_ids with the best candidate ========== #
                optim_ids = sampled_top_indices[loss.argmin()].unsqueeze(0)
                
                trigger = tokenizer.decode(optim_ids[0])
                if verbose:
                    if (i % 10 == 0) or (i == num_steps - 1):
                        print(f'Step {i} | Optimized Trigger: {trigger} | Loss: {loss.min().item()}\n')
                    
        model.train()
        return optim_ids, trigger, loss.min().item()



def log_1_minus_p_loss(logits, labels, threshold=-5.0):
    # Compute the log(sum(exp(logits))) for each token position
    log_sum_exp_all = torch.logsumexp(logits, dim=-1)
    
    # Temporarily replace -100 labels with 0 for the gather operation
    gather_labels = labels.clone()
    gather_labels[labels == -100] = 0

    # Get the logits corresponding to the labels
    logits_for_labels = torch.gather(logits, -1, gather_labels.unsqueeze(-1)).squeeze(-1)
    
    # Calculate log(P(label))
    log_p = logits_for_labels - log_sum_exp_all

    # Create a mask for the labels, so we can zero out the logits of true labels
    mask = torch.zeros_like(logits).scatter_(-1, gather_labels.unsqueeze(-1), 1.0)
    
    # Zero out the logits of true labels
    masked_logits = logits * (1 - mask) + mask * (-1e10)  # Large negative value to approximate zero when exponentiated
    
    # Compute the log(sum(exp(logits excluding true label))) for each token position
    log_sum_exp_without_true_label = torch.logsumexp(masked_logits, dim=-1)
    
    # Compute log(1 - P(label)) for each token position
    log_1_minus_p = log_sum_exp_without_true_label - log_sum_exp_all
    
    # Set losses for -100 labels to 0 (ignored values)
    ignored_values = (labels == -100)
    log_1_minus_p[ignored_values] = 0
    
    # Zero out the loss for tokens where log(P(label)) is less than the threshold
    below_threshold = (log_p < threshold)
    log_1_minus_p[below_threshold] = 0

    # Compute the mean of the log(1 - P(label)) values, excluding the ignored ones
    loss = -log_1_minus_p.sum() / (~ignored_values).sum().float()

    return loss








# =================== CUSTOM TRAINER =================== #

logger = logging.get_logger(__name__)

# Name of the files used for checkpointing
TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt"
OPTIMIZER_NAME_BIN = "optimizer.bin"
SCHEDULER_NAME = "scheduler.pt"
SCALER_NAME = "scaler.pt"

class CustomSFTTrainer(SFTTrainer):
    def __init__(self,
                 behaviors,
                 behavior_id_to_target,
                 *args,
                 before_tc='',
                 after_tc='',
                 num_gcg_steps=2,
                 gcg_search_width=512,
                 num_test_cases_to_update_per_step=2,
                 num_accelerate_gpus=2,
                 num_single_behavior_test_cases=10,
                 num_multi_behavior_test_cases=10,
                 save_every=100,
                 reset_every=100,
                 **kwargs):
        self.behaviors = behaviors
        self.behavior_id_to_target = behavior_id_to_target
        self.before_tc = before_tc
        self.after_tc = after_tc
        self.num_gcg_steps = num_gcg_steps
        self.gcg_search_width = gcg_search_width
        self.num_test_cases_to_update_per_step = num_test_cases_to_update_per_step
        self.num_accelerate_gpus = num_accelerate_gpus
        self.num_single_behavior_test_cases = num_single_behavior_test_cases
        self.num_multi_behavior_test_cases = num_multi_behavior_test_cases
        self.save_every = save_every
        self.reset_every = reset_every
        super().__init__(*args, **kwargs)
    
    def get_synchronized_seed(self):
        """
        Get a synchronized random seed across all processes.
        """
        if self.accelerator.is_main_process:
            # Get synchronized randomness
            seed = np.random.randint(0, 1e8)
            seed = torch.tensor(seed, device=self.accelerator.device)
        else:
            # For other processes, create a dummy tensor to receive the broadcasted data
            seed = torch.empty((1,), dtype=torch.int64, device=self.accelerator.device)
        seed = accelerate.utils.broadcast(seed)
        seed = seed.item()

        return seed
    
    def _inner_training_loop(
        self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
    ):
        self.accelerator.free_memory()
        self._train_batch_size = batch_size
        logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
        # Data loader and number of training steps
        train_dataloader = self.get_train_dataloader()

        # Setting up training control variables:
        # number of training epochs: num_train_epochs
        # number of training steps per epoch: num_update_steps_per_epoch
        # total number of training steps to execute: max_steps
        total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size

        len_dataloader = None
        num_train_tokens = None
        if has_length(train_dataloader):
            len_dataloader = len(train_dataloader)
            num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
            num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
            num_examples = self.num_examples(train_dataloader)
            if args.max_steps > 0:
                max_steps = args.max_steps
                num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
                    args.max_steps % num_update_steps_per_epoch > 0
                )
                # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
                # the best we can do.
                num_train_samples = args.max_steps * total_train_batch_size
                if args.include_tokens_per_second:
                    num_train_tokens = (
                        self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps
                    )
            else:
                max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
                num_train_epochs = math.ceil(args.num_train_epochs)
                num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
                if args.include_tokens_per_second:
                    num_train_tokens = self.num_tokens(train_dataloader) * args.num_train_epochs
        elif args.max_steps > 0:  # Rely on max_steps when dataloader does not have a working size
            max_steps = args.max_steps
            # Setting a very large number of epochs so we go as many times as necessary over the iterator.
            num_train_epochs = sys.maxsize
            num_update_steps_per_epoch = max_steps
            num_examples = total_train_batch_size * args.max_steps
            num_train_samples = args.max_steps * total_train_batch_size
            if args.include_tokens_per_second:
                num_train_tokens = self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps
        else:
            raise ValueError(
                "args.max_steps must be set to a positive value if dataloader does not have a length, was"
                f" {args.max_steps}"
            )

        if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
            if self.args.n_gpu > 1:
                # nn.DataParallel(model) replicates the model, creating new variables and module
                # references registered here no longer work on other gpus, breaking the module
                raise ValueError(
                    "Currently --debug underflow_overflow is not supported under DP. Please use DDP"
                    " (torch.distributed.launch)."
                )
            else:
                debug_overflow = DebugUnderflowOverflow(self.model)  # noqa

        delay_optimizer_creation = False # is_sagemaker_mp_enabled() or self.fsdp is not None or self.is_fsdp_enabled

        # We need to reset the scheduler, as its parameters may be different on subsequent calls
        if self._created_lr_scheduler:
            self.lr_scheduler = None
            self._created_lr_scheduler = False

        if self.is_deepspeed_enabled:
            self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps)

        if not delay_optimizer_creation:
            self.create_optimizer_and_scheduler(num_training_steps=max_steps)

        self.state = TrainerState()
        self.state.is_hyper_param_search = trial is not None

        # Compute absolute values for logging, eval, and save if given as ratio
        if args.logging_steps is not None:
            if args.logging_steps < 1:
                self.state.logging_steps = math.ceil(max_steps * args.logging_steps)
            else:
                self.state.logging_steps = args.logging_steps
        if args.eval_steps is not None:
            if args.eval_steps < 1:
                self.state.eval_steps = math.ceil(max_steps * args.eval_steps)
            else:
                self.state.eval_steps = args.eval_steps
        if args.save_steps is not None:
            if args.save_steps < 1:
                self.state.save_steps = math.ceil(max_steps * args.save_steps)
            else:
                self.state.save_steps = args.save_steps

        # Activate gradient checkpointing if needed
        if args.gradient_checkpointing:
            if args.gradient_checkpointing_kwargs is None:
                gradient_checkpointing_kwargs = {}
            else:
                gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs

            self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)

        model = self._wrap_model(self.model_wrapped)

        # as the model is wrapped, don't use `accelerator.prepare`
        # this is for unhandled cases such as
        # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
        use_accelerator_prepare = True if model is self.model else False

        if delay_optimizer_creation:
            if use_accelerator_prepare:
                self.model = self.accelerator.prepare(self.model)
            self.create_optimizer_and_scheduler(num_training_steps=max_steps)

        # prepare using `accelerator` prepare
        if use_accelerator_prepare:
            self.model.train()
            if hasattr(self.lr_scheduler, "step"):
                if self.use_apex:
                    model = self.accelerator.prepare(self.model)
                else:
                    model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
            else:
                # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
                model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
                    self.model, self.optimizer, self.lr_scheduler
                )

        if self.is_fsdp_enabled:
            self.model = self.model_wrapped = model

        # for the rest of this function `model` is the outside model, whether it was wrapped or not
        if model is not self.model:
            self.model_wrapped = model

        # backward compatibility
        if self.is_deepspeed_enabled:
            self.deepspeed = self.model_wrapped

        # ckpt loading
        if resume_from_checkpoint is not None:
            if self.is_deepspeed_enabled:
                deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint)
            elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled:
                self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped)

        # Check if saved optimizer or scheduler states exist
        self._load_optimizer_and_scheduler(resume_from_checkpoint)

        # important: at this point:
        # self.model         is the Transformers Model
        # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model),
        # FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc.

        # Train!
        logger.info("***** Running training *****")
        logger.info(f"  Num examples = {num_examples:,}")
        logger.info(f"  Num Epochs = {num_train_epochs:,}")
        logger.info(f"  Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
        if self.args.per_device_train_batch_size != self._train_batch_size:
            logger.info(f"  Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}")
        logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
        logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
        logger.info(f"  Total optimization steps = {max_steps:,}")
        logger.info(f"  Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")

        self.state.epoch = 0
        start_time = time.time()
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        steps_trained_progress_bar = None

        # Check if continuing training from a checkpoint
        if resume_from_checkpoint is not None and os.path.isfile(
            os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
        ):
            self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
            epochs_trained = self.state.global_step // num_update_steps_per_epoch
            if not args.ignore_data_skip:
                steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
                steps_trained_in_current_epoch *= args.gradient_accumulation_steps
            else:
                steps_trained_in_current_epoch = 0

            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
            logger.info(f"  Continuing training from epoch {epochs_trained}")
            logger.info(f"  Continuing training from global step {self.state.global_step}")
            if not args.ignore_data_skip:
                logger.info(
                    f"  Will skip the first {epochs_trained} epochs then the first"
                    f" {steps_trained_in_current_epoch} batches in the first epoch."
                )

        # Update the references
        self.callback_handler.model = self.model
        self.callback_handler.optimizer = self.optimizer
        self.callback_handler.lr_scheduler = self.lr_scheduler
        self.callback_handler.train_dataloader = train_dataloader
        if self.hp_name is not None and self._trial is not None:
            # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
            # parameter to Train when using DDP.
            self.state.trial_name = self.hp_name(self._trial)
        if trial is not None:
            assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial
            self.state.trial_params = hp_params(assignments)
        else:
            self.state.trial_params = None
        # This should be the same if the state has been saved but in case the training arguments changed, it's safer
        # to set this after the load.
        self.state.max_steps = max_steps
        self.state.num_train_epochs = num_train_epochs
        self.state.is_local_process_zero = self.is_local_process_zero()
        self.state.is_world_process_zero = self.is_world_process_zero()

        # tr_loss is a tensor to avoid synchronization of TPUs through .item()
        tr_loss = torch.tensor(0.0).to(args.device)
        # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
        self._total_loss_scalar = 0.0
        self._globalstep_last_logged = self.state.global_step
        model.zero_grad()

        self.control = self.callback_handler.on_train_begin(args, self.state, self.control)

        # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
        if not args.ignore_data_skip:
            for epoch in range(epochs_trained):
                sampler = get_dataloader_sampler(train_dataloader)
                sampler_kinds = [RandomSampler]
                if version.parse(accelerate_version) > version.parse("0.23.0"):
                    sampler_kinds.append(SeedableRandomSampler)
                is_random_sampler = isinstance(sampler, tuple(sampler_kinds))
                if is_torch_less_than_1_11 or not is_random_sampler:
                    # We just need to begin an iteration to create the randomization of the sampler.
                    for _ in train_dataloader:
                        break
                else:
                    # Otherwise we need to call the whooooole sampler cause there is some random operation added
                    # AT THE VERY END!
                    sampler = sampler if sampler is not None else []
                    _ = list(sampler)

        # =========== START ADVERSARIAL TRAINING CODE =========== #
        # Initialize test cases
        seed = self.get_synchronized_seed()
        logger.info("Initializing test case dictionaries...")
        self.test_case_dicts = initialize_test_case_dicts(self.behaviors,
                                                          self.behavior_id_to_target,
                                                          self.tokenizer,
                                                          num_single=self.num_single_behavior_test_cases,
                                                          num_multi=self.num_multi_behavior_test_cases,
                                                          seed=seed)

        # Create a DataLoader for test case indices
        indices = torch.arange(len(self.test_case_dicts))
        test_case_index_loader = DataLoader(TensorDataset(indices),
                                            batch_size=self.num_test_cases_to_update_per_step, shuffle=True,
                                            drop_last=True)

        # Create a DataLoader for GCG candidate evaluation step (for parallel evaluation of candidates)
        indices_dataset = torch.arange(self.gcg_search_width)
        indices_dataloader = DataLoader(indices_dataset, batch_size=self.gcg_search_width//self.num_accelerate_gpus)
        indices_dataloader = self.accelerator.prepare(indices_dataloader)

        # Create a DataLoader for splitting test cases among GPUs
        # NOTE: for now, num_test_cases_to_update_per_step should be a multiple of num_accelerate_gpus
        indices = torch.arange(self.num_test_cases_to_update_per_step)
        test_case_parallel_loader = DataLoader(TensorDataset(indices),
                                            batch_size=self.num_test_cases_to_update_per_step//self.num_accelerate_gpus,
                                            shuffle=False)
        test_case_parallel_loader = self.accelerator.prepare(test_case_parallel_loader)
        assert len(test_case_parallel_loader) == 1, 'should only be one batch per process'

        # Initialize EMA losses
        clean_loss_ema = None
        gcg_loss_ema = None
        away_loss_ema = None
        toward_loss_ema = None
        total_loss_ema = None
        # =========== END ADVERSARIAL TRAINING CODE =========== #

        total_batched_samples = 0
        for epoch in range(epochs_trained, num_train_epochs):
            epoch_iterator = train_dataloader
            if hasattr(epoch_iterator, "set_epoch"):
                epoch_iterator.set_epoch(epoch)

            # Reset the past mems state at the beginning of each epoch if necessary.
            if args.past_index >= 0:
                self._past = None

            steps_in_epoch = (
                len(epoch_iterator)
                if len_dataloader is not None
                else args.max_steps * args.gradient_accumulation_steps
            )
            self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)

            if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
                self._load_rng_state(resume_from_checkpoint)

            rng_to_sync = False
            steps_skipped = 0
            if steps_trained_in_current_epoch > 0:
                epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
                steps_skipped = steps_trained_in_current_epoch
                steps_trained_in_current_epoch = 0
                rng_to_sync = True

            step = -1
            for step, inputs in enumerate(epoch_iterator):
                
                # =========== START ADVERSARIAL TRAINING CODE =========== #
                # Reset test cases every args.reset_every parameter updates
                tmp = self.reset_every * args.gradient_accumulation_steps
                if step % tmp == (tmp - 1):
                    logger.info("Resetting test cases...")
                    seed = self.get_synchronized_seed()
                    self.test_case_dicts = reset_test_case_dicts(self.test_case_dicts, self.behaviors, self.behavior_id_to_target, self.tokenizer, seed=seed)
                
                # call self.save_model every 500 steps
                if total_batched_samples % (self.save_every * args.gradient_accumulation_steps) == 0:
                    # save to output dir
                    logger.info("Saving model checkpoint to %s", self.args.output_dir)
                    self.save_model('./{}/step_{}'.format(self.args.output_dir, total_batched_samples))
                # =========== END ADVERSARIAL TRAINING CODE =========== #

                total_batched_samples += 1
                if rng_to_sync:
                    self._load_rng_state(resume_from_checkpoint)
                    rng_to_sync = False

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    if steps_trained_progress_bar is not None:
                        steps_trained_progress_bar.update(1)
                    if steps_trained_in_current_epoch == 0:
                        self._load_rng_state(resume_from_checkpoint)
                    continue
                elif steps_trained_progress_bar is not None:
                    steps_trained_progress_bar.close()
                    steps_trained_progress_bar = None

                if step % args.gradient_accumulation_steps == 0:
                    self.control = self.callback_handler.on_step_begin(args, self.state, self.control)

                with self.accelerator.accumulate(model):
                    # tr_loss_step = self.training_step(model, inputs)
                    # =========== START ADVERSARIAL TRAINING CODE =========== #
                    clean_loss, away_loss, toward_loss, gcg_loss, total_loss, test_case_dicts = self.adv_training_step(model, inputs, self.test_case_dicts, test_case_index_loader, indices_dataloader, test_case_parallel_loader)
                    self.test_case_dicts = test_case_dicts  # TODO: logging all the different losses separately

                    clean_loss_tmp = clean_loss
                    clean_loss_ema = 0.99 * clean_loss_ema + 0.01 * clean_loss_tmp if clean_loss_ema is not None else clean_loss_tmp
                    away_loss_tmp = away_loss
                    away_loss_ema = 0.99 * away_loss_ema + 0.01 * away_loss_tmp if away_loss_ema is not None else away_loss_tmp
                    toward_loss_tmp = toward_loss
                    toward_loss_ema = 0.99 * toward_loss_ema + 0.01 * toward_loss_tmp if toward_loss_ema is not None else toward_loss_tmp
                    gcg_loss_tmp = gcg_loss
                    gcg_loss_ema = 0.99 * gcg_loss_ema + 0.01 * gcg_loss_tmp if gcg_loss_ema is not None else gcg_loss_tmp
                    total_loss_tmp = total_loss
                    total_loss_ema = 0.99 * total_loss_ema + 0.01 * total_loss_tmp if total_loss_ema is not None else total_loss_tmp

                    self.accelerator.log(
                        {
                            "clean_loss_ema": clean_loss_ema,
                            "gcg_loss_ema": gcg_loss_ema,
                            "away_loss_ema": away_loss_ema,
                            "toward_loss_ema": toward_loss_ema,
                            "total_loss_ema": total_loss_ema,
                        },
                        step=total_batched_samples,
                    )

                    tr_loss_step = total_loss  # used for standard Trainer logging
                    # =========== END ADVERSARIAL TRAINING CODE =========== #

                if (
                    args.logging_nan_inf_filter
                    and not is_torch_tpu_available()
                    and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
                ):
                    # if loss is nan or inf simply add the average of previous logged losses
                    tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
                else:
                    tr_loss += tr_loss_step

                self.current_flos += float(self.floating_point_ops(inputs))

                is_last_step_and_steps_less_than_grad_acc = (
                    steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch
                )

                if (
                    total_batched_samples % args.gradient_accumulation_steps == 0
                    or
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
                    is_last_step_and_steps_less_than_grad_acc
                ):
                    # the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered
                    # in accelerate. So, explicitly enable sync gradients to True in that case.
                    if is_last_step_and_steps_less_than_grad_acc or (
                        version.parse(accelerate_version) <= version.parse("0.20.3")
                    ):
                        self.accelerator.gradient_state._set_sync_gradients(True)

                    # Gradient clipping
                    if args.max_grad_norm is not None and args.max_grad_norm > 0:
                        # deepspeed does its own clipping

                        if is_sagemaker_mp_enabled() and args.fp16:
                            self.optimizer.clip_master_grads(args.max_grad_norm)
                        elif self.use_apex:
                            # Revert to normal clipping otherwise, handling Apex or full precision
                            nn.utils.clip_grad_norm_(
                                amp.master_params(self.optimizer),
                                args.max_grad_norm,
                            )
                        else:
                            self.accelerator.clip_grad_norm_(
                                model.parameters(),
                                args.max_grad_norm,
                            )

                    # Optimizer step
                    self.optimizer.step()
                    optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
                    if optimizer_was_run:
                        # Delay optimizer scheduling until metrics are generated
                        if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                            self.lr_scheduler.step()

                    model.zero_grad()
                    self.state.global_step += 1
                    self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
                    self.control = self.callback_handler.on_step_end(args, self.state, self.control)

                    self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
                else:
                    self.control = self.callback_handler.on_substep_end(args, self.state, self.control)

                if self.control.should_epoch_stop or self.control.should_training_stop:
                    break
            if step < 0:
                logger.warning(
                    "There seems to be not a single sample in your epoch_iterator, stopping training at step"
                    f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
                    f" num_steps ({max_steps}) higher than the number of available samples."
                )
                self.control.should_training_stop = True

            self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
            self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)

            if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
                if is_torch_tpu_available():
                    # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                    xm.master_print(met.metrics_report())
                else:
                    logger.warning(
                        "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
                        "configured. Check your training configuration if this is unexpected."
                    )
            if self.control.should_training_stop:
                break

        if args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of training
            delattr(self, "_past")

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
        if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
            # Wait for everyone to get here so we are sure the model has been saved by process 0.
            if is_torch_tpu_available():
                xm.rendezvous("load_best_model_at_end")
            elif args.parallel_mode == ParallelMode.DISTRIBUTED:
                dist.barrier()
            elif is_sagemaker_mp_enabled():
                smp.barrier()

            self._load_best_model()

        # add remaining tr_loss
        self._total_loss_scalar += tr_loss.item()
        train_loss = self._total_loss_scalar / self.state.global_step

        metrics = speed_metrics(
            "train",
            start_time,
            num_samples=num_train_samples,
            num_steps=self.state.max_steps,
            num_tokens=num_train_tokens,
        )
        self.store_flos()
        metrics["total_flos"] = self.state.total_flos
        metrics["train_loss"] = train_loss

        self.is_in_train = False

        self._memory_tracker.stop_and_update_metrics(metrics)

        self.log(metrics)

        run_dir = self._get_output_dir(trial)
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)

        # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.
        if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
            for checkpoint in checkpoints_sorted:
                if not os.path.samefile(checkpoint, self.state.best_model_checkpoint):
                    logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
                    shutil.rmtree(checkpoint)

        self.control = self.callback_handler.on_train_end(args, self.state, self.control)

        # Wait for the checkpoint to be uploaded.
        self._finish_current_push()

        # After training we make sure to retrieve back the original forward pass method
        # for the embedding layer by removing the forward post hook.
        if self.neftune_noise_alpha is not None:
            self._deactivate_neftune(self.model)

        return TrainOutput(self.state.global_step, train_loss, metrics)


    def adv_training_step(self, model, inputs, test_case_dicts, test_case_index_loader, indices_dataloader, test_case_parallel_loader):
        """
        :param model: The model to train.
        :param inputs: The inputs that would normally be passed to self.training_step.
        :param test_case_dicts: A list of test case dictionaries.
        :param test_case_index_loader: A DataLoader to select which test cases to update (all GPUs see the same batches here)
        :param indices_dataloader: A prepared DataLoader to parallelize GCG candidate evaluation
        :param test_case_parallel_loader: A DataLoader to split test cases among GPUs (different GPUs see different batches here)
        :return: The clean loss, the away loss, the toward loss, the GCG loss, the total loss, and the updated test case dictionaries
        """
        model.train()
        inputs = self._prepare_inputs(inputs)

        # =========== UPDATE TEST CASES =========== #
        # Get indices of test cases to update
        test_case_indices = next(iter(test_case_index_loader))[0]

        # Update test cases (all processes run the full loop through all test cases to update)
        # Parallelization happens within run_gcg using indices_dataloader
        optim_tokens = []
        optim_targets = []
        model_targets = []
        behaviors = []
        gcg_losses = []
        for i in test_case_indices:
            test_case_dict = test_case_dicts[i]
            behavior = test_case_dict['behavior']
            optim_target = test_case_dict['optim_target']
            adv_tokens_init = test_case_dict['optim_tokens']
            optim_ids, optim_string, gcg_loss = run_gcg(
                behavior,
                optim_target,
                self.before_tc,
                self.after_tc,
                model,
                self.tokenizer,
                self.accelerator,
                indices_dataloader,
                num_steps=self.num_gcg_steps,
                adv_tokens_init=adv_tokens_init,
                allow_non_ascii=False,
                search_width=self.gcg_search_width,
                num_accelerate_gpus=self.num_accelerate_gpus,
                verbose=False)
            
            test_case_dict['optim_tokens'] = optim_ids
            test_case_dict['loss'] = gcg_loss
            test_case_dicts[i] = test_case_dict

            # Collect batch for model forward pass
            optim_tokens.append(optim_ids)
            optim_targets.append(optim_target)
            model_targets.append(test_case_dict['model_target'])  # TODO: accommodate multi-behavior test case dicts
            behaviors.append(behavior)

            # Collect loss for logging
            gcg_losses.append(gcg_loss)
        
        gcg_loss = torch.tensor(gcg_losses).mean().item()
        
        # =========== COMPUTE LOSS FOR MODEL ON UPDATED TEST CASES =========== #
        # Construct inputs for model forward pass
        optim_ids_list = optim_tokens

        optim_targets = self.tokenizer(optim_targets, padding=True, truncation=True, add_special_tokens=False)['input_ids']
        optim_target_ids_list = [torch.tensor(input_ids, device=self.accelerator.device).unsqueeze(0) for input_ids in optim_targets]
        model_targets = self.tokenizer(model_targets, padding=True, truncation=True, add_special_tokens=False)['input_ids']
        # Add EOF token
        for i in range(len(model_targets)):
            model_targets[i].append(self.tokenizer.eos_token_id)
        model_target_ids_list = [torch.tensor(input_ids, device=self.accelerator.device).unsqueeze(0) for input_ids in model_targets]
        behaviors = self.tokenizer(behaviors, padding=False, truncation=True, add_special_tokens=False)['input_ids']
        behavior_ids_list = [torch.tensor(input_ids, device=self.accelerator.device).unsqueeze(0) for input_ids in behaviors]
        
        before_ids = torch.tensor(self.tokenizer(self.before_tc, padding=False, add_special_tokens=True)['input_ids'], device=self.accelerator.device).unsqueeze(0)
        after_ids = torch.tensor(self.tokenizer(self.after_tc, padding=False, add_special_tokens=False)['input_ids'], device=self.accelerator.device).unsqueeze(0)

        input_ids_list_away = []
        labels_list_away = []

        input_ids_list_toward = []
        labels_list_toward = []

        updated_test_case_indices = next(iter(test_case_parallel_loader))[0]
        # TODO: decouple updated test cases from the test cases to use for updating the model?
        # It's actually unclear whether this would be best, or whether the model should train
        # on the same test cases that were just updated. Let's keep it like this for now.

        for i in updated_test_case_indices:
            optim_target_ids = optim_target_ids_list[i]
            model_target_ids = model_target_ids_list[i]
            behavior_ids = behavior_ids_list[i]
            optim_ids = optim_ids_list[i]

            # ========== Get inputs for loss away from optim targets ========== #
            target_ids = optim_target_ids
            input_ids = torch.cat([before_ids, behavior_ids, optim_ids, after_ids, target_ids], dim=1)
            # Set labels to input_ids but with -100 for all tokens except the target tokens
            labels = input_ids.clone()
            labels.fill_(-100)
            start_target = input_ids.size(1) - target_ids.size(1)
            labels[:, start_target:] = input_ids[:, start_target:]

            input_ids_list_away.append(input_ids)
            labels_list_away.append(labels)

            # ========== Get inputs for loss toward model targets ========== #
            target_ids = model_target_ids
            input_ids = torch.cat([before_ids, behavior_ids, optim_ids, after_ids, target_ids], dim=1)
            # Set labels to input_ids but with -100 for all tokens except the target tokens
            labels = input_ids.clone()
            labels.fill_(-100)
            start_target = input_ids.size(1) - target_ids.size(1)
            labels[:, start_target:] = input_ids[:, start_target:]

            input_ids_list_toward.append(input_ids)
            labels_list_toward.append(labels)

        # ========== Compute loss ========== #
        # ====== Create combined batch (with clean+toward+away examples) ====== #
        # add right-padding to input_ids and labels
        num_away = len(input_ids_list_away)
        num_toward = len(input_ids_list_toward)

        input_ids_list = input_ids_list_away + input_ids_list_toward
        # add inputs['input_ids'] to input_ids_list
        input_ids_list.extend(list(inputs['input_ids']))

        # convert input_ids_list to list[list[int]] (have to encapsulate in a dict as well)
        input_ids_list = {'input_ids': [input_ids.squeeze(0).tolist() for input_ids in input_ids_list]}
        tmp = self.tokenizer.pad(input_ids_list, padding=True, return_tensors='pt')
        input_ids = tmp['input_ids']
        attention_mask = tmp['attention_mask']

        labels_list = labels_list_away + labels_list_toward
        # add inputs['labels'] to labels_list (with inputs['attention_mask']==0 positions set to -100])
        tmp = inputs['labels']
        tmp[inputs['attention_mask'] == 0] = -100
        labels_list.extend(list(tmp))

        # convert labels_list to list[list[int]] (have to encapsulate in a dict as well)
        labels_list = {'input_ids': [labels.squeeze(0).tolist() for labels in labels_list]}
        tmp = self.tokenizer.pad(labels_list, padding=True, return_tensors='pt')
        labels = tmp['input_ids']
        labels_attn_tmp = tmp['attention_mask']
        # set labels to -100 where labels_attn_tmp is 0 (workaround b/c tokenizer.pad() doesn't set labels to -100)
        labels[labels_attn_tmp == 0] = -100

        # move input_ids, attention_mask, and labels to device
        input_ids = input_ids.to(self.accelerator.device)
        attention_mask = attention_mask.to(self.accelerator.device)
        labels = labels.to(self.accelerator.device)

        # print if input_ids is over 2048 tokens long
        if input_ids.size(1) > 2048:
            print(f'input_ids is {input_ids.size(1)} tokens long')
            print(input_ids.shape)
            print(input_ids)
            # print decoded
            for i in range(input_ids.size(0)):
                print(self.tokenizer.decode(input_ids[i]))
        
        # ====== Forward pass ====== #
        output = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = output.logits

        # Shift so that tokens < n predict n
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()

        # ====== Compute away loss ====== #
        shift_logits_away = shift_logits[:num_away]
        shift_labels_away = shift_labels[:num_away]

        # Now compute the log(1 - P(label)) loss:
        away_loss = log_1_minus_p_loss(shift_logits_away, shift_labels_away)

        # ====== Compute toward loss ====== #
        shift_logits_toward = shift_logits[num_away:num_away+num_toward]
        shift_labels_toward = shift_labels[num_away:num_away+num_toward]

        # Now compute toward loss
        loss_fct = CrossEntropyLoss()
        toward_loss = loss_fct(shift_logits_toward.view(-1, shift_logits_toward.size(-1)), shift_labels_toward.view(-1))

        # ====== Compute clean loss ====== #
        shift_logits_clean = shift_logits[num_away+num_toward:]
        shift_labels_clean = shift_labels[num_away+num_toward:]

        # Now compute clean loss
        loss_fct = CrossEntropyLoss()
        clean_loss = loss_fct(shift_logits_clean.view(-1, shift_logits_clean.size(-1)), shift_labels_clean.view(-1))

        # ========== Compute total loss ========== #
        loss = clean_loss + away_loss*0.5 + toward_loss*0.5
        loss_bp = loss
        
        if self.use_apex:
            with amp.scale_loss(loss_bp, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            self.accelerator.backward(loss_bp)

        return clean_loss.item(), away_loss.item(), toward_loss.item(), gcg_loss, loss.detach() / self.args.gradient_accumulation_steps, test_case_dicts

    def create_accelerator_and_postprocess(self):
        grad_acc_kwargs = {"num_steps": self.args.gradient_accumulation_steps}
        if version.parse(accelerate_version) > version.parse("0.20.3"):
            grad_acc_kwargs["sync_with_dataloader"] = False
        gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)

        # create accelerator object
        self.accelerator = Accelerator(
            dispatch_batches=self.args.dispatch_batches,
            split_batches=self.args.split_batches,
            deepspeed_plugin=self.args.deepspeed_plugin,
            gradient_accumulation_plugin=gradient_accumulation_plugin,
            log_with="wandb",
        )
        self.accelerator.init_trackers("harmbench_adv_training")

        # deepspeed and accelerate flags covering both trainer args and accelerate launcher
        self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
        self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None

        # post accelerator creation setup
        if self.is_fsdp_enabled:
            fsdp_plugin = self.accelerator.state.fsdp_plugin
            fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get(
                "limit_all_gathers", fsdp_plugin.limit_all_gathers
            )
            if is_accelerate_available("0.23.0"):
                fsdp_plugin.activation_checkpointing = self.args.fsdp_config.get(
                    "activation_checkpointing", fsdp_plugin.activation_checkpointing
                )
                if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing:
                    raise ValueError(
                        "The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg "
                        "can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic "
                        "when using FSDP."
                    )

        if self.is_deepspeed_enabled:
            if getattr(self.args, "hf_deepspeed_config", None) is None:
                from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig

                ds_plugin = self.accelerator.state.deepspeed_plugin

                ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config)
                ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config
                ds_plugin.hf_ds_config.trainer_config_process(self.args)
