from transformers import Trainer, TrainingArguments
import torch.nn.functional as F
import torch
from torch.utils.data import Sampler, Dataset, RandomSampler, SequentialSampler
from typing import Optional
from copy import deepcopy
from typing import Iterator, Optional, Sized
import numpy as np
import random

class GroupRandomSampler(Sampler[int]):
    r"""Samples elements sequentially, always in the same order.

    Args:
        data_source (Dataset): dataset to sample from
    """
    def __init__(self, data_source: Dataset) -> None:
        self.data_source = data_source
        self.group_types = sorted(set(item['loss_type'] for item in data_source))
        
        # print(self.data_source)

    def __iter__(self) -> Iterator[int]:
        n = len(self.data_source)
        
        groups = {group_type: [] for group_type in self.group_types}
        for idx, item in enumerate(self.data_source):
            groups[item['loss_type']].append(idx)
            
        # Extract all groups
        # Shuffle the list of groups
        # Generate a unique random seed for each epoch
        # seed = random.randint(0, 2**32 - 1)
        seed = int(torch.empty((), dtype=torch.int64).random_().item())
        min_length = np.min([len(item) for item in groups.values()])
        group_ids = list(range(min_length))
        # suppose we can 
        group_scale = [len(item) for item in groups.values()] / min_length
        
        # print(group_scale, self.group_types)
        random.seed(seed)
        random.shuffle(group_ids)
        
        generator = torch.Generator()
        generator.manual_seed(seed)
        
        normal_indices = torch.randperm(len(groups[0]), generator=generator).tolist()
        
        # print(f'normal indices:{normal_indices}')
        
        normal_iter = -1
        for group_id in group_ids:
            for id, type_id in enumerate(self.group_types):
                for i in range(int(group_scale[id])):
                    # print(f"logging: type{type_id} ---- {groups[type_id][group_id * int(group_scale[id]) + i]}")
                    if type_id == 0:
                        normal_iter += 1
                        normal_iter %= len(groups[0])
                        # print(f"logging: type{type_id} ---- {groups[type_id][normal_indices[normal_iter]]}")
                        yield groups[type_id][normal_indices[normal_iter]]
                    else:
                        # print(f"logging: type{type_id} ---- {groups[type_id][group_id * int(group_scale[type_id]) + i]}")
                        yield groups[type_id][group_id *int(group_scale[id]) + i]
                    
    def __len__(self) -> int:
        return len(self.data_source)

