import torch
from torch import nn
from transformers import Trainer
from transformers.trainer_utils import seed_worker
from transformers.utils import is_datasets_available
import datasets
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
import torch.nn.functional as F
import copy, os
import deepspeed
from evaluate_util import get_dataloader, get_all_evals, get_kl_divergence, get_single_dataloader, get_single_evals, get_reft_evals
import copy
import json 
from pathlib import Path
from data_module import get_batch_loss, UnlearnDataset
from utils import merge_dicts, interleave_eval_result_dict, get_forget_quality, get_model_utility
import numpy as np
from scipy.stats import ks_2samp, hmean
import csv 
import pickle
import math

import math
import os
import tqdm


class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        input_ids, labels, attention_mask = inputs
        outputs = model(input_ids,labels=labels, attention_mask=attention_mask)
        loss = outputs.loss
        return (loss, outputs) if return_outputs else loss
    
    def prediction_step(self, model, inputs, prediction_loss_only: bool, ignore_keys=None):
        input_ids, labels, attention_mask = inputs
        # forward pass
        with torch.no_grad():
            outputs = model(input_ids,labels=labels, attention_mask=attention_mask)
            logits = outputs.logits
            loss = outputs.loss
        return (loss, logits, labels)
    

class CustomTrainerForgetting(Trainer):
    def __init__(self, *args, **kwargs):
        self.loss_type = kwargs.pop('forget_loss')
        self.oracle_model = kwargs.pop('oracle_model')
        self.eval_cfg = kwargs.pop('eval_cfg')
        self.seed = kwargs.pop('seed')

        # the coefficient of each part in the loss function. This is used in ablation study.
        self.npo_coeff=kwargs.pop('npo_coeff')
        self.grad_diff_coeff=kwargs.pop('grad_diff_coeff')
        self.KL_coeff=kwargs.pop('KL_coeff')

        self.ref_policy = kwargs.pop('ref_policy')

        self.beta = kwargs.pop('beta')
        self.gamma = kwargs.pop('gamma')
        self.lambda_entropy = kwargs.pop('lambda_entropy')
        self.entropy_lower_bound = kwargs.pop('entropy_lower_bound')
        self.cl_coeff = kwargs.pop('cl_coeff')
        self.tau = kwargs.pop('tau')
        self.mix_retain_coeff = kwargs.pop('mix_retain_coeff')
        self.xp_terms = []
        self.grad_Rs = []

        super(CustomTrainerForgetting, self).__init__(*args, **kwargs)

        # Here, we always need the oracle model to compute the KL distance in the evaluation time.
        if self.is_deepspeed_enabled:
            self.oracle_model = self.e_prepare_deepspeed(self.oracle_model)


    def get_train_dataloader(self):
        """
        Override the original get_train_dataloader function simply for debugging.
        This is identical to the get_train_dataloader function in transformer.Trainer.
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")

        train_dataset = self.train_dataset
        data_collator = self.data_collator
        # import pdb ; pdb.set_trace()
        if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
            train_dataset = self._remove_unused_columns(train_dataset, description="training")
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="training")

        dataloader_params = {
            "batch_size": self._train_batch_size,
            "collate_fn": data_collator,
            "num_workers": self.args.dataloader_num_workers,
            "pin_memory": self.args.dataloader_pin_memory,
            "persistent_workers": self.args.dataloader_persistent_workers,
        }
        
        generator = torch.Generator()
        generator.manual_seed(self.seed + self.state.global_step)
        print(f'Generator........Epoch-{self.state.global_step}')

        if not isinstance(train_dataset, torch.utils.data.IterableDataset):
            
            dataloader_params["generator"] = generator
            dataloader_params["shuffle"] = True # set shuffle=True with specified generator.
            # dataloader_params["sampler"] = self._get_train_sampler()
            dataloader_params["drop_last"] = self.args.dataloader_drop_last
            dataloader_params["worker_init_fn"] = seed_worker

        return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))


    def e_prepare_deepspeed(self, model):
        deepspeed_plugin = self.accelerator.state.deepspeed_plugin
        config_kwargs = copy.deepcopy(deepspeed_plugin.deepspeed_config)

        if model is not None:
            if hasattr(model, "config"):
                hidden_size = (
                    max(model.config.hidden_sizes)
                    if getattr(model.config, "hidden_sizes", None)
                    else getattr(model.config, "hidden_size", None)
                )
                if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
                    config_kwargs.update(
                        {
                            "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
                            "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
                            "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
                        }
                    )

        # If ZeRO-3 is used, we shard both the active and reference model.
        # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)

        # for BLUE, we use ZeRO-3 here
        if config_kwargs["zero_optimization"]["stage"] != 3:
            config_kwargs["zero_optimization"]["stage"] = 0

        config_kwargs["optimizer"] = {"type": None}
        model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
        model.eval()
        #set the gradients to false for every parameter
        for param in model.parameters():
            param.requires_grad = False

        return model
    
    def compute_loss(self, model, inputs, return_outputs=False):
        if self.loss_type == "grad_ascent":
            forget_inputs, retain_inputs = inputs
            input_ids, labels, attention_mask = forget_inputs
            outputs = model(input_ids,labels=labels, attention_mask=attention_mask)         ##attention_mask is used to indicate which tokens to attend to ()
            forget_loss = outputs.loss
            forget_loss = forget_loss * -1
            loss = forget_loss

        elif self.loss_type == "fine_tuned":
            _, retain_inputs = inputs
            retain_input_ids, retain_labels, retain_attention_mask = retain_inputs
            retain_outputs = model(retain_input_ids,labels=retain_labels, attention_mask=retain_attention_mask)
            retain_loss = retain_outputs.loss
            loss = retain_loss

        elif self.loss_type == "grad_ascent_forgetKL":
            forget_inputs, retain_inputs = inputs
            input_ids, labels, attention_mask = forget_inputs
            outputs = model(input_ids,labels=labels, attention_mask=attention_mask)         
            forget_loss = -1 * outputs.loss

            with torch.no_grad():
                oracle_outputs = self.oracle_model(input_ids,labels=labels, attention_mask=attention_mask)
            oracle_probs = F.log_softmax(oracle_outputs.logits, dim=-1)
            oracle_probs = oracle_probs.view(-1, oracle_outputs.logits.shape[-1])
            current_probs = F.log_softmax(outputs.logits, dim=-1)
            current_probs = current_probs.view(-1, outputs.logits.shape[-1])

            kl_loss = nn.functional.kl_div(current_probs, oracle_probs, reduction='batchmean', log_target=True)
            loss = forget_loss + kl_loss

        elif self.loss_type == "grad_diff":
            forget_inputs, retain_inputs = inputs
            input_ids, labels, attention_mask = forget_inputs
            outputs = model(input_ids,labels=labels, attention_mask=attention_mask)
            forget_loss = outputs.loss
            forget_loss = forget_loss * -1

            retain_input_ids, retain_labels, retain_attention_mask = retain_inputs
            retain_outputs = model(retain_input_ids,labels=retain_labels, attention_mask=retain_attention_mask)
            retain_loss = retain_outputs.loss
            loss = forget_loss + retain_loss

        elif self.loss_type == "entropy_grad_diff":
            forget_inputs, retain_inputs = inputs
            input_ids, labels, attention_mask = forget_inputs
            outputs = model(input_ids,labels=labels, attention_mask=attention_mask)
            forget_loss = outputs.loss
            forget_loss = forget_loss * -1

            # compute entropy
            probs = F.softmax(outputs.logits, dim=-1)  # (B, T, V)
            entropy = -torch.sum(probs * torch.log(probs + 1e-6), dim=-1)  # shape: (B, T)
            # entropy = torch.clamp(self.entropy_lower_bound - entropy_org, min=0.0)
            valid_mask = (labels != -100) & (attention_mask == 1)
            masked_entropy = entropy * valid_mask
            avg_entropy = masked_entropy.sum() / valid_mask.sum()

            # if int(os.environ.get('RANK', '0')) == 0:
            #     import pdb; pdb.set_trace()

            if int(os.environ.get('RANK', '0')) == 0:
                # masked_entropy_org = entropy_org * valid_mask
                # avg_entropy_org = masked_entropy_org.sum() / valid_mask.sum()
                # print(f"entropy_org: {avg_entropy_org}", f"entropy: {avg_entropy}")
                print(f"entropy: {avg_entropy}")

            coefficient = torch.clamp(avg_entropy.detach() - self.entropy_lower_bound, min=0.0)

            retain_input_ids, retain_labels, retain_attention_mask = retain_inputs
            retain_outputs = model(retain_input_ids,labels=retain_labels, attention_mask=retain_attention_mask)
            retain_loss = retain_outputs.loss
            loss = forget_loss * self.lambda_entropy * coefficient  + retain_loss

        elif self.loss_type == "high_conf_penalty_grad_diff":
            forget_inputs, retain_inputs = inputs
            input_ids, labels, attention_mask = forget_inputs
            outputs = model(input_ids,labels=labels, attention_mask=attention_mask)
            # forget_loss = outputs.loss
            # forget_loss = forget_loss * -1

            valid_mask = (labels != -100) & (attention_mask == 1)

            logits = outputs.logits
            probs = F.softmax(logits, dim=-1)
            max_probs, _ = probs.max(dim=-1)  # shape: (B, T)
            penalty_mask = (max_probs > 0.9).float()  # (B, T)

            token_weights = 1.0 - self.lambda_entropy * penalty_mask
            log_probs = F.log_softmax(logits, dim=-1)

            safe_labels = labels.clone()
            safe_labels[~valid_mask] = 0  # replace -100 with dummy index

            nll = -log_probs.gather(-1, safe_labels.unsqueeze(-1)).squeeze(-1)  # shape: (B, T)
            weighted_nll = nll * token_weights

            forget_loss = - weighted_nll.sum() / valid_mask.sum()

            retain_input_ids, retain_labels, retain_attention_mask = retain_inputs
            retain_outputs = model(retain_input_ids,labels=retain_labels, attention_mask=retain_attention_mask)
            retain_loss = retain_outputs.loss
            loss = forget_loss + retain_loss

        elif self.loss_type == "cl_grad_diff":
            forget_inputs, retain_inputs, forget_question_inputs, retain_question_inputs, mix_question_inputs = inputs
            input_ids, labels, attention_mask = forget_inputs

            if input_ids.shape[0] != 1:
                raise("Not implemented.")

            outputs = model(input_ids,labels=labels, attention_mask=attention_mask, output_hidden_states=True)
            forget_loss = outputs.loss
            forget_loss = forget_loss * -1

            retain_input_ids, retain_labels, retain_attention_mask = retain_inputs
            retain_outputs = model(retain_input_ids,labels=retain_labels, attention_mask=retain_attention_mask, output_hidden_states=True)
            retain_loss = retain_outputs.loss

            mix_input_ids, mix_labels, mix_attention_mask = mix_question_inputs
            mix_outputs = model(mix_input_ids,labels=mix_labels, attention_mask=mix_attention_mask, output_hidden_states=True)

            forget_question_input_ids, forget_question_labels, forget_question_attention_mask = forget_question_inputs
            retain_question_input_ids, retain_question_labels, retain_question_attention_mask = retain_question_inputs

            mix_length = mix_attention_mask.sum(dim=1).item()
            forget_length = forget_question_attention_mask.sum(dim=1).item()
            retain_length = retain_question_attention_mask.sum(dim=1).item()
            neg_rep = outputs.hidden_states[-1][:, forget_length-1].detach()
            pos_rep = retain_outputs.hidden_states[-1][:, retain_length-1].detach()
            mix_rep = mix_outputs.hidden_states[-1][:, mix_length-1]

            tau = self.tau
            sim_pos = F.cosine_similarity(mix_rep, pos_rep, dim=-1)
            sim_neg = F.cosine_similarity(mix_rep, neg_rep, dim=-1)
            prob_pos = torch.exp(sim_pos / tau) / (torch.exp(sim_pos / tau) + torch.exp(sim_neg / tau))
            cl_loss = -torch.log(prob_pos + 1e-8).mean()

            # if int(os.environ.get('RANK', '0')) == 0:
            #     import pdb; pdb.set_trace()

            print(f"cl_loss: {cl_loss.item()}")

            loss = forget_loss + retain_loss + self.cl_coeff * cl_loss

        elif self.loss_type == "mix_retain_grad_diff":
            forget_inputs, retain_inputs, mix_retain_inputs = inputs
            input_ids, labels, attention_mask = forget_inputs

            outputs = model(input_ids,labels=labels, attention_mask=attention_mask, output_hidden_states=True)
            forget_loss = outputs.loss
            forget_loss = forget_loss * -1

            retain_input_ids, retain_labels, retain_attention_mask = retain_inputs
            retain_outputs = model(retain_input_ids,labels=retain_labels, attention_mask=retain_attention_mask, output_hidden_states=True)
            retain_loss = retain_outputs.loss

            mix_retain_input_ids, mix_retain_labels, mix_retain_attention_mask = mix_retain_inputs
            mix_retain_outputs = model(mix_retain_input_ids,labels=mix_retain_labels, attention_mask=mix_retain_attention_mask, output_hidden_states=True)
            mix_retain_loss = mix_retain_outputs.loss

            print(f"mix_retain_loss: {mix_retain_loss.item()}")

            # print("retain_input_ids.shape: ", retain_input_ids.shape)
            # print("mix_retain_input_ids.shape: ", mix_retain_input_ids.shape)

            # if int(os.environ.get('RANK', '0')) == 0:
            #     import pdb; pdb.set_trace()

            loss = forget_loss + retain_loss + self.mix_retain_coeff * mix_retain_loss

        elif self.loss_type == "manual_grad_diff":
            forget_inputs, retain_inputs = inputs
            forget_input_ids, forget_labels, forget_attention_mask = forget_inputs
            retain_input_ids, retain_labels, retain_attention_mask = retain_inputs
            input_ids = torch.cat((forget_input_ids, retain_input_ids), dim=0)
            labels = torch.cat((forget_labels, retain_labels), dim=0)
            attention_mask = torch.cat((forget_attention_mask, retain_attention_mask), dim=0)
            outputs = model(input_ids,labels=labels, attention_mask=attention_mask)

            # # # Default loss (averaged over all samples and tokens)
            # # default_loss = outputs.loss
            # # if int(os.environ.get('RANK', '0')) == 0:
            # #     print("Default batch loss:", default_loss)

            # # Compute per-token loss manually
            # loss_fct = torch.nn.CrossEntropyLoss(reduction="none", ignore_index=-100)  # No reduction

            # shift_logits = outputs.logits[:, :-1, :].contiguous()
            # shift_labels = labels[:, 1:].contiguous()
            # shift_mask = attention_mask[:, 1:].contiguous()  # Mask for valid tokens
            # losses = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            # per_token_loss = losses.view(shift_labels.shape)
            # masked_loss = per_token_loss * shift_mask  # Zero out padding token losses
            # per_sample_loss = masked_loss.sum(dim=1) / shift_mask.sum(dim=1)

            # # if int(os.environ.get('RANK', '0')) == 0:
            # #     # import pdb; pdb.set_trace()
            # # # print("Per-sample loss:", per_sample_loss)  # Tensor of shape (batch_size,)
            # #     print("Manual loss:", per_sample_loss.mean())  # Tensor of shape (batch_size,)

            # # if int(os.environ.get('RANK', '0')) == 0:
            # #     import pdb; pdb.set_trace()

            # loss = - per_sample_loss[:input_ids.shape[0]].mean() + per_sample_loss[input_ids.shape[0]:].mean()

            forget_loss = get_batch_loss(outputs.logits[:forget_input_ids.shape[0]], labels[:forget_input_ids.shape[0]])
            retain_loss = get_batch_loss(outputs.logits[forget_input_ids.shape[0]:], labels[forget_input_ids.shape[0]:])

            # if int(os.environ.get('RANK', '0')) == 0:
            #     import pdb; pdb.set_trace()

            loss = - forget_loss.mean() + retain_loss.mean()

        elif self.loss_type == "mix_retain_manual_grad_diff":
            forget_inputs, retain_inputs, mix_retain_inputs = inputs
            forget_input_ids, forget_labels, forget_attention_mask = forget_inputs
            retain_input_ids, retain_labels, retain_attention_mask = retain_inputs
            mix_retain_input_ids, mix_retain_labels, mix_retain_attention_mask = mix_retain_inputs
            input_ids = torch.cat((forget_input_ids, retain_input_ids, mix_retain_input_ids), dim=0)
            labels = torch.cat((forget_labels, retain_labels, mix_retain_labels), dim=0)
            attention_mask = torch.cat((forget_attention_mask, retain_attention_mask, mix_retain_attention_mask), dim=0)
            outputs = model(input_ids,labels=labels, attention_mask=attention_mask)

            forget_loss = get_batch_loss(outputs.logits[:forget_input_ids.shape[0]], labels[:forget_input_ids.shape[0]])
            retain_loss = get_batch_loss(outputs.logits[forget_input_ids.shape[0]:forget_input_ids.shape[0]+retain_input_ids.shape[0]], labels[forget_input_ids.shape[0]:forget_input_ids.shape[0]+retain_input_ids.shape[0]])
            mix_retain_loss = get_batch_loss(outputs.logits[forget_input_ids.shape[0]+retain_input_ids.shape[0]:], labels[forget_input_ids.shape[0]+retain_input_ids.shape[0]:])
            loss = - forget_loss.mean() + retain_loss.mean() + self.mix_retain_coeff * mix_retain_loss.mean()
        
        elif self.loss_type == "KL":
            forget_inputs, retain_inputs = inputs
            input_ids, labels, attention_mask = forget_inputs
            outputs = model(input_ids,labels=labels, attention_mask=attention_mask)
            forget_loss = outputs.loss
            forget_loss = forget_loss * -1
            retain_input_ids, retain_labels, retain_attention_mask = retain_inputs
            with torch.no_grad():
                retain_outputs = self.oracle_model(retain_input_ids,labels=retain_labels, attention_mask=retain_attention_mask)
            
            retain_probs = F.log_softmax(retain_outputs.logits, dim=-1)
            retain_probs = retain_probs.view(-1, retain_outputs.logits.shape[-1])

            current_outputs = model(retain_input_ids,labels=retain_labels, attention_mask=retain_attention_mask)
            current_probs = F.log_softmax(current_outputs.logits, dim=-1)
            current_probs = current_probs.view(-1, current_outputs.logits.shape[-1])

            #minimum KL divergence
            retain_loss = nn.functional.kl_div(current_probs, retain_probs, reduction='batchmean', log_target=True)
            loss = forget_loss + retain_loss

        elif self.loss_type == "idk":
            idk_inputs, retain_inputs = inputs
            idk_input_ids, idk_labels, idk_attention_mask = idk_inputs
            retain_input_ids, retain_labels, retain_attention_mask = retain_inputs
            
            #concatenate the inputs. single forward pass is much more efficient
            input_ids = torch.cat((idk_input_ids, retain_input_ids), dim=0)
            labels = torch.cat((idk_labels, retain_labels), dim=0)
            attention_mask = torch.cat((idk_attention_mask, retain_attention_mask), dim=0)
            
            outputs = model(input_ids,labels=labels, attention_mask=attention_mask)
            loss = outputs.loss

        elif self.loss_type == "mix_retain_idk":

            idk_inputs, retain_inputs, mix_retain_inputs = inputs
            idk_input_ids, idk_labels, idk_attention_mask = idk_inputs
            retain_input_ids, retain_labels, retain_attention_mask = retain_inputs
            mix_retain_input_ids, mix_retain_labels, mix_retain_attention_mask = mix_retain_inputs
            
            #concatenate the inputs. single forward pass is much more efficient
            input_ids = torch.cat((idk_input_ids, retain_input_ids), dim=0)
            labels = torch.cat((idk_labels, retain_labels), dim=0)
            attention_mask = torch.cat((idk_attention_mask, retain_attention_mask), dim=0)
            
            outputs = model(input_ids,labels=labels, attention_mask=attention_mask)

            mix_retain_outputs = model(mix_retain_input_ids,labels=mix_retain_labels, attention_mask=mix_retain_attention_mask, output_hidden_states=True)
            mix_retain_loss = mix_retain_outputs.loss
            
            loss = outputs.loss + self.mix_retain_coeff * mix_retain_loss
        
        elif self.loss_type in ["dpo","dpo_grad_diff","dpo_KL"]:
            idk_inputs, forget_inputs, retain_inputs = inputs
            idk_input_ids, idk_labels, idk_attention_mask = idk_inputs
            forget_input_ids, forget_labels, forget_attention_mask = forget_inputs
            idk_outputs = model(idk_input_ids,labels=idk_labels, attention_mask=idk_attention_mask)
            forget_outputs = model(forget_input_ids,labels=forget_labels, attention_mask=forget_attention_mask)
            with torch.no_grad():
                idk_outputs_oracle = self.oracle_model(idk_input_ids,labels=idk_labels, attention_mask=idk_attention_mask)
                forget_outputs_oracle = self.oracle_model(forget_input_ids,labels=forget_labels, attention_mask=forget_attention_mask)
                idk_logits_oracle = idk_outputs_oracle.logits
                forget_logits_oracle = forget_outputs_oracle.logits

                idk_loss_oracle = -1 * get_batch_loss(idk_logits_oracle, idk_labels)
                forget_loss_oracle = -1 * get_batch_loss(forget_logits_oracle, forget_labels)
            
            idk_loss_current = -1 * get_batch_loss(idk_outputs.logits, idk_labels)
            forget_loss_current = -1 * get_batch_loss(forget_outputs.logits, forget_labels)

            pi_logratios = idk_loss_current - forget_loss_current
            ref_logratios = idk_loss_oracle - forget_loss_oracle
            loss = -F.logsigmoid(self.beta * (pi_logratios - ref_logratios)).mean()*2/self.beta

            if self.loss_type == 'dpo_grad_diff':
                retain_input_ids, retain_labels, retain_attention_mask = retain_inputs
                retain_outputs = model(retain_input_ids,labels=retain_labels, attention_mask=retain_attention_mask)
                retain_loss = retain_outputs.loss
                loss = loss + retain_loss

            elif self.loss_type == 'dpo_KL':
                retain_input_ids, retain_labels, retain_attention_mask = retain_inputs
                with torch.no_grad():
                    retain_outputs = self.oracle_model(retain_input_ids,labels=retain_labels, attention_mask=retain_attention_mask)
                retain_probs = F.log_softmax(retain_outputs.logits, dim=-1)
                retain_probs = retain_probs.view(-1, retain_outputs.logits.shape[-1])

                current_outputs = model(retain_input_ids,labels=retain_labels, attention_mask=retain_attention_mask)
                current_probs = F.log_softmax(current_outputs.logits, dim=-1)
                current_probs = current_probs.view(-1, current_outputs.logits.shape[-1])

                #minimum KL divergence
                retain_loss = nn.functional.kl_div(current_probs, retain_probs, reduction='batchmean', log_target=True)
                loss = loss + retain_loss

        elif self.loss_type == "simnpo":
            forget_inputs, _ = inputs
            input_ids, labels, attention_mask = forget_inputs
            outputs = model(input_ids,labels=labels, attention_mask=attention_mask)
            loss_mask = labels != -100
            forget_loss = get_batch_loss(outputs.logits, labels) / loss_mask.sum(-1) - self.gamma

            loss = -F.logsigmoid(self.beta * forget_loss).mean() * 2 / self.beta

        elif self.loss_type == 'simnpo_grad_diff':
            forget_inputs, retain_inputs = inputs
            input_ids, labels, attention_mask = forget_inputs
            outputs = model(input_ids,labels=labels, attention_mask=attention_mask)
            loss_mask = labels != -100
            forget_loss = get_batch_loss(outputs.logits, labels) / loss_mask.sum(-1) - self.gamma
            forget_loss = -F.logsigmoid(self.beta * forget_loss).mean() * 2 / self.beta

            retain_input_ids, retain_labels, retain_attention_mask = retain_inputs
            retain_outputs = model(retain_input_ids,labels=retain_labels, attention_mask=retain_attention_mask)
            retain_loss = retain_outputs.loss
            loss = self.npo_coeff * forget_loss + self.grad_diff_coeff * retain_loss

        ### Implement the NPO
        elif self.loss_type == 'npo':
            forget_inputs, retain_inputs = inputs
            input_ids, labels, attention_mask = forget_inputs
            outputs = model(input_ids,labels=labels, attention_mask=attention_mask)
            forget_loss_current = get_batch_loss(outputs.logits, labels) 

            if self.ref_policy == 'fine_tuned':
                with torch.no_grad():
                    forget_outputs_oracle = self.oracle_model(input_ids,labels=labels, attention_mask=attention_mask)
                    forget_logits_oracle = forget_outputs_oracle.logits
                    forget_loss_oracle = get_batch_loss(forget_logits_oracle, labels)
                neg_log_ratios = forget_loss_current - forget_loss_oracle
            else:
                raise NotImplementedError
            loss = -F.logsigmoid(self.beta * neg_log_ratios).mean() * 2 / self.beta

        elif self.loss_type == 'npo_grad_diff':
            # print(inputs)
            forget_inputs, retain_inputs = inputs
            input_ids, labels, attention_mask = forget_inputs
            outputs = model(input_ids,labels=labels, attention_mask=attention_mask)
            forget_loss_current = get_batch_loss(outputs.logits, labels) 

            if self.ref_policy == 'fine_tuned':
                with torch.no_grad():
                    forget_outputs_oracle = self.oracle_model(input_ids,labels=labels, attention_mask=attention_mask)
                    forget_logits_oracle = forget_outputs_oracle.logits
                    forget_loss_oracle = get_batch_loss(forget_logits_oracle, labels)
                neg_log_ratios = forget_loss_current - forget_loss_oracle
            else:
                raise NotImplementedError
            forget_loss = -F.logsigmoid(self.beta * neg_log_ratios).mean() * 2 / self.beta

            # forget_loss = -(self.beta * neg_log_ratios).mean() * 2 / self.beta

            retain_input_ids, retain_labels, retain_attention_mask = retain_inputs
            retain_outputs = model(retain_input_ids,labels=retain_labels, attention_mask=retain_attention_mask)
            retain_loss = retain_outputs.loss
            # print(forget_loss)
            loss = self.npo_coeff * forget_loss + self.grad_diff_coeff * retain_loss

        elif self.loss_type == 'mix_retain_npo_grad_diff':
            # print(inputs)
            forget_inputs, retain_inputs, mix_retain_inputs = inputs
            input_ids, labels, attention_mask = forget_inputs
            outputs = model(input_ids,labels=labels, attention_mask=attention_mask)
            forget_loss_current = get_batch_loss(outputs.logits, labels) 

            if self.ref_policy == 'fine_tuned':
                with torch.no_grad():
                    forget_outputs_oracle = self.oracle_model(input_ids,labels=labels, attention_mask=attention_mask)
                    forget_logits_oracle = forget_outputs_oracle.logits
                    forget_loss_oracle = get_batch_loss(forget_logits_oracle, labels)
                neg_log_ratios = forget_loss_current - forget_loss_oracle
            else:
                raise NotImplementedError
            forget_loss = -F.logsigmoid(self.beta * neg_log_ratios).mean() * 2 / self.beta

            # forget_loss = -(self.beta * neg_log_ratios).mean() * 2 / self.beta

            retain_input_ids, retain_labels, retain_attention_mask = retain_inputs
            retain_outputs = model(retain_input_ids,labels=retain_labels, attention_mask=retain_attention_mask)
            retain_loss = retain_outputs.loss
            # print(forget_loss)

            mix_retain_input_ids, mix_retain_labels, mix_retain_attention_mask = mix_retain_inputs
            mix_retain_outputs = model(mix_retain_input_ids,labels=mix_retain_labels, attention_mask=mix_retain_attention_mask)
            mix_retain_loss = mix_retain_outputs.loss

            loss = self.npo_coeff * forget_loss + self.grad_diff_coeff * retain_loss + self.mix_retain_coeff * mix_retain_loss

        elif self.loss_type == 'manual_npo_grad_diff':
            # print(inputs)
            forget_inputs, retain_inputs = inputs
            forget_input_ids, forget_labels, forget_attention_mask = forget_inputs
            retain_input_ids, retain_labels, retain_attention_mask = retain_inputs
            input_ids = torch.cat((forget_input_ids, retain_input_ids), dim=0)
            labels = torch.cat((forget_labels, retain_labels), dim=0)
            attention_mask = torch.cat((forget_attention_mask, retain_attention_mask), dim=0)
            outputs = model(input_ids,labels=labels, attention_mask=attention_mask)

            forget_loss_current = get_batch_loss(outputs.logits[:forget_input_ids.shape[0]], labels[:forget_input_ids.shape[0]])
            retain_loss = get_batch_loss(outputs.logits[forget_input_ids.shape[0]:], labels[forget_input_ids.shape[0]:])

            if self.ref_policy == 'fine_tuned':
                with torch.no_grad():
                    forget_outputs_oracle = self.oracle_model(forget_input_ids,labels=forget_labels, attention_mask=forget_attention_mask)
                    forget_logits_oracle = forget_outputs_oracle.logits
                    forget_loss_oracle = get_batch_loss(forget_logits_oracle, forget_labels)
                neg_log_ratios = forget_loss_current - forget_loss_oracle
            else:
                raise NotImplementedError
            forget_loss = -F.logsigmoid(self.beta * neg_log_ratios).mean() * 2 / self.beta

            # forget_loss = -(self.beta * neg_log_ratios).mean() * 2 / self.beta

            # retain_input_ids, retain_labels, retain_attention_mask = retain_inputs
            # retain_outputs = model(retain_input_ids,labels=retain_labels, attention_mask=retain_attention_mask)
            # retain_loss = retain_outputs.loss
            # print(forget_loss)
            loss = self.npo_coeff * forget_loss.mean() + self.grad_diff_coeff * retain_loss.mean()

        elif self.loss_type == 'mix_retain_manual_npo_grad_diff':
            # print(inputs)
            forget_inputs, retain_inputs, mix_retain_inputs = inputs
            forget_input_ids, forget_labels, forget_attention_mask = forget_inputs
            retain_input_ids, retain_labels, retain_attention_mask = retain_inputs
            mix_retain_input_ids, mix_retain_labels, mix_retain_attention_mask = mix_retain_inputs
            input_ids = torch.cat((forget_input_ids, retain_input_ids, mix_retain_input_ids), dim=0)
            labels = torch.cat((forget_labels, retain_labels, mix_retain_labels), dim=0)
            attention_mask = torch.cat((forget_attention_mask, retain_attention_mask, mix_retain_attention_mask), dim=0)
            outputs = model(input_ids,labels=labels, attention_mask=attention_mask)

            forget_loss_current = get_batch_loss(outputs.logits[:forget_input_ids.shape[0]], labels[:forget_input_ids.shape[0]])
            retain_loss = get_batch_loss(outputs.logits[forget_input_ids.shape[0]:forget_input_ids.shape[0]+retain_input_ids.shape[0]], labels[forget_input_ids.shape[0]:forget_input_ids.shape[0]+retain_input_ids.shape[0]])
            mix_retain_loss = get_batch_loss(outputs.logits[forget_input_ids.shape[0]+retain_input_ids.shape[0]:], labels[forget_input_ids.shape[0]+retain_input_ids.shape[0]:])

            if self.ref_policy == 'fine_tuned':
                with torch.no_grad():
                    forget_outputs_oracle = self.oracle_model(forget_input_ids,labels=forget_labels, attention_mask=forget_attention_mask)
                    forget_logits_oracle = forget_outputs_oracle.logits
                    forget_loss_oracle = get_batch_loss(forget_logits_oracle, forget_labels)
                neg_log_ratios = forget_loss_current - forget_loss_oracle
            else:
                raise NotImplementedError
            forget_loss = -F.logsigmoid(self.beta * neg_log_ratios).mean() * 2 / self.beta

            loss = self.npo_coeff * forget_loss.mean() + self.grad_diff_coeff * retain_loss.mean() + self.mix_retain_coeff * mix_retain_loss.mean()
            
        elif self.loss_type == 'npo_KL':
            forget_inputs, retain_inputs = inputs
            input_ids, labels, attention_mask = forget_inputs
            outputs = model(input_ids,labels=labels, attention_mask=attention_mask)
            forget_loss_current = get_batch_loss(outputs.logits, labels) 
            if self.ref_policy == 'fine_tuned':
                with torch.no_grad():
                    forget_outputs_oracle = self.oracle_model(input_ids,labels=labels, attention_mask=attention_mask)
                    forget_logits_oracle = forget_outputs_oracle.logits
                    forget_loss_oracle = get_batch_loss(forget_logits_oracle, labels)
                neg_log_ratios = forget_loss_current - forget_loss_oracle
            else:
                raise NotImplementedError
            forget_loss = -F.logsigmoid(self.beta * neg_log_ratios).mean() * 2 / self.beta

            retain_input_ids, retain_labels, retain_attention_mask = retain_inputs
            with torch.no_grad():
                retain_outputs = self.oracle_model(retain_input_ids,labels=retain_labels, attention_mask=retain_attention_mask)
            retain_probs = F.log_softmax(retain_outputs.logits, dim=-1)
            retain_probs = retain_probs.view(-1, retain_outputs.logits.shape[-1])

            current_outputs = model(retain_input_ids,labels=retain_labels, attention_mask=retain_attention_mask)
            current_probs = F.log_softmax(current_outputs.logits, dim=-1)
            current_probs = current_probs.view(-1, current_outputs.logits.shape[-1])

            #minimum KL divergence
            retain_loss = nn.functional.kl_div(current_probs, retain_probs, reduction='batchmean', log_target=True)
            loss = self.npo_coeff * forget_loss + self.KL_coeff * retain_loss

        elif self.loss_type == 'kto_sigmoid':
            idk_inputs, forget_inputs, retain_inputs = inputs
            idk_input_ids, idk_labels, idk_attention_mask = idk_inputs
            forget_input_ids, forget_labels, forget_attention_mask = forget_inputs
            
            with torch.no_grad():
                idk_outputs = model(idk_input_ids,labels=idk_labels, attention_mask=idk_attention_mask)
                idk_outputs_oracle = self.oracle_model(idk_input_ids,labels=idk_labels, attention_mask=idk_attention_mask)
                idk_loss_log = -1 * get_batch_loss(idk_outputs.logits, idk_labels)
                idk_loss_log_oracle = -1 * get_batch_loss(idk_outputs_oracle.logits, idk_labels)
                
                KL_term = (idk_loss_log - idk_loss_log_oracle).mean()

                forget_outputs_oracle = self.oracle_model(forget_input_ids,labels=forget_labels, attention_mask=forget_attention_mask)
                forget_loss_oracle = -1 * get_batch_loss(forget_outputs_oracle.logits, forget_labels)

            forget_outputs = model(forget_input_ids,labels=forget_labels, attention_mask=forget_attention_mask)
            forget_loss = -1 * get_batch_loss(forget_outputs.logits, forget_labels)
            log_ratios = forget_loss - forget_loss_oracle
            loss = 1.0 - F.sigmoid(KL_term - self.beta * log_ratios).mean() * 2 / self.beta

        elif self.loss_type == 'kto_logsigmoid':
            idk_inputs, forget_inputs, retain_inputs = inputs
            idk_input_ids, idk_labels, idk_attention_mask = idk_inputs
            forget_input_ids, forget_labels, forget_attention_mask = forget_inputs
            
            with torch.no_grad():
                idk_outputs = model(idk_input_ids,labels=idk_labels, attention_mask=idk_attention_mask)
                idk_outputs_oracle = self.oracle_model(idk_input_ids,labels=idk_labels, attention_mask=idk_attention_mask)
                idk_loss_log = -1 * get_batch_loss(idk_outputs.logits, idk_labels)
                idk_loss_log_oracle = -1 * get_batch_loss(idk_outputs_oracle.logits, idk_labels)
                
                KL_term = (idk_loss_log - idk_loss_log_oracle).mean()

                forget_outputs_oracle = self.oracle_model(forget_input_ids,labels=forget_labels, attention_mask=forget_attention_mask)
                forget_loss_oracle = -1 * get_batch_loss(forget_outputs_oracle.logits, forget_labels)

            forget_outputs = model(forget_input_ids,labels=forget_labels, attention_mask=forget_attention_mask)
            forget_loss = -1 * get_batch_loss(forget_outputs.logits, forget_labels)
            log_ratios = forget_loss - forget_loss_oracle
            loss = 1.0 - F.logsigmoid(KL_term - self.beta * log_ratios).mean() * 2 / self.beta

        elif self.loss_type == 'kto_logsigmoid_grad_diff':
            idk_inputs, forget_inputs, retain_inputs = inputs
            idk_input_ids, idk_labels, idk_attention_mask = idk_inputs
            forget_input_ids, forget_labels, forget_attention_mask = forget_inputs
            
            with torch.no_grad():
                idk_outputs = model(idk_input_ids,labels=idk_labels, attention_mask=idk_attention_mask)
                idk_outputs_oracle = self.oracle_model(idk_input_ids,labels=idk_labels, attention_mask=idk_attention_mask)
                idk_loss_log = -1 * get_batch_loss(idk_outputs.logits, idk_labels)
                idk_loss_log_oracle = -1 * get_batch_loss(idk_outputs_oracle.logits, idk_labels)
                
                KL_term = (idk_loss_log - idk_loss_log_oracle).mean()

                forget_outputs_oracle = self.oracle_model(forget_input_ids,labels=forget_labels, attention_mask=forget_attention_mask)
                forget_loss_oracle = -1 * get_batch_loss(forget_outputs_oracle.logits, forget_labels)

            forget_outputs = model(forget_input_ids,labels=forget_labels, attention_mask=forget_attention_mask)
            forget_loss = -1 * get_batch_loss(forget_outputs.logits, forget_labels)
            log_ratios = forget_loss - forget_loss_oracle
            forget_loss = 1.0 - F.logsigmoid(KL_term - self.beta * log_ratios).mean() * 2 / self.beta

            retain_input_ids, retain_labels, retain_attention_mask = retain_inputs
            retain_outputs = model(retain_input_ids,labels=retain_labels, attention_mask=retain_attention_mask)
            retain_loss = retain_outputs.loss

            loss = forget_loss + retain_loss

        elif self.loss_type == "reft_grad_ascent":
            forget_inputs, retain_inputs = inputs
            input_ids, labels, attention_mask = forget_inputs
            outputs = model(input_ids,labels=labels, attention_mask=attention_mask)
            forget_loss = - outputs.loss
            loss = forget_loss

        return (loss, outputs) if return_outputs else loss
    

    def prediction_step(self, model, inputs, prediction_loss_only: bool, ignore_keys=None):
        input_ids, labels, attention_mask = inputs
        # forward pass
        with torch.no_grad():
            outputs = model(input_ids,labels=labels, attention_mask=attention_mask)
            logits = outputs.logits
            loss = outputs.loss
        return (loss, logits, labels)

    def evaluate(
        self,
        eval_dataset = None,
        ignore_keys = None,
        metric_key_prefix = "eval",
    ):
        '''
        RZ: Call this function in Trainer.train() when evakluating the performace of each checkpoint.
        '''

        args = self.args
        model = self._wrap_model(self.model, training=False, dataloader=None)

        print('####### Evaluating the model...... #######')
        print(self.is_in_train, args.device, model.dtype, self.args.dataloader_num_workers, self.eval_cfg.split_list)

        if len(self.accelerator._models) == 0 and model is self.model:
            model = (
                self.accelerator.prepare(model)
                if self.is_deepspeed_enabled
                else self.accelerator.prepare_model(model, evaluation_mode=True)
            )

            if self.is_fsdp_enabled:
                self.model = model

            # for the rest of this function `model` is the outside model, whether it was wrapped or not
            if model is not self.model:
                self.model_wrapped = model

            # backward compatibility
            if self.is_deepspeed_enabled:
                self.deepspeed = self.model_wrapped

        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
        # while ``train`` is running, cast it to the right dtype first and then put on device
        if not self.is_in_train:
            if args.fp16_full_eval:
                model = model.to(dtype=torch.float16, device=args.device)
            elif args.bf16_full_eval:
                model = model.to(dtype=torch.bfloat16, device=args.device)
        
        model.eval()
        curr_step = self.state.global_step
        eval_cfg = self.eval_cfg

        curr_save_dir = os.path.join(eval_cfg.save_dir, f"checkpoint-{curr_step}")
        Path(curr_save_dir).mkdir(parents=True, exist_ok=True)

        forget_rate = eval_cfg.split_list[-1].split('_')[0]

        with torch.no_grad():
            for i, (folder, split, question_key, answer_key, eval_task, base_answer_key, perturbed_answer_key) in enumerate(zip(eval_cfg.data_path, eval_cfg.split_list, eval_cfg.question_key, eval_cfg.answer_key, eval_cfg.eval_task, eval_cfg.base_answer_key, eval_cfg.perturbed_answer_key)):
                
                world_size = self.accelerator.num_processes

                # For some reason, Hydra is not interprating the split correctly
                if eval_task == 'eval_log_forget':
                    split = eval_cfg.split
                print(f'Working on eval task {eval_task} with split {split}')
                save_filename = os.path.join(curr_save_dir, f"{eval_task}.json")

                save_filename = save_filename if world_size == 1 else os.path.join(curr_save_dir, f"{eval_task}_{self.accelerator.local_process_index}.json")
                if os.path.exists(save_filename) and not eval_cfg.overwrite:
                    print(f"Skipping {eval_task} because {save_filename} already exists")
                    continue

                eval_dataloader, base_eval_dataloader, perturb_dataloader = get_dataloader(eval_cfg, eval_task, self.tokenizer, folder, split, question_key, answer_key, base_answer_key, perturbed_answer_key)

                eval_dataloader = self.accelerator.prepare(eval_dataloader)
                # print('dataset condition: ', len(eval_dataloader.dataset), self.accelerator.local_process_index)
                base_eval_dataloader = self.accelerator.prepare(base_eval_dataloader)
                perturb_dataloader = self.accelerator.prepare(perturb_dataloader)

                # if int(os.environ.get('RANK', '0')) == 0:
                #    import pdb; pdb.set_trace()

                eval_logs = get_all_evals(eval_cfg, model, self.tokenizer, eval_task, eval_dataloader, base_eval_dataloader, perturb_dataloader)
                
                kl_divergence_log = get_kl_divergence(model, self.oracle_model, eval_dataloader)
                eval_logs['kl_divergence'] = kl_divergence_log

                with open(save_filename, "w") as f:
                    # pretty write json to f
                    json.dump(eval_logs, f, indent=4)
            
                #wait for all process to finish
            self.accelerator.wait_for_everyone()
            aggregated_eval_logs = {}
            for eval_task in eval_cfg.eval_task:
                #read the saved file as json and merge them using merge_dicts

                if world_size > 1:
                    if self.accelerator.is_local_main_process:
                        eval_logs = json.load(open(os.path.join(curr_save_dir, f"{eval_task}_0.json")))

                        for i in range(1, world_size):
                            filename = os.path.join(curr_save_dir, f"{eval_task}_{i}.json")
                            eval_logs = merge_dicts(eval_logs, json.load(open(filename)))
                        
                        aggregated_eval_logs[f'{eval_task}.json'] = eval_logs

                        new_save_filename = os.path.join(curr_save_dir, f"{eval_task}.json")
                        with open(new_save_filename, "w") as f:
                            # pretty write json to f
                            json.dump(eval_logs, f, indent=4)
                            #delete old files use shutil
                            for i in range(world_size):
                                filename = os.path.join(curr_save_dir, f"{eval_task}_{i}.json")
                                os.remove(filename)

                else:
                    if self.accelerator.is_local_main_process:
                        eval_logs = json.load(open(os.path.join(curr_save_dir, f"{eval_task}.json")))
                        aggregated_eval_logs[f'{eval_task}.json'] = eval_logs
                                
            if self.accelerator.is_local_main_process:

                aggregated_eval_logs = interleave_eval_result_dict(aggregated_eval_logs, forget_rate, large_bsz=eval_cfg.batch_size, num_processes=world_size)
                aggregated_eval_log_filename = os.path.join(curr_save_dir, "eval_log_aggregated.json")

                with open(aggregated_eval_log_filename, 'w') as f:
                    json.dump(aggregated_eval_logs, f, indent=4)
                
                model_utility = get_model_utility(aggregated_eval_logs)
                retain_result = json.load(open(eval_cfg.retain_result, 'r'))
                forget_quality, trust_ratio = get_forget_quality(aggregated_eval_logs, retain_result)
                aaggregate_stat = {**model_utility, **forget_quality}

                aaggregate_stat['curr_step'] = curr_step
                aaggregate_stat['seed'] = self.seed
                aaggregate_stat['loss_type'] = self.loss_type

                with open(os.path.join(curr_save_dir, "aggregate_stat.txt"), 'w') as txtfile:
                    for key, value in aaggregate_stat.items():
                        txtfile.write(f"{key}: {value}\n")

                with open(os.path.join(curr_save_dir, "truth_ratio.pkl"), 'wb') as picklefile:
                    pickle.dump(trust_ratio, picklefile)


    def evaluate_save_presentation(
        self,
        eval_dataset = None,
        ignore_keys = None,
        metric_key_prefix = "eval",
    ):
        '''
        RZ: Call this function in Trainer.train() when evakluating the performace of each checkpoint.
        '''

        args = self.args
        model = self._wrap_model(self.model, training=False, dataloader=None)

        print('####### Evaluating the model...... #######')
        print(self.is_in_train, args.device, model.dtype, self.args.dataloader_num_workers, self.eval_cfg.split_list)

        if len(self.accelerator._models) == 0 and model is self.model:
            model = (
                self.accelerator.prepare(model)
                if self.is_deepspeed_enabled
                else self.accelerator.prepare_model(model, evaluation_mode=True)
            )

            if self.is_fsdp_enabled:
                self.model = model

            # for the rest of this function `model` is the outside model, whether it was wrapped or not
            if model is not self.model:
                self.model_wrapped = model

            # backward compatibility
            if self.is_deepspeed_enabled:
                self.deepspeed = self.model_wrapped

        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
        # while ``train`` is running, cast it to the right dtype first and then put on device
        if not self.is_in_train:
            if args.fp16_full_eval:
                model = model.to(dtype=torch.float16, device=args.device)
            elif args.bf16_full_eval:
                model = model.to(dtype=torch.bfloat16, device=args.device)
        
        model.eval()
        curr_step = self.state.global_step
        eval_cfg = self.eval_cfg

        if eval_cfg.with_ans:
            curr_save_dir = os.path.join(eval_cfg.save_dir, f"eval_only_with_ans")
        else:
            curr_save_dir = os.path.join(eval_cfg.save_dir, f"eval_only")
        Path(curr_save_dir).mkdir(parents=True, exist_ok=True)

        forget_rate = eval_cfg.split_list[-1].split('_')[0]

        # if int(os.environ.get('RANK', '0')) == 0:
        #     import pdb; pdb.set_trace()

        with torch.no_grad():
            for i, (folder, split, question_key, answer_key, eval_task, base_answer_key, perturbed_answer_key) in enumerate(zip(eval_cfg.data_path, eval_cfg.split_list, eval_cfg.question_key, eval_cfg.answer_key, eval_cfg.eval_task, eval_cfg.base_answer_key, eval_cfg.perturbed_answer_key)):
                
                world_size = self.accelerator.num_processes

                # For some reason, Hydra is not interprating the split correctly
                if eval_task == 'eval_log_forget':
                    split = eval_cfg.split
                print(f'Working on eval task {eval_task} with split {split}')
                save_filename = os.path.join(curr_save_dir, f"{eval_task}.json")
                save_numpy_filename = os.path.join(curr_save_dir, f"{eval_task}.npy")

                save_filename = save_filename if world_size == 1 else os.path.join(curr_save_dir, f"{eval_task}_{self.accelerator.local_process_index}.json")
                save_numpy_filename = save_numpy_filename if world_size == 1 else os.path.join(curr_save_dir, f"{eval_task}_{self.accelerator.local_process_index}.npy")
                if os.path.exists(save_filename) and not eval_cfg.overwrite:
                    print(f"Skipping {eval_task} because {save_filename} already exists")
                    continue

                eval_dataloader, base_eval_dataloader, perturb_dataloader = get_dataloader(eval_cfg, eval_task, self.tokenizer, folder, split, question_key, answer_key, base_answer_key, perturbed_answer_key)

                eval_dataloader = self.accelerator.prepare(eval_dataloader)
                # print('dataset condition: ', len(eval_dataloader.dataset), self.accelerator.local_process_index)
                base_eval_dataloader = self.accelerator.prepare(base_eval_dataloader)
                perturb_dataloader = self.accelerator.prepare(perturb_dataloader)

                # if int(os.environ.get('RANK', '0')) == 0:
                #    import pdb; pdb.set_trace()

                eval_logs, hidden_states = get_all_evals(eval_cfg, model, self.tokenizer, eval_task, eval_dataloader, base_eval_dataloader, perturb_dataloader, output_hidden_states=True, eval_cfg=eval_cfg)

                # last_token_hidden_states = []
                # for i in range(len(hidden_states)):
                #     last_token_hidden_states.append(hidden_states[i][:, -1, :])
                last_token_hidden_states = np.concatenate(hidden_states, axis=0)

                # if int(os.environ.get('RANK', '0')) == 0:
                #     import pdb; pdb.set_trace()
                
                kl_divergence_log = get_kl_divergence(model, self.oracle_model, eval_dataloader)
                eval_logs['kl_divergence'] = kl_divergence_log

                with open(save_filename, "w") as f:
                    # pretty write json to f
                    json.dump(eval_logs, f, indent=4)

                np.save(save_numpy_filename, last_token_hidden_states)

                # if int(os.environ.get('RANK', '0')) == 0:
                #     import pdb; pdb.set_trace()
            
                #wait for all process to finish
            self.accelerator.wait_for_everyone()
            aggregated_eval_logs = {}
            for eval_task in eval_cfg.eval_task:
                #read the saved file as json and merge them using merge_dicts

                if world_size > 1:
                    if self.accelerator.is_local_main_process:
                        eval_logs = json.load(open(os.path.join(curr_save_dir, f"{eval_task}_0.json")))

                        for i in range(1, world_size):
                            filename = os.path.join(curr_save_dir, f"{eval_task}_{i}.json")
                            eval_logs = merge_dicts(eval_logs, json.load(open(filename)))

                        process_last_token_hidden_states = []
                        for i in range(0, world_size):
                            numpy_filename = os.path.join(curr_save_dir, f"{eval_task}_{i}.npy")
                            last_token_hidden_states = np.load(numpy_filename)
                            process_last_token_hidden_states.append(last_token_hidden_states)
                        # import pdb ; pdb.set_trace()

                        new_numpy_save_filename = os.path.join(curr_save_dir, f"{eval_task}.npy")
                        process_last_token_hidden_states = np.concatenate(process_last_token_hidden_states, axis=0)
                        np.save(new_numpy_save_filename, process_last_token_hidden_states)

                        for i in range(world_size):
                            numpy_filename = os.path.join(curr_save_dir, f"{eval_task}_{i}.npy")
                            os.remove(numpy_filename)
                        
                        aggregated_eval_logs[f'{eval_task}.json'] = eval_logs

                        new_save_filename = os.path.join(curr_save_dir, f"{eval_task}.json")
                        with open(new_save_filename, "w") as f:
                            # pretty write json to f
                            json.dump(eval_logs, f, indent=4)
                            #delete old files use shutil
                            for i in range(world_size):
                                filename = os.path.join(curr_save_dir, f"{eval_task}_{i}.json")
                                os.remove(filename)

                    # eval_logs = json.load(open(os.path.join(curr_save_dir, f"{eval_task}.json")))
                    # aggregated_eval_logs[f'{eval_task}.json'] = eval_logs

                else:
                    if self.accelerator.is_local_main_process:
                        eval_logs = json.load(open(os.path.join(curr_save_dir, f"{eval_task}.json")))
                        aggregated_eval_logs[f'{eval_task}.json'] = eval_logs
                                
            if self.accelerator.is_local_main_process:

                # for i in range(1, world_size):
                #     numpy_filename = os.path.join(curr_save_dir, f"{eval_task}_{i}.json")

                aggregated_eval_logs = interleave_eval_result_dict(aggregated_eval_logs, forget_rate, large_bsz=eval_cfg.batch_size, num_processes=world_size)
                aggregated_eval_log_filename = os.path.join(curr_save_dir, "eval_log_aggregated.json")

                with open(aggregated_eval_log_filename, 'w') as f:
                    json.dump(aggregated_eval_logs, f, indent=4)
                
                model_utility = get_model_utility(aggregated_eval_logs)
                retain_result = json.load(open(eval_cfg.retain_result, 'r'))
                forget_quality, trust_ratio = get_forget_quality(aggregated_eval_logs, retain_result)
                aaggregate_stat = {**model_utility, **forget_quality}

                aaggregate_stat['curr_step'] = curr_step
                aaggregate_stat['seed'] = self.seed
                aaggregate_stat['loss_type'] = self.loss_type

                with open(os.path.join(curr_save_dir, "aggregate_stat.txt"), 'w') as txtfile:
                    for key, value in aaggregate_stat.items():
                        txtfile.write(f"{key}: {value}\n")

                with open(os.path.join(curr_save_dir, "truth_ratio.pkl"), 'wb') as picklefile:
                    pickle.dump(trust_ratio, picklefile)

    def evaluate_one_dataloader(
            self,
            eval_dataset = None,
            ignore_keys = None,
            metric_key_prefix = "eval",
        ):
        '''
        RZ: Call this function in Trainer.train() when evakluating the performace of each checkpoint.
        '''

        load_data = torch.load("temp.pth")

        # print(self.model.model.embed_tokens(load_data["input_ids"]))

        args = self.args
        model = self._wrap_model(self.model, training=False, dataloader=None)

        # print(model.model.embed_tokens.weight)
        # print(self.model.model.embed_tokens(load_data["input_ids"]))

        print('####### Evaluating the model...... #######')
        print(self.is_in_train, args.device, model.dtype, self.args.dataloader_num_workers, self.eval_cfg.split_list)

        # import pdb ; pdb.set_trace()

        if len(self.accelerator._models) == 0 and model is self.model:
            model = (
                self.accelerator.prepare(model)
                if self.is_deepspeed_enabled
                else self.accelerator.prepare_model(model, evaluation_mode=True)
            )

            print(model.model.embed_tokens.weight)
            print(self.model.model.embed_tokens(load_data["input_ids"]))

            if self.is_fsdp_enabled:
                self.model = model

            print(model.model.embed_tokens.weight)
            print(self.model.model.embed_tokens(load_data["input_ids"]))

            # for the rest of this function `model` is the outside model, whether it was wrapped or not
            if model is not self.model:
                self.model_wrapped = model

            print(model.model.embed_tokens.weight)
            print(self.model.model.embed_tokens(load_data["input_ids"]))

            # backward compatibility
            if self.is_deepspeed_enabled:
                self.deepspeed = self.model_wrapped

            print(model.model.embed_tokens.weight)
            print(self.model.model.embed_tokens(load_data["input_ids"]))

        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
        # while ``train`` is running, cast it to the right dtype first and then put on device
        if not self.is_in_train:
            if args.fp16_full_eval:
                model = model.to(dtype=torch.float16, device=args.device)
            elif args.bf16_full_eval:
                model = model.to(dtype=torch.bfloat16, device=args.device)

        # print(model.model.embed_tokens.weight)
        # print(self.model.model.embed_tokens(load_data["input_ids"]))
        
        model.eval()
        curr_step = self.state.global_step
        eval_cfg = self.eval_cfg

        # import pdb ; pdb.set_trace()

        if eval_cfg.with_ans:
            curr_save_dir = os.path.join(eval_cfg.save_dir, f"eval_only_with_ans")
        else:
            curr_save_dir = os.path.join(eval_cfg.save_dir, f"eval_only")
        Path(curr_save_dir).mkdir(parents=True, exist_ok=True)

        forget_rate = eval_cfg.split_list[-1].split('_')[0]

        # if int(os.environ.get('RANK', '0')) == 0:
        #     import pdb; pdb.set_trace()

        with torch.no_grad():
            for i, (folder, split, question_key, answer_key, eval_task, base_answer_key, perturbed_answer_key) in enumerate(zip(eval_cfg.data_path, eval_cfg.split_list, eval_cfg.question_key, eval_cfg.answer_key, eval_cfg.eval_task, eval_cfg.base_answer_key, eval_cfg.perturbed_answer_key)):
                
                world_size = self.accelerator.num_processes

                # For some reason, Hydra is not interprating the split correctly
                if eval_task == 'eval_log_forget':
                    split = eval_cfg.split
                print(f'Working on eval task {eval_task} with split {split}')
                save_filename = os.path.join(curr_save_dir, f"{eval_task}.json")
                save_numpy_filename = os.path.join(curr_save_dir, f"{eval_task}.npy")

                save_filename = save_filename if world_size == 1 else os.path.join(curr_save_dir, f"{eval_task}_{self.accelerator.local_process_index}.json")
                save_numpy_filename = save_numpy_filename if world_size == 1 else os.path.join(curr_save_dir, f"{eval_task}_{self.accelerator.local_process_index}.npy")
                if os.path.exists(save_filename) and not eval_cfg.overwrite:
                    print(f"Skipping {eval_task} because {save_filename} already exists")
                    continue

                eval_dataloader = get_single_dataloader(eval_cfg, eval_task, self.tokenizer, folder, split, question_key, answer_key, base_answer_key, perturbed_answer_key)

                eval_dataloader = self.accelerator.prepare(eval_dataloader)
                # print('dataset condition: ', len(eval_dataloader.dataset), self.accelerator.local_process_index)
                # base_eval_dataloader = self.accelerator.prepare(base_eval_dataloader)
                # perturb_dataloader = self.accelerator.prepare(perturb_dataloader)

                # if int(os.environ.get('RANK', '0')) == 0:
                #    import pdb; pdb.set_trace()

                eval_logs, hidden_states = get_single_evals(eval_cfg, model, self.tokenizer, eval_task, eval_dataloader, output_hidden_states=True, eval_cfg=eval_cfg)

                # last_token_hidden_states = []
                # for i in range(len(hidden_states)):
                #     last_token_hidden_states.append(hidden_states[i][:, -1, :])
                last_token_hidden_states = np.concatenate(hidden_states, axis=0)

                # if int(os.environ.get('RANK', '0')) == 0:
                #     import pdb; pdb.set_trace()
                
                # kl_divergence_log = get_kl_divergence(model, self.oracle_model, eval_dataloader)
                # eval_logs['kl_divergence'] = kl_divergence_log

                with open(save_filename, "w") as f:
                    # pretty write json to f
                    json.dump(eval_logs, f, indent=4)

                np.save(save_numpy_filename, last_token_hidden_states)

                # if int(os.environ.get('RANK', '0')) == 0:
                #     import pdb; pdb.set_trace()
            
                #wait for all process to finish
            self.accelerator.wait_for_everyone()
            aggregated_eval_logs = {}
            for eval_task in eval_cfg.eval_task:
                #read the saved file as json and merge them using merge_dicts

                if world_size > 1:
                    if self.accelerator.is_local_main_process:
                        eval_logs = json.load(open(os.path.join(curr_save_dir, f"{eval_task}_0.json")))

                        for i in range(1, world_size):
                            filename = os.path.join(curr_save_dir, f"{eval_task}_{i}.json")
                            eval_logs = merge_dicts(eval_logs, json.load(open(filename)))

                        process_last_token_hidden_states = []
                        for i in range(0, world_size):
                            numpy_filename = os.path.join(curr_save_dir, f"{eval_task}_{i}.npy")
                            last_token_hidden_states = np.load(numpy_filename)
                            process_last_token_hidden_states.append(last_token_hidden_states)
                        # import pdb ; pdb.set_trace()

                        new_numpy_save_filename = os.path.join(curr_save_dir, f"{eval_task}.npy")
                        process_last_token_hidden_states = np.concatenate(process_last_token_hidden_states, axis=0)
                        np.save(new_numpy_save_filename, process_last_token_hidden_states)

                        for i in range(world_size):
                            numpy_filename = os.path.join(curr_save_dir, f"{eval_task}_{i}.npy")
                            os.remove(numpy_filename)
                        
                        aggregated_eval_logs[f'{eval_task}.json'] = eval_logs

                        new_save_filename = os.path.join(curr_save_dir, f"{eval_task}.json")
                        with open(new_save_filename, "w") as f:
                            # pretty write json to f
                            json.dump(eval_logs, f, indent=4)
                            #delete old files use shutil
                            for i in range(world_size):
                                filename = os.path.join(curr_save_dir, f"{eval_task}_{i}.json")
                                os.remove(filename)

                    # eval_logs = json.load(open(os.path.join(curr_save_dir, f"{eval_task}.json")))
                    # aggregated_eval_logs[f'{eval_task}.json'] = eval_logs

                else:
                    if self.accelerator.is_local_main_process:
                        eval_logs = json.load(open(os.path.join(curr_save_dir, f"{eval_task}.json")))
                        aggregated_eval_logs[f'{eval_task}.json'] = eval_logs
                                
            if self.accelerator.is_local_main_process:

                # for i in range(1, world_size):
                #     numpy_filename = os.path.join(curr_save_dir, f"{eval_task}_{i}.json")

                aggregated_eval_logs = interleave_eval_result_dict(aggregated_eval_logs, forget_rate, large_bsz=eval_cfg.batch_size, num_processes=world_size)
                aggregated_eval_log_filename = os.path.join(curr_save_dir, "eval_log_aggregated.json")

                with open(aggregated_eval_log_filename, 'w') as f:
                    json.dump(aggregated_eval_logs, f, indent=4)
                
                model_utility = get_model_utility(aggregated_eval_logs, no_ratio=True)
                retain_result = json.load(open(eval_cfg.retain_result, 'r'))
                # forget_quality, trust_ratio = get_forget_quality(aggregated_eval_logs, retain_result)
                # aaggregate_stat = {**model_utility, **forget_quality}
                aaggregate_stat = {**model_utility}

                aaggregate_stat['curr_step'] = curr_step
                aaggregate_stat['seed'] = self.seed
                aaggregate_stat['loss_type'] = self.loss_type

                with open(os.path.join(curr_save_dir, "aggregate_stat.txt"), 'w') as txtfile:
                    for key, value in aaggregate_stat.items():
                        txtfile.write(f"{key}: {value}\n")

                # with open(os.path.join(curr_save_dir, "truth_ratio.pkl"), 'wb') as picklefile:
                #     pickle.dump(trust_ratio, picklefile)

    def evaluate_reft(
            self,
            eval_dataset = None,
            ignore_keys = None,
            metric_key_prefix = "eval",
        ):
        '''
        RZ: Call this function in Trainer.train() when evakluating the performace of each checkpoint.
        '''

        args = self.args
        model = self._wrap_model(self.model, training=False, dataloader=None)
        # model = self.model

        # if int(os.environ.get('RANK', '0')) == 0:
        #     import pdb; pdb.set_trace()

        print('####### Evaluating the model...... #######')
        print(self.is_in_train, args.device, model.model.dtype, self.args.dataloader_num_workers, self.eval_cfg.split_list)

        if len(self.accelerator._models) == 0 and model is self.model:
            model.model = (
                self.accelerator.prepare(model.model)
                if self.is_deepspeed_enabled
                else self.accelerator.prepare_model(model, evaluation_mode=True)
            )
            # import pdb ; pdb.set_trace()

            if self.is_fsdp_enabled:
                self.model = model

            # for the rest of this function `model` is the outside model, whether it was wrapped or not
            if model is not self.model:
                self.model_wrapped = model

            # backward compatibility
            if self.is_deepspeed_enabled:
                self.deepspeed = self.model_wrapped

        # # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
        # # while ``train`` is running, cast it to the right dtype first and then put on device
        # if not self.is_in_train:
        #     if args.fp16_full_eval:
        #         model = model.to(dtype=torch.float16, device=args.device)
        #     elif args.bf16_full_eval:
        #         model = model.to(dtype=torch.bfloat16, device=args.device)
        
        # model.model.eval()
        curr_step = self.state.global_step
        eval_cfg = self.eval_cfg

        if eval_cfg.with_ans:
            curr_save_dir = os.path.join(eval_cfg.save_dir, f"eval_only_with_ans")
        else:
            curr_save_dir = os.path.join(eval_cfg.save_dir, f"eval_only")
        Path(curr_save_dir).mkdir(parents=True, exist_ok=True)

        forget_rate = eval_cfg.split_list[-1].split('_')[0]

        # if int(os.environ.get('RANK', '0')) == 0:
        #     import pdb; pdb.set_trace()

        with torch.no_grad():
            for i, (folder, split, question_key, answer_key, eval_task, base_answer_key, perturbed_answer_key) in enumerate(zip(eval_cfg.data_path, eval_cfg.split_list, eval_cfg.question_key, eval_cfg.answer_key, eval_cfg.eval_task, eval_cfg.base_answer_key, eval_cfg.perturbed_answer_key)):
                
                world_size = self.accelerator.num_processes

                # For some reason, Hydra is not interprating the split correctly
                if eval_task == 'eval_log_forget':
                    split = eval_cfg.split
                print(f'Working on eval task {eval_task} with split {split}')
                save_filename = os.path.join(curr_save_dir, f"{eval_task}.json")
                save_numpy_filename = os.path.join(curr_save_dir, f"{eval_task}.npy")

                save_filename = save_filename if world_size == 1 else os.path.join(curr_save_dir, f"{eval_task}_{self.accelerator.local_process_index}.json")
                save_numpy_filename = save_numpy_filename if world_size == 1 else os.path.join(curr_save_dir, f"{eval_task}_{self.accelerator.local_process_index}.npy")
                if os.path.exists(save_filename) and not eval_cfg.overwrite:
                    print(f"Skipping {eval_task} because {save_filename} already exists")
                    continue

                eval_dataloader = get_single_dataloader(eval_cfg, eval_task, self.tokenizer, folder, split, question_key, answer_key, base_answer_key, perturbed_answer_key)

                eval_dataloader = self.accelerator.prepare(eval_dataloader)
                # print('dataset condition: ', len(eval_dataloader.dataset), self.accelerator.local_process_index)
                # base_eval_dataloader = self.accelerator.prepare(base_eval_dataloader)
                # perturb_dataloader = self.accelerator.prepare(perturb_dataloader)

                # if int(os.environ.get('RANK', '0')) == 0:
                #    import pdb; pdb.set_trace()

                eval_logs, hidden_states = get_reft_evals(eval_cfg, model, self.tokenizer, eval_task, eval_dataloader, output_hidden_states=True, eval_cfg=eval_cfg)

                # last_token_hidden_states = []
                # for i in range(len(hidden_states)):
                #     last_token_hidden_states.append(hidden_states[i][:, -1, :])
                last_token_hidden_states = np.concatenate(hidden_states, axis=0)

                # if int(os.environ.get('RANK', '0')) == 0:
                #     import pdb; pdb.set_trace()
                
                # kl_divergence_log = get_kl_divergence(model, self.oracle_model, eval_dataloader)
                # eval_logs['kl_divergence'] = kl_divergence_log

                with open(save_filename, "w") as f:
                    # pretty write json to f
                    json.dump(eval_logs, f, indent=4)

                np.save(save_numpy_filename, last_token_hidden_states)

                # if int(os.environ.get('RANK', '0')) == 0:
                #     import pdb; pdb.set_trace()
            
                #wait for all process to finish
            self.accelerator.wait_for_everyone()
            aggregated_eval_logs = {}
            for eval_task in eval_cfg.eval_task:
                #read the saved file as json and merge them using merge_dicts

                if world_size > 1:
                    if self.accelerator.is_local_main_process:
                        eval_logs = json.load(open(os.path.join(curr_save_dir, f"{eval_task}_0.json")))

                        for i in range(1, world_size):
                            filename = os.path.join(curr_save_dir, f"{eval_task}_{i}.json")
                            eval_logs = merge_dicts(eval_logs, json.load(open(filename)))

                        process_last_token_hidden_states = []
                        for i in range(0, world_size):
                            numpy_filename = os.path.join(curr_save_dir, f"{eval_task}_{i}.npy")
                            last_token_hidden_states = np.load(numpy_filename)
                            process_last_token_hidden_states.append(last_token_hidden_states)
                        # import pdb ; pdb.set_trace()

                        new_numpy_save_filename = os.path.join(curr_save_dir, f"{eval_task}.npy")
                        process_last_token_hidden_states = np.concatenate(process_last_token_hidden_states, axis=0)
                        np.save(new_numpy_save_filename, process_last_token_hidden_states)

                        for i in range(world_size):
                            numpy_filename = os.path.join(curr_save_dir, f"{eval_task}_{i}.npy")
                            os.remove(numpy_filename)
                        
                        aggregated_eval_logs[f'{eval_task}.json'] = eval_logs

                        new_save_filename = os.path.join(curr_save_dir, f"{eval_task}.json")
                        with open(new_save_filename, "w") as f:
                            # pretty write json to f
                            json.dump(eval_logs, f, indent=4)
                            #delete old files use shutil
                            for i in range(world_size):
                                filename = os.path.join(curr_save_dir, f"{eval_task}_{i}.json")
                                os.remove(filename)

                    # eval_logs = json.load(open(os.path.join(curr_save_dir, f"{eval_task}.json")))
                    # aggregated_eval_logs[f'{eval_task}.json'] = eval_logs

                else:
                    if self.accelerator.is_local_main_process:
                        eval_logs = json.load(open(os.path.join(curr_save_dir, f"{eval_task}.json")))
                        aggregated_eval_logs[f'{eval_task}.json'] = eval_logs
                                
            if self.accelerator.is_local_main_process:

                # for i in range(1, world_size):
                #     numpy_filename = os.path.join(curr_save_dir, f"{eval_task}_{i}.json")

                aggregated_eval_logs = interleave_eval_result_dict(aggregated_eval_logs, forget_rate, large_bsz=eval_cfg.batch_size, num_processes=world_size)
                aggregated_eval_log_filename = os.path.join(curr_save_dir, "eval_log_aggregated.json")

                with open(aggregated_eval_log_filename, 'w') as f:
                    json.dump(aggregated_eval_logs, f, indent=4)
                
                model_utility = get_model_utility(aggregated_eval_logs, no_ratio=True)
                retain_result = json.load(open(eval_cfg.retain_result, 'r'))
                # forget_quality, trust_ratio = get_forget_quality(aggregated_eval_logs, retain_result)
                # aaggregate_stat = {**model_utility, **forget_quality}
                aaggregate_stat = {**model_utility}

                aaggregate_stat['curr_step'] = curr_step
                aaggregate_stat['seed'] = self.seed
                aaggregate_stat['loss_type'] = self.loss_type

                with open(os.path.join(curr_save_dir, "aggregate_stat.txt"), 'w') as txtfile:
                    for key, value in aaggregate_stat.items():
                        txtfile.write(f"{key}: {value}\n")

                # with open(os.path.join(curr_save_dir, "truth_ratio.pkl"), 'wb') as picklefile:
                #     pickle.dump(trust_ratio, picklefile)


class CustomTrainerRetraining(Trainer):
    def __init__(self, *args, **kwargs):
        self.eval_cfg = kwargs.pop('eval_cfg')
        self.seed = kwargs.pop('seed')
        super(CustomTrainerRetraining, self).__init__(*args, **kwargs)

    def get_train_dataloader(self):
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")

        train_dataset = self.train_dataset
        data_collator = self.data_collator
        if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
            train_dataset = self._remove_unused_columns(train_dataset, description="training")
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="training")

        dataloader_params = {
            "batch_size": self._train_batch_size,
            "collate_fn": data_collator,
            "num_workers": self.args.dataloader_num_workers,
            "pin_memory": self.args.dataloader_pin_memory,
            "persistent_workers": self.args.dataloader_persistent_workers,
        }
        
        generator = torch.Generator()
        generator.manual_seed(self.seed + self.state.global_step)
        print(f'Generator........Epoch-{self.state.global_step}')

        if not isinstance(train_dataset, torch.utils.data.IterableDataset):
            dataloader_params["generator"] = generator
            dataloader_params["shuffle"] = True # set shuffle=True with specified generator.
            # dataloader_params["sampler"] = self._get_train_sampler()
            dataloader_params["drop_last"] = self.args.dataloader_drop_last
            dataloader_params["worker_init_fn"] = seed_worker

        return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))

    def compute_loss(self, model, inputs, return_outputs=False):
        input_ids, labels, attention_mask = inputs
        outputs = model(input_ids,labels=labels, attention_mask=attention_mask)
        loss = outputs.loss
        return (loss, outputs) if return_outputs else loss
    
    def prediction_step(self, model, inputs, prediction_loss_only: bool, ignore_keys=None):
        input_ids, labels, attention_mask = inputs
        # forward pass
        with torch.no_grad():
            outputs = model(input_ids,labels=labels, attention_mask=attention_mask)
            logits = outputs.logits
            loss = outputs.loss
        return (loss, logits, labels)

    def evaluate(
        self,
        eval_dataset = None,
        ignore_keys = None,
        metric_key_prefix = "eval",
    ):

        args = self.args
        model = self._wrap_model(self.model, training=False, dataloader=None)

        print('####### Evaluating the model...... #######')
        # print(self.is_in_train, args.device, model.dtype, self.args.dataloader_num_workers, self.eval_cfg.split_list)

        if len(self.accelerator._models) == 0 and model is self.model:
            model = (
                self.accelerator.prepare(model)
                if self.is_deepspeed_enabled
                else self.accelerator.prepare_model(model, evaluation_mode=True)
            )

            if self.is_fsdp_enabled:
                self.model = model

            # for the rest of this function `model` is the outside model, whether it was wrapped or not
            if model is not self.model:
                self.model_wrapped = model

            # backward compatibility
            if self.is_deepspeed_enabled:
                self.deepspeed = self.model_wrapped

        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
        # while ``train`` is running, cast it to the right dtype first and then put on device
        if not self.is_in_train:
            if args.fp16_full_eval:
                model = model.to(dtype=torch.float16, device=args.device)
            elif args.bf16_full_eval:
                model = model.to(dtype=torch.bfloat16, device=args.device)
        
        model.eval()
        curr_step = self.state.global_step
        eval_cfg = self.eval_cfg

        curr_save_dir = os.path.join(eval_cfg.save_dir, f"checkpoint-{curr_step}")
        Path(curr_save_dir).mkdir(parents=True, exist_ok=True)

        forget_rate = eval_cfg.split.split('_')[0]

        with torch.no_grad():
            for i, (folder, split, question_key, answer_key, eval_task, base_answer_key, perturbed_answer_key) in enumerate(zip(eval_cfg.data_path, eval_cfg.split_list, eval_cfg.question_key, eval_cfg.answer_key, eval_cfg.eval_task, eval_cfg.base_answer_key, eval_cfg.perturbed_answer_key)):

                world_size = self.accelerator.num_processes

                # For some reason, Hydra is not interprating the split correctly
                if eval_task == 'eval_log_forget':
                    split = eval_cfg.split
                print(f'Working on eval task {eval_task} with split {split}')
                save_filename = os.path.join(curr_save_dir, f"{eval_task}.json")

                save_filename = save_filename if world_size == 1 else os.path.join(curr_save_dir, f"{eval_task}_{self.accelerator.local_process_index}.json")
                # print(save_filename)
                if os.path.exists(save_filename) and not eval_cfg.overwrite:
                    print(f"Skipping {eval_task} because {save_filename} already exists")
                    continue

                eval_dataloader, base_eval_dataloader, perturb_dataloader = get_dataloader(eval_cfg, eval_task, self.tokenizer, folder, split, question_key, answer_key, base_answer_key, perturbed_answer_key)
                eval_dataloader = self.accelerator.prepare(eval_dataloader)
                # print('dataset condition: ', len(eval_dataloader.dataset), self.accelerator.local_process_index)
                base_eval_dataloader = self.accelerator.prepare(base_eval_dataloader)
                perturb_dataloader = self.accelerator.prepare(perturb_dataloader)

                eval_logs, hidden_states = get_all_evals(eval_cfg, model, self.tokenizer, eval_task, eval_dataloader, base_eval_dataloader, perturb_dataloader, output_hidden_states=True)
                with open(save_filename, "w") as f:
                    # pretty write json to f
                    json.dump(eval_logs, f, indent=4)
            
                #wait for all process to finish
            self.accelerator.wait_for_everyone()
            aggregated_eval_logs = {}
            for eval_task in eval_cfg.eval_task:
                #read the saved file as json and merge them using merge_dicts

                if world_size > 1:
                    if self.accelerator.is_local_main_process:
                        eval_logs = json.load(open(os.path.join(curr_save_dir, f"{eval_task}_0.json")))

                        for i in range(1, world_size):
                            filename = os.path.join(curr_save_dir, f"{eval_task}_{i}.json")
                            eval_logs = merge_dicts(eval_logs, json.load(open(filename)))
                        
                        aggregated_eval_logs[f'{eval_task}.json'] = eval_logs

                        new_save_filename = os.path.join(curr_save_dir, f"{eval_task}.json")
                        with open(new_save_filename, "w") as f:
                            # pretty write json to f
                            json.dump(eval_logs, f, indent=4)
                            #delete old files use shutil
                            for i in range(world_size):
                                filename = os.path.join(curr_save_dir, f"{eval_task}_{i}.json")
                                os.remove(filename)

                else:
                    if self.accelerator.is_local_main_process:
                        eval_logs = json.load(open(os.path.join(curr_save_dir, f"{eval_task}.json")))
                        aggregated_eval_logs[f'{eval_task}.json'] = eval_logs
                                
            if self.accelerator.is_local_main_process:

                aggregated_eval_logs = interleave_eval_result_dict(aggregated_eval_logs, forget_rate, large_bsz=eval_cfg.batch_size, num_processes=world_size)
                aggregated_eval_log_filename = os.path.join(curr_save_dir, "eval_log_aggregated.json")

                with open(aggregated_eval_log_filename, 'w') as f:
                    json.dump(aggregated_eval_logs, f, indent=4)


def custom_data_collator_forget(samples):
    rets = []
    # if len(samples[0]) == 3: # This is for dpo
    #     idk_samples, forget_samples, retain_samples = [sample[0] for sample in samples], [sample[1] for sample in samples], [sample[2] for sample in samples]
    #     data_types = ["idk", "forget", "retain"]
    if len(samples[0]) == 3:
        forget_samples, retain_samples, mix_retain_samples = [sample[0] for sample in samples], [sample[1] for sample in samples], [sample[2] for sample in samples]
        data_types = ["forget", "retain", "mix_retain"]
    elif len(samples[0]) == 2:
        forget_samples, retain_samples = [sample[0] for sample in samples], [sample[1] for sample in samples]
        data_types = ["forget", "retain"]
    elif len(samples[0]) == 5:
        forget_samples, retain_samples, forget_question_samples, retain_question_samples, mix_question_samples = [sample[0] for sample in samples], [sample[1] for sample in samples], [sample[2] for sample in samples], [sample[3] for sample in samples], [sample[4] for sample in samples]
        data_types = ["forget", "retain", "forget_question", "retain_question", "mix_question"]
    for data_type in data_types:
        if data_type == "forget":
            data = forget_samples 
        elif data_type == "retain":
            data = retain_samples 
        elif data_type == "idk":
            data = idk_samples
        elif data_type == "mix_retain":
            data = mix_retain_samples
        elif data_type == "forget_question":
            data = forget_question_samples
        elif data_type == "retain_question":
            data = retain_question_samples
        elif data_type == "mix_question":
            data = mix_question_samples
        
 
        input_ids = [s[0] for s in data]
        labels = [s[1] for s in data]
        attention_mask = [s[2] for s in data]
        rets.append((torch.stack(input_ids), torch.stack(labels), torch.stack(attention_mask)))
    return rets

def custom_data_collator_forget_wmdp(samples):
    rets = []
    # if len(samples[0]) == 3: This is for dpo
    #     idk_samples, forget_samples, retain_samples = [sample[0] for sample in samples], [sample[1] for sample in samples], [sample[2] for sample in samples]
    #     data_types = ["idk", "forget", "retain"]
    if len(samples[0]) == 3:
        forget_samples, retain_samples, mix_retain_samples = [sample[0] for sample in samples], [sample[1] for sample in samples], [sample[2] for sample in samples]
        data_types = ["forget", "retain", "mix_retain"]
    elif len(samples[0]) == 2:
        forget_samples, retain_samples = [sample[0] for sample in samples], [sample[1] for sample in samples]
        data_types = ["forget", "retain"]
    # import pdb ; pdb.set_trace()
    for data_type in data_types:
        if data_type == "forget":
            data = forget_samples 
        elif data_type == "retain":
            data = retain_samples 
        elif data_type == "mix_retain":
            data = mix_retain_samples 
        elif data_type == "idk":
            data = idk_samples
        
        input_ids = [s["input_ids"] for s in data]
        labels = [s["labels"] for s in data]
        attention_mask = [s["attention_mask"] for s in data]
        rets.append((torch.stack(input_ids), torch.stack(labels), torch.stack(attention_mask)))
    return rets

def compute_metrics(pred):
    logits, labels = torch.from_numpy(pred.predictions), torch.from_numpy(pred.label_ids)
    preds = torch.from_numpy(pred.predictions.argmax(-1))
    shifted_labels = labels[..., 1:].contiguous()
    acc = torch.mean((preds[..., :-1] == shifted_labels).float())
    loss  = get_loss(logits, labels)
    return {"eval accuracy": acc, "eval loss": loss.item()}

def get_loss(output, labels):
    shifted_labels = labels[..., 1:].contiguous()
    output = output[..., :-1, :].contiguous()

    loss_function = nn.CrossEntropyLoss(ignore_index=-100)
    loss = loss_function(output.view(-1, output.size(-1)), shifted_labels.view(-1))

    return loss