class NPOTrainer(Trainer):
    def __init__(self, alpha, beta, theta, reference_model, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # self.pad = pad
        self.alpha = alpha
        self.beta = beta
        self.theta = theta
        device = self.accelerator.device
        self.reference_model = reference_model.to(device)
        self.loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
        # print("Pad Token ID: ", self.pad)
    
    def compute_logps(self, prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits):
        mask = chosen_attention_mask[:, :-1] - prompt_attention_mask[:, 1:]
        per_token_logps = torch.gather(logits[:, :-1, :].log_softmax(-1), dim=2, 
                                       index=(mask * chosen_inputs[:, 1:]).unsqueeze(2)).squeeze(2)
        return torch.mul(per_token_logps, mask.to(dtype=torch.bfloat16)).sum(dim=1).to(dtype=torch.float64) / mask.sum(dim=1).to(dtype=torch.float64)
        
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        types = inputs.pop("loss_type")
        
        labels = inputs.get("labels")
        # input_ids = inputs.get('input_ids')
        # print(f'input_ids size:{labels.size()}')
        # print(f'{input_ids[:, 100:110]}')
        # print(f'types:{types}')
        neg_idxs = (types == 1)
        pos_idxs = (types == 2)
        normal_idxs = (types == 0)
        
        print(f'normal sample num: {normal_idxs.sum()}, pos sample num: {pos_idxs.sum()}, neg sample num: {neg_idxs.sum()}')
        
        loss, logits = compute_crossentropy_loss(model, inputs, mean=False, return_logits=True)
        
        # loss_sft = loss[normal_idxs | pos_idxs].mean()
        # print(f'all loss:{loss}, types:{types}')
        # normal_loss = loss[normal_idxs].mean()
        normal_loss = get_subbatch_loss(loss, normal_idxs, labels)
        # normal_loss = torch.tensor(0., device=self.accelerator.device)
        # pos_loss = self.theta * loss[pos_idxs].mean()
        if pos_idxs.sum() > 0 and self.theta > 0:
            pos_loss = self.theta * get_subbatch_loss(loss, pos_idxs, labels)
        
            loss_sft = normal_loss + pos_loss
        else:
            loss_sft = normal_loss
        
        if neg_idxs.sum() > 0 and self.alpha > 0:
            attention_mask = inputs['attention_mask']
            prompt_attention_mask = attention_mask.clone()
            prompt_attention_mask[labels != -100] = 0.
            
            log_probs = self.compute_logps(prompt_attention_mask, inputs['input_ids'], attention_mask, logits)
            
            neg_prob = log_probs[neg_idxs]
            
            with torch.no_grad():
                ref_loss, ref_logits = compute_crossentropy_loss(self.reference_model, inputs, mean=False, return_logits=True)
            
            ref_log_probs = self.compute_logps(prompt_attention_mask, inputs['input_ids'], attention_mask, ref_logits)
            ref_neg_prob = ref_log_probs[neg_idxs]
            
            log_delta = -self.beta * (neg_prob - ref_neg_prob)
            dpo_loss = torch.log(torch.nn.functional.sigmoid(log_delta))
            
            loss = loss_sft - self.alpha * torch.mean(dpo_loss)
            
        else:
            loss = loss_sft
        
        # loss = normal_loss
        
        self.log({'Negative Geometric Mean': torch.mean(neg_prob).item() if 'neg_prob' in locals().keys() else 0.,
                   'DPO Loss': -self.alpha * torch.mean(dpo_loss).item() if 'dpo_loss' in locals().keys() else 0.,
                   'Normal Loss': normal_loss.item(),
                   'Positive Loss': pos_loss.item() if 'pos_loss' in locals().keys() else 0.,
                   'Normal prob': -normal_loss.item(),
                   'Positive prob': -pos_loss.item() if 'pos_loss' in locals().keys() else 0.,
                   'Negative prob': torch.mean(neg_prob).item() if 'neg_prob' in locals().keys() else 0.})   
        
        # self.log({'Normal Loss': normal_loss.item()})     
        return loss
    
    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
        # return SequentialSampler(self.train_dataset)
        # return RandomSampler(self.train_dataset)
        return GroupRandomSampler(self.train_dataset)

class DPOTrainer(Trainer):
    def __init__(self, alpha, beta, reference_model, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # self.pad = pad
        self.alpha = alpha
        self.beta = beta
        device = self.accelerator.device
        self.reference_model = reference_model.to(device)
        self.loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
        # print("Pad Token ID: ", self.pad)

    def compute_logps(self, prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits):
        mask = chosen_attention_mask[:, :-1] - prompt_attention_mask[:, 1:]
        per_token_logps = torch.gather(logits[:, :-1, :].log_softmax(-1), dim=2, 
                                       index=(mask * chosen_inputs[:, 1:]).unsqueeze(2)).squeeze(2)
        return torch.mul(per_token_logps, mask.to(dtype=torch.bfloat16)).sum(dim=1).to(dtype=torch.float64) / mask.sum(dim=1).to(dtype=torch.float64)

    def compute_loss(
        self, model, inputs, return_outputs=False, num_items_in_batch=None
    ):
        types = inputs.pop("loss_type")

        labels = inputs.get("labels")
        neg_idxs = (types == 1)
        pos_idxs = (types == 2)
        normal_idxs = (types == 0)
        print(f'normal sample num: {normal_idxs.sum()}, pos sample num: {pos_idxs.sum()}, neg sample num: {neg_idxs.sum()}')

        loss, logits = compute_crossentropy_loss(model, inputs, mean=False, return_logits=True)

        loss_sft = get_subbatch_loss(loss, normal_idxs, labels)
        # loss_sft = loss[normal_idxs].mean()

        if pos_idxs.sum() > 0 and neg_idxs.sum() > 0:
            attention_mask = inputs['attention_mask']
            prompt_attention_mask = attention_mask.clone()
            prompt_attention_mask[labels != -100] = 0.

            log_probs = self.compute_logps(prompt_attention_mask, inputs['input_ids'], attention_mask, logits)

            pos_prob = log_probs[pos_idxs]
            neg_prob = log_probs[neg_idxs]

            with torch.no_grad():
                ref_loss, ref_logits = compute_crossentropy_loss(self.reference_model, inputs, mean=False, return_logits=True)

            ref_log_probs = self.compute_logps(prompt_attention_mask, inputs['input_ids'], attention_mask, ref_logits)
            ref_pos_prob = ref_log_probs[pos_idxs]
            ref_neg_prob = ref_log_probs[neg_idxs]

            log_delta = self.beta * (pos_prob - ref_pos_prob - neg_prob + ref_neg_prob)
            dpo_loss = torch.log(torch.nn.functional.sigmoid(log_delta))

            loss = loss_sft - self.alpha * torch.mean(dpo_loss)
        else:
            loss = loss_sft

        self.log(
            {
                "Positive Geometric Mean": (
                    torch.mean(pos_prob).item()
                    if "pos_prob" in locals().keys()
                    else 0.0
                ),
                "Negative Geometric Mean": (
                    torch.mean(neg_prob).item()
                    if "neg_prob" in locals().keys()
                    else 0.0
                ),
                "DPO Loss": (
                    torch.mean(dpo_loss).item()
                    if "dpo_loss" in locals().keys()
                    else 0.0
                ),
                "Normal prob": -loss_sft.item(),
                "Positive prob": torch.mean(pos_prob).item() if "pos_prob" in locals().keys() else 0.0,
                "Negative prob": torch.mean(neg_prob).item() if "neg_prob" in locals().keys() else 0.0,
            }
        )
        return loss

    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
        # return SequentialSampler(self.train_dataset)
        return GroupRandomSampler(self.train_dataset)

class DPOSFTTrainer(Trainer):
    def __init__(self, alpha, beta, theta, reference_model, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # self.pad = pad
        self.alpha = alpha
        self.beta = beta
        self.theta = theta
        device = self.accelerator.device
        self.reference_model = reference_model.to(device)
        self.loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
        # print("Pad Token ID: ", self.pad)
    
    def compute_logps(self, prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits):
        mask = chosen_attention_mask[:, :-1] - prompt_attention_mask[:, 1:]
        per_token_logps = torch.gather(logits[:, :-1, :].log_softmax(-1), dim=2, 
                                       index=(mask * chosen_inputs[:, 1:]).unsqueeze(2)).squeeze(2)
        return torch.mul(per_token_logps, mask.to(dtype=torch.bfloat16)).sum(dim=1).to(dtype=torch.float64) / mask.sum(dim=1).to(dtype=torch.float64)
        
    def compute_loss(self, model, inputs, return_outputs=False):
        types = inputs.pop("loss_type")
        
        labels = inputs.get("labels")
        neg_idxs = (types == 1)
        pos_idxs = (types == 2)
        normal_idxs = (types == 0)
        
        loss, logits = compute_crossentropy_loss(model, inputs, mean=False, return_logits=True)
        
        loss_sft = get_subbatch_loss(loss, normal_idxs, labels)
        # loss_sft = loss[normal_idxs].mean()
        loss_pos_sft = get_subbatch_loss(loss, pos_idxs, labels)
        
        attention_mask = inputs['attention_mask']
        prompt_attention_mask = attention_mask.clone()
        prompt_attention_mask[labels != -100] = 0.
        
        log_probs = self.compute_logps(prompt_attention_mask, inputs['input_ids'], attention_mask, logits)
        
        pos_prob = log_probs[pos_idxs]
        neg_prob = log_probs[neg_idxs]
        
        with torch.no_grad():
            ref_loss, ref_logits = compute_crossentropy_loss(self.reference_model, inputs, mean=False, return_logits=True)
        
        ref_log_probs = self.compute_logps(prompt_attention_mask, inputs['input_ids'], attention_mask, ref_logits)
        ref_pos_prob = ref_log_probs[pos_idxs]
        ref_neg_prob = ref_log_probs[neg_idxs]
        
        log_delta = self.beta * (pos_prob - ref_pos_prob - neg_prob + ref_neg_prob)
        dpo_loss = torch.log(torch.nn.functional.sigmoid(log_delta))
        
        loss = loss_sft - self.alpha * torch.mean(dpo_loss) + self.theta * loss_pos_sft
        
        self.log({'Positive Geometric Mean': torch.mean(pos_prob).item(),
                   'Negative Geometric Mean': torch.mean(neg_prob).item(),
                   'DPO Loss': - self.alpha * torch.mean(dpo_loss).item(),
                   'Normal Loss': loss_sft.item(),
                   'Positive Loss': self.theta * loss_pos_sft.item()})        
        return loss
    
    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
        # return SequentialSampler(self.train_dataset)
        return GroupRandomSampler(self.train_dataset)

class ORPOTrainer(Trainer):
    def __init__(self, alpha, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # self.pad = pad
        self.alpha = alpha
        self.loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
        # print("Pad Token ID: ", self.pad)
        
    def compute_custom_loss(self, logits, labels):
        
        logits = logits.contiguous()
        
        if labels is not None:
            # move labels to correct device to enable model parallelism
            labels = labels.to(logits.device)
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            
            # Flatten the tokens
            loss = self.loss_fct(shift_logits.transpose(2, 1), shift_labels).mean(dim=-1)
            
        return loss
    
    def compute_logps(self, prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits):
        mask = chosen_attention_mask[:, :-1] - prompt_attention_mask[:, 1:]
        per_token_logps = torch.gather(logits[:, :-1, :].log_softmax(-1), dim=2, 
                                       index=(mask * chosen_inputs[:, 1:]).unsqueeze(2)).squeeze(2)
        return torch.mul(per_token_logps, mask.to(dtype=torch.bfloat16)).sum(dim=1).to(dtype=torch.float64) / mask.sum(dim=1).to(dtype=torch.float64)
        
    def compute_loss(self, model, inputs, return_outputs=False):
        # if self.label_smoother is not None and "labels" in inputs:
        #     labels = inputs.pop("labels")
        # else:
        #     labels = None
        
        # # Generate the hidden states for 'chosen' and 'reject'
        # neg_labels = inputs['negative_input_ids'].clone()
        # pos_labels = inputs['positive_input_ids'].clone()

        # neg_labels[neg_labels == self.pad] = -100
        # pos_labels[pos_labels == self.pad] = -100

        # outputs_neg = model(**{'input_ids': inputs['negative_input_ids'],
        #                        'attention_mask': inputs['negative_attention_mask'],
        #                        'labels': neg_labels,}, output_hidden_states=True)      
        # outputs_pos = model(**{'input_ids': inputs['positive_input_ids'],
        #                        'attention_mask': inputs['positive_attention_mask'],
        #                        'labels': pos_labels,}, output_hidden_states=True)
            
        # # Calculate NLL loss
        # pos_loss = self.compute_custom_loss(logits=outputs_pos.logits, labels=inputs['positive_input_ids']) 

        # # Calculate Log Probability
        # pos_prob = self.compute_logps(prompt_attention_mask=inputs['attention_mask'], 
        #                               chosen_inputs=inputs['positive_input_ids'], 
        #                               chosen_attention_mask=inputs['positive_attention_mask'], 
        #                               logits=outputs_pos.logits)
        # neg_prob = self.compute_logps(prompt_attention_mask=inputs['attention_mask'], 
        #                               chosen_inputs=inputs['negative_input_ids'], 
        #                               chosen_attention_mask=inputs['negative_attention_mask'], 
        #                               logits=outputs_neg.logits)

        # # Calculate log odds
        # log_odds = (pos_prob - neg_prob) - (torch.log(1 - torch.exp(pos_prob)) - torch.log(1 - torch.exp(neg_prob)))
        # sig_ratio = torch.nn.functional.sigmoid(log_odds)
        # ratio = torch.log(sig_ratio)
        
        # # Calculate the Final Loss
        # loss = torch.mean(pos_loss - self.alpha * ratio).to(dtype=torch.bfloat16)
        
        
        types = inputs.pop("loss_type")
        
        labels = inputs.get("labels")
        neg_idxs = (types == 1)
        pos_idxs = (types == 2)
        normal_idxs = (types == 0)
        
        loss, logits = compute_crossentropy_loss(model, inputs, mean=False, return_logits=True)
        
        loss_sft = loss[normal_idxs].mean()
        
        attention_mask = inputs['attention_mask']
        prompt_attention_mask = attention_mask.clone()
        prompt_attention_mask[labels != -100] = 0.
        
        log_probs = self.compute_logps(prompt_attention_mask, inputs['input_ids'], attention_mask, logits)
        
        pos_prob = log_probs[pos_idxs]
        neg_prob = log_probs[neg_idxs]
        log_odds = (pos_prob - neg_prob) - (torch.log(1 - torch.exp(pos_prob)) - torch.log(1 - torch.exp(neg_prob)))
        sig_ratio = torch.nn.functional.sigmoid(log_odds)
        ratio = torch.log(sig_ratio)
        
        loss = loss_sft - self.alpha * torch.mean(ratio)
        
        
        self.log({'Positive Geometric Mean': torch.mean(pos_prob).item(),
                   'Negative Geometric Mean': torch.mean(neg_prob).item(),
                   'Log Odds Ratio': torch.mean(ratio).item(),
                   'Log Odds': torch.mean(log_odds).item()})
        
        return loss
    
    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
        return SequentialSampler(self.train_dataset)

class GD_Trainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def compute_loss(self, model, inputs, return_outputs=False):
        
        loss = compute_crossentropy_loss(model, inputs, mean=False)
        
        return loss
        
    # def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
    #     return SequentialSampler(self.train_dataset)

class AlignSFT_Trainer(Trainer):
    def __init__(self, alpha, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.alpha = alpha # alpha is the weight for the alignment loss
    
    def compute_loss(self, model, inputs, return_outputs=False):
        raise Exception("loss bug Not fixed")
        types = inputs.pop("loss_type")
        outputs = model(**inputs)
        align_idxs = (types == 2)
        normal_idxs = (types == 0)
        
        loss = compute_crossentropy_loss(model, inputs, mean=False)
        
        loss_align = loss[align_idxs].mean() * self.alpha
        loss_normal = loss[normal_idxs].mean()
        
        adjusted_loss = loss_align + loss_normal
        
        self.log({'loss_align': loss_align.item(), "loss_normal": loss_normal.item()})
        return (adjusted_loss, outputs) if return_outputs else adjusted_loss
    
    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
        return GroupRandomSampler(self.train_dataset)

class GA_GD_Trainer(Trainer):
    def __init__(self, pretrain_model=None, theta_GA=1.0, theta_GD=1.0, eos_index=None, **kwargs):
        super().__init__(**kwargs)
        device = self.accelerator.device
        self.theta_GA = theta_GA
        self.theta_GD = theta_GD
        self.eos_index = eos_index
        
    def compute_loss(self, model, inputs, return_outputs=False):
        types = inputs.pop("loss_type")
        outputs = model(**inputs)
        ga_idxs = (types == 1)
        gd_idxs = (types == 0)
        
        loss_origin, loss = compute_crossentropy_loss(model, inputs, mean=False, eos_token_id=self.eos_index)
        
        loss_ga = loss[ga_idxs].mean() * -self.theta_GA
        loss_gd_normal = loss_origin[gd_idxs].mean() * self.theta_GD
        
        adjusted_loss = loss_ga + loss_gd_normal
        
        self.log({'loss_ga_harmful': loss_ga.item(), "loss_gd_normal": loss_gd_normal.item()})
        return (adjusted_loss, outputs) if return_outputs else adjusted_loss
    
    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
        # return SequentialSampler(self.train_dataset)
        return RandomSampler(self.train_dataset)

class GA_GD_GD_Trainer(Trainer):
    def __init__(self, theta_GA=1.0, theta_GD=1.0, theta_KL=1.0, eos_index=None, ga_limit = None, **kwargs):
        super().__init__(**kwargs)
        self.theta_GA = theta_GA
        self.theta_GD = theta_GD
        self.theta_KL = theta_KL
        self.eos_index = eos_index
        self.ga_limit = ga_limit
        # self.prev_GA_loss = 0
        
    def compute_loss(self, model, inputs, return_outputs=False):
        types = inputs.pop("loss_type")
        outputs = model(**inputs)
        labels = inputs.get("labels")
        # 0 for benign query+benign response (KL), 1 for harmful query+harmful response (GA), 2 for harmful query+benign response (GD)
        # print(f'input type:{type(inputs)}')
        
        # approach 1:
        # this could lead to problem when using with gradient checkpointing and zero2/3 (gradient reduction twice)
        # ga_idxs = (types == 1)
        # gd_idxs = (types == 2) | (types == 0)
        # ga_inputs = {k: v[ga_idxs] for k, v in deepcopy(inputs).items()}
        # gd_inputs = {k: v[gd_idxs] for k, v in deepcopy(inputs).items()}
        # loss_ga = compute_crossentropy_loss(model, ga_inputs) * -self.theta_GA
        # loss_gd = compute_crossentropy_loss(model, gd_inputs) * self.theta_GD
        # adjusted_loss = loss_ga + loss_gd
        # print('=' * 5 + 'Approach 1' + '=' * 5)
        # print(f'loss_ga: {loss_ga.item()}')
        # print(f'loss_gd: {loss_gd.item()}')
        # print('=' * 5 + 'Approach 1' + '=' * 5) 


        # approach 2:
        ga_idxs = (types == 1)
        harmful_gd_idxs = (types == 2)
        normal_gd_idxs = (types == 0)
        # print(ga_idxs, harmful_gd_idxs, normal_gd_idxs)
        loss_origin = compute_crossentropy_loss(model, inputs, mean=False)
        # print(eos_loss_origin)
        # print(ga_idxs)
        loss_ga = get_subbatch_loss(loss_origin, ga_idxs, labels, self.ga_limit)  * -self.theta_GA
        real_loss_ga = get_subbatch_loss(loss_origin, ga_idxs, labels)  * -self.theta_GA

        loss_harmful_gd = get_subbatch_loss(loss_origin, harmful_gd_idxs, labels) * self.theta_GD
        loss_normal_gd = get_subbatch_loss(loss_origin, normal_gd_idxs, labels) * self.theta_KL
        # adjusted_loss = loss_ga + loss_harmful_gd + loss_normal_gd
        # print('=' * 5 + 'Approach 2' + '=' * 5) 
        # print(f'loss_ga: {loss_ga.item()}')
        # print(f'loss_gd: {loss_gd.item()}')
        # print('=' * 5 + 'Approach 2' + '=' * 5) 
        
        # print(f'ga_idxs:{ga_idxs}')
        # print(f'gd_idxs:{gd_idxs}')
        # print(f'inputs:{inputs["labels"][:, 40:60]}')
        # print(f'ga_inputs labels size:{ga_inputs["labels"].size()}')
        # print(f'gd_inputs labels size:{gd_inputs["labels"].size()}')
        
        if np.isnan(loss_harmful_gd.cpu().detach().numpy()):
            adjusted_loss = loss_ga + loss_normal_gd
        elif np.isnan(loss_ga.cpu().detach().numpy()):
            adjusted_loss = loss_harmful_gd + loss_normal_gd
        elif np.isnan(loss_normal_gd.cpu().detach().numpy()):
            adjusted_loss = loss_harmful_gd + loss_ga
        else:
            adjusted_loss = loss_ga + loss_harmful_gd + loss_normal_gd
        # adjusted_loss = loss_ga + loss_normal_gd
        
        self.log({'loss_ga': real_loss_ga.item(), 'loss_normal_gd': loss_normal_gd.item(), 'loss_harmful_gd': loss_harmful_gd.item()})

        # assert return_outputs is False
        return (adjusted_loss, outputs) if return_outputs else adjusted_loss
    
    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
        # return SequentialSampler(self.train_dataset)
        return GroupRandomSampler(self.train_dataset)

class GA_GD_KL_Trainer(Trainer):
    def __init__(self, pretrain_model=None, theta_GA=1.0, theta_GD=1.0, theta_KL=1.0, eos_index=None, **kwargs):
        super().__init__(**kwargs)
        device = self.accelerator.device
        pretrain_model.to(device)
        self.pretrain_model = pretrain_model
        self.theta_GA = theta_GA
        self.theta_GD = theta_GD
        self.theta_KL = theta_KL
        self.eos_index = eos_index
        # self.prev_GA_loss = 0
        
    def compute_loss(self, model, inputs, return_outputs=False):
        # print(inputs)
        types = inputs.pop("loss_type")
        # 0 for benign query+benign response (KL), 1 for harmful query+harmful response (GA), 2 for harmful query+benign response (GD)
        ga_idxs = (types == 1)
        gd_idxs = (types == 2)
        kl_idxs = (types == 0)
        
        # print(ga_idxs.sum(), gd_idxs.sum(), kl_idxs.sum())
        
        # ga_inputs = {k: v[ga_idxs] for k, v in inputs.items()}
        # gd_inputs = {k: v[gd_idxs] for k, v in inputs.items()}
        kl_inputs = {k: v[kl_idxs] for k, v in inputs.items()}
        
        # loss_origin, logits = compute_crossentropy_loss(model, inputs, mean=False, return_logits=True)
        loss_origin, eos_loss_origin, logits = compute_crossentropy_loss(model, inputs, mean=False, return_logits=True, eos_token_id=self.eos_index)
        
        loss_ga = eos_loss_origin[ga_idxs].mean() * -self.theta_GA
        loss_gd = loss_origin[gd_idxs].mean() * self.theta_GD

        if kl_idxs.sum() == 0:
            loss_kl = torch.tensor(0., device=self.accelerator.device)
        else:
            loss_kl = compute_kl(self.pretrain_model, logits[kl_idxs], kl_inputs)
            loss_kl = loss_kl * self.theta_KL
        
        adjusted_loss = loss_ga + loss_gd + loss_kl
        
        self.log({'loss_ga': loss_ga.item(), 'loss_gd': loss_gd.item(), 'loss_kl': loss_kl.item()})
        
        # assert return_outputs is False
        return adjusted_loss
    
    # def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
    #     return SequentialSampler(self.train_dataset)

def get_subbatch_loss(loss, idxs, labels, limit=None):
    loss = loss[idxs]
    labels = labels[idxs]
    shift_labels = labels[..., 1:].contiguous()
    valid_counts = (shift_labels != -100).float().sum(dim=1)
    # print(loss.shape, labels.shape)
    if limit is not None:
        # Create a mask for instances where loss divided by valid count is greater than ga_limit
        instance_loss = loss.sum(dim=1) / valid_counts
        mask = instance_loss <= limit
        # print(instance_loss, mask)
        loss = loss[mask]
        valid_counts = valid_counts[mask]
        
    loss = loss.sum() / valid_counts.sum()
    # print("loss", loss)
    return loss

def compute_crossentropy_loss(model, batch, mean=True, return_logits=False, eos_token_id=None):
    outputs = model(**batch)
    # print(outputs.logits.size())
    
    logits = outputs.logits
    labels = batch.get("labels")
    
    loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    loss = loss_fct(
        shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1)
    )
    # valid_counts = (shift_labels != -100).sum(dim=-1).float()

    loss = loss.view(shift_logits.size(0), -1)

    # loss = loss.sum(dim=-1) / valid_counts
    
    # valid_counts = (shift_labels != -100).sum().float()
    # loss = loss.sum() / valid_counts
    
    # print(f'manual loss:{loss.mean()}')
    
    # loss_fct = torch.nn.CrossEntropyLoss()
    # shift_logits = logits[..., :-1, :].contiguous()
    # shift_labels = labels[..., 1:].contiguous()
    # loss = loss_fct(
    #     shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1)
    # )
    
    # print(f'auto loss:{loss.mean()}')
    
    if eos_token_id:
        
        eos_mask = (shift_labels != eos_token_id)
        # eos_mask_unsqueezed = eos_mask.unsqueeze(-1)
        # eos_mask_broadcasted = eos_mask_unsqueezed.expand(-1, -1, shift_logits.size(-1)) 
        
        # shift_logits = shift_logits[eos_mask_broadcasted].reshape(logits.size(0), logits.size(1)-2, logits.size(2)).contiguous()
        # shift_labels = shift_labels[eos_mask].reshape(eos_mask.size(0), eos_mask.size(1)-1).contiguous()
        
        # print(shift_logits.size())
        # print(shift_labels.size())
        
        eos_loss = loss_fct(
            shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1)
        )
        
        # print(eos_mask[0].sum(), eos_mask[0].size(0))
        
        # valid_counts = ((shift_labels != -100).sum(dim=-1).float() + (shift_labels != eos_token_id).sum(dim=-1).float()) / 2

        eos_loss = eos_loss.view(shift_logits.size(0), -1)

        # eos_loss = (eos_loss * eos_mask).sum(dim=-1)
        
        # print(loss.size(), eos_loss.size())
        if not return_logits:
            return loss, eos_loss
        else:
            return loss, eos_loss, logits
        
    if not return_logits:
        if mean:
            # return loss.mean()
            valid_counts = (shift_labels != -100).sum().float()
            loss = loss.sum() / valid_counts
            return loss
        else:
            return loss
    else:
        if mean:
            valid_counts = (shift_labels != -100).sum().float()
            loss = loss.sum() / valid_counts
            # return loss.mean(), logits
            return loss, logits
        else:
            return loss, logits

def compute_kl(pretrained_model, logits, batch):
    """
    Compute *forward* KL as the normal utility loss.

    Args:
        pretrained_model: reference model which is the pretrained (original) model.
        logits: The logits given by current unlearning model.
        batch: A batch of normal data.

    Returns:
       The KL loss.
    """
    input_ids = batch["input_ids"]
    labels = batch["labels"]
    
    with torch.no_grad():
        pretrained_outputs = pretrained_model(
            input_ids=input_ids,
            labels=labels,
        )
    
    # P: pretrained model; Q: current model.
    prob_p = torch.nn.functional.softmax(pretrained_outputs.logits, -1)
    prob_q = torch.nn.functional.softmax(logits, -1)
    
    token_count = (labels != -100).sum()

    loss = -(prob_p * torch.log(prob_q + 1e-12)).masked_fill((labels == -100).unsqueeze(-1), 0).sum() / token_count

    return loss
