import os
from pytorch_lightning.core.optimizer import LightningOptimizer
import torch
import pickle
from typing import Any, Dict
import pytorch_lightning as L
from torch.optim.optimizer import Optimizer
from transformers import AutoModelForCausalLM
from transformers.models.llama import LlamaForCausalLM
from transformers.optimization import get_scheduler
import torch.nn as nn
import torch.nn.functional as F
from peft import LoraConfig, get_peft_model
import copy
from deepspeed.ops.adam import DeepSpeedCPUAdam
from bitsandbytes.optim import PagedAdam as bsAdamW
from lightning.pytorch.utilities import grad_norm
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
import numpy as np


from ..peft_util import find_all_linear_names
from ..language_models import ContrastLLM, init_small_huggingface_llm, AssistedModel
from ..utils import NameTimer

def get_dtype(data_type):
    if data_type == 'bfloat16':
        return torch.bfloat16
    elif data_type == 'float16':
        return torch.float16
 
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

def NextTokenPredictionLoss(model, input_ids, attention_mask, labels, position_ids=None):
    outputs = model(
        input_ids=input_ids, 
        attention_mask=attention_mask, 
        labels=labels, 
        position_ids=position_ids
    )
    assert outputs.loss is not None, "Forget loss is None"
    return outputs.loss

def TokenNextTokenPredictionLoss(output, labels):
    shifted_labels = labels[..., 1:].contiguous()
    output = output[..., :-1, :].contiguous()
    loss_function = nn.CrossEntropyLoss(ignore_index=-100, reduction='none')
    loss = loss_function(output.transpose(-1,-2), shifted_labels).sum(dim=-1)
    return loss

 
def split_forget_retain(input_ids, attention_mask, labels=None, position_ids=None, retainlabels=None):
    if retainlabels is not None:
        # Split the batch into forget/retain
        forget_input_ids = input_ids[retainlabels == 0]
        forget_attention_mask = attention_mask[retainlabels == 0]
        forget_labels = labels[retainlabels == 0]
        forget_position_ids = None
        if position_ids is not None:
            forget_position_ids = position_ids[retainlabels]
 
        remember_input_ids = input_ids[retainlabels == 1] 
        remember_attention_mask = attention_mask[retainlabels == 1]
        remember_labels = labels[retainlabels == 1]
        remember_position_ids = None
        if position_ids is not None:
            remember_position_ids = position_ids[retainlabels == 1]

        return (
            (forget_input_ids, forget_attention_mask, forget_labels, forget_position_ids,),
            (remember_input_ids, remember_attention_mask, remember_labels, remember_position_ids)
        )
    else:
        return input_ids, attention_mask, labels, position_ids

    
def init_full_model(model_path, num_layer=0 ,data_type='bfloat16', **kwargs):
    with NameTimer("Init full model"):
        basellm = AutoModelForCausalLM.from_pretrained(
            model_path, torch_dtype=get_dtype(data_type),
            use_flash_attention_2=True, trust_remote_code=True,
        )
        if num_layer != 0: # Use the full model
            basellm = init_small_huggingface_llm( # Use the small model
                basellm.model.config,
                num_layer=num_layer,
                base_llm=basellm,
                device='cpu',
            )
        return basellm

def init_peft_model(model_path, Lora, baseoutdir, num_layer=0, data_type='bfloat16', **kwargs):
    with NameTimer("Init peft model"):
        basellm = init_full_model(model_path, num_layer, data_type)
        if num_layer != 0:
            basellm.save_pretrained(os.path.join(baseoutdir, 'fullmodel'))
        peftconfig = LoraConfig(
            r=Lora.r,
            lora_alpha=Lora.alpha,
            target_modules=find_all_linear_names(basellm), 
            lora_dropout=Lora.dropout,
            bias=Lora.bias, 
            task_type="CAUSAL_LM",
        )
        model = get_peft_model(basellm, peftconfig)
        return model

def init_assisted_model(model_path, num_layer=0, data_type='bfloat16', Lora=None, **kwargs):
    is_lora = (Lora is not None) and Lora.r != 0
    model = AssistedModel.init_from_basellm(model_path, assist_num_layer=num_layer, is_lora=is_lora, Lora=Lora)
    return model

class BaseModule(L.LightningModule):
    def __init__(self, 
        learning_rate=1e-5,
        lr_scheduler_type="linear",
        weight_decay=0.0,
        warmup_ratio=0.05,
        **kwargs,
    ) -> None:
        super().__init__()
        self.save_hyperparameters()
        self.weight_decay = weight_decay
    
    def log_train_loss(self, name, loss : torch.Tensor):
        self.log(name, loss.item(), on_step=True, prog_bar=True)

    def log_val_loss(self, name, loss : torch.Tensor):
        self.log(name, loss.item(), on_step=False, on_epoch=True, sync_dist=True, prog_bar=True, add_dataloader_idx=False)

    def log_losses(self, loss, forget_loss=None, remember_loss=None, stage='train'):
        log_func = self.log_train_loss if stage == 'train' else self.log_val_loss
        if "/" in stage:
            connector = "_"
        else:
            connector = "/"
        if loss is not None:
            log_func(f"{stage}{connector}loss", loss)
        if forget_loss is not None:
            log_func(f"{stage}{connector}forget", forget_loss)
        if remember_loss is not None:
            log_func(f"{stage}{connector}retain", remember_loss)

    def calculate_loss(self, **kwargs):
        raise NotImplementedError("calculate_loss method is not implemented.")
   
    def forget_loss_func(self, input_ids, attention_mask, labels=None, position_ids=None, **kwargs):
        raise NotImplementedError("forget_loss_func method is not implemented. ")

    def retain_loss_func(self, input_ids, attention_mask, labels=None, position_ids=None, **kwargs):
        raise NotImplementedError("retain_loss_func method is not implemented.")
    
    def on_before_optimizer_step(self, optimizer: Optimizer) -> None:

        if hasattr(optimizer, '_global_grad_norm'):
            self.log("train/grad_norm", optimizer._global_grad_norm, on_step=True, prog_bar=False, sync_dist=False)
        else:
            norm_type = 2
            norms = []
            # norm = grad_norm(self.model, norm_type)
            # for n, p in self.model.named_parameters():
            #     if p.grad is not None:
            #         print(n, "OKOKOK")

            # print("NormNORMNORM: ", norm.keys())
            for group in optimizer.param_groups:
                group = group['params']
                norms.extend(
                    [p.grad.data.norm(norm_type).item() for p in group if p.grad is not None]
                )
            if len(norms) != 0:
                self.log("train/grad_norm", np.mean(norms), on_step=True, prog_bar=False, sync_dist=False)
            else:
                print("Fuckwhy")

    def on_before_configure_optimizers(self, num_training_steps, is_deepspeed=False):
        self.warmup_steps = int(num_training_steps * self.hparams.warmup_ratio)
        self.num_training_steps = num_training_steps
        self.is_deepspeed = is_deepspeed
    
    def configure_optimizers(self):
        opt_model = self.model
        decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
        decay_parameters = [name for name in decay_parameters if "bias" not in name]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
                ],
                "weight_decay": self.hparams.weight_decay,
            },
            {
                "params": [
                    p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
                ],
                "weight_decay": 0.0,
            },
        ]

        print("Total number params: ", sum([len(x['params']) for x in optimizer_grouped_parameters]))
        # optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)
        if not self.is_deepspeed:
            optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)
        else:
            optimizer = DeepSpeedCPUAdam(self.model.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)

        scheduler = get_scheduler(
            name=self.hparams.lr_scheduler_type,
            optimizer=optimizer,
            num_warmup_steps=self.warmup_steps,
            num_training_steps=self.num_training_steps,
        )
        return [optimizer], [{
            "scheduler": scheduler, 
            "interval": "step", 
            "frequency": 1, 
            "name": "learningrate", 
        }]
    
    def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
        if hasattr(self, 'oracle_model'):
            for k in list(checkpoint['state_dict'].keys()):
                if 'oracle_model' in k:
                    checkpoint['state_dict'].pop(k)

        if not self.is_deepspeed: # We can call transformers save_pretrained
            checkpoint.pop('state_dict')
            checkpoint['model'] = self.model

class ForgetModule(BaseModule):
    """ This base module only applies training loss on forget data. """
    def __init__(
        self, 
        model_path,
        num_layer=0,
        data_type = 'bfloat16', # Use bfloat16 as the default training
        learning_rate=1e-5, 
        lr_scheduler_type="linear", 
        weight_decay=0.0, 
        warmup_ratio=0.1, 
        **kwargs
    ):
        super().__init__(learning_rate, lr_scheduler_type, weight_decay, warmup_ratio)
        self.data_type = data_type

        loraconf = kwargs.get('Lora', None)
        if kwargs.get('is_assist', False): # 
            self.model = init_assisted_model(model_path, num_layer, data_type, **kwargs)
        elif (loraconf is not None and loraconf.r != 0):
            baseoutdir = kwargs.get('baseoutdir')
            self.model = init_peft_model(model_path, loraconf, baseoutdir, num_layer, data_type)
        else:
            self.model = init_full_model(model_path, num_layer, data_type, **kwargs)
        
        self.learning_rate = learning_rate
        self.lr_scheduler_type = lr_scheduler_type
        self.weight_decay = weight_decay
        self.warmup_ratio = warmup_ratio
        self.loader_names = {
            0: 'val/forget',
            1: 'val/retain',
            2: 'val/perturb',
            3: 'val/paraphrase',
        }

    def forward(self, input_ids, attention_mask, labels=None, position_ids=None, **kwargs):    
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            labels=labels,
        )
        return outputs
    
    def calculate_loss(self, input_ids, attention_mask, labels=None, position_ids=None, **kwargs):
        loss = self.forget_loss_func(input_ids, attention_mask, labels, position_ids)
        return loss

    def training_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']
        loss = self.calculate_loss(input_ids, attention_mask, labels)
        self.log_losses(loss, forget_loss=loss, stage='train')
        try:
            last_lr = self.lr_schedulers().get_last_lr()[0]
        except Exception as e:
            last_lr = 0.0
        self.log(
            f"train/learning_rate", 
            last_lr, prog_bar=False,sync_dist=False
        )
        print(self.global_step)
        return loss

    def validation_step(self, batch, batch_idx, dataloader_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']
        forgetloss = NextTokenPredictionLoss(
            self, input_ids, attention_mask, labels
        ) #! We log the NLL loss for validation to check how the model remembers different data
        if dataloader_idx is not None:
            val_name = self.loader_names[dataloader_idx]
        else:
            val_name = "val"
        self.log_losses(loss=None, forget_loss=forgetloss, stage=val_name)

    def test_step(self, *args: Any, **kwargs: Any):
        return self.validation_step(*args, **kwargs)


class ForgetRetainModule(ForgetModule):
    """ This base module applies training on both forget data and retain data. """
    def __init__(self, model_path, num_layer=0, remember_weight=1., data_type='bfloat16', learning_rate=0.00001, lr_scheduler_type="linear", weight_decay=0, warmup_ratio=0.1, **kwargs):
        super().__init__(model_path, num_layer, data_type, learning_rate, lr_scheduler_type, weight_decay, warmup_ratio, **kwargs)

        self.remember_weight = remember_weight

    def calculate_loss(self, input_ids, attention_mask, labels=None, position_ids=None, retainlabels=None, **kwargs):
        (
            (forget_input_ids, forget_attention_mask, forget_labels, forget_position_ids,),
            (remember_input_ids, remember_attention_mask, remember_labels, remember_position_ids)
        ) = split_forget_retain(
            input_ids, attention_mask, labels, position_ids, retainlabels
        )

        if forget_input_ids.shape[0] == 0:
            forget_loss = torch.tensor(0.).to(input_ids.device)
        else:
            forget_loss = self.forget_loss_func(forget_input_ids, forget_attention_mask, forget_labels, forget_position_ids)
        if remember_input_ids.shape[0] == 0:
            remember_loss = torch.tensor(0.).to(input_ids.device)
        else:
            remember_loss = self.retain_loss_func(remember_input_ids, remember_attention_mask, remember_labels, remember_position_ids)

        loss = forget_loss + self.remember_weight * remember_loss
        return loss, forget_loss, remember_loss

    def training_step(self, batch, batch_idx, dataloader_idx=None):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']
        if 'retainlabels' in batch:
            retainlabels = batch['retainlabels']
        else:
            retainlabels = torch.zeros(input_ids.shape[0]).to(input_ids.device)

        loss, forget_loss, remember_loss = self.calculate_loss(
            input_ids, attention_mask, labels, retainlabels=retainlabels 
        )
        try:
            last_lr = self.lr_schedulers().get_last_lr()[0]
        except Exception as e:
            last_lr = 0.0
        self.log(f"train/learning_rate", 
            last_lr, prog_bar=False, sync_dist=False
        )
        self.log_losses(loss, forget_loss, remember_loss, stage='train')
        # rank_zero_info(f"Loss: {loss.item()} Grad: {str(loss.requires_grad)} Forget: {forget_loss.item()} Remember: {remember_loss.item()}")
        print(f"Loss: {loss.item()} Grad: {str(loss.requires_grad)} Forget: {forget_loss.item()} Remember: {remember_loss.item()}")
        return loss
  
## Below are the useful modules

################################################################################################################
### Remember modules, these are our method with assistant model
################################################################################################################
class UniformRememberModule(ForgetRetainModule):
    """ This module applies gradient descent on forget data and KL-uniform on forget data. """

    def forget_loss_func(self, input_ids, attention_mask, labels=None, position_ids=None, **kwargs):
        return NextTokenPredictionLoss(
            self, input_ids, attention_mask, labels, position_ids
        )

    def retain_loss_func(self, input_ids, attention_mask, labels=None, position_ids=None, **kwargs):
        outputs = self(input_ids=input_ids, attention_mask=attention_mask, labels=labels, position_ids=position_ids)
        logits = outputs.logits
        num_labels = logits.shape[-1]
        soft_outputs = nn.functional.softmax(logits, dim=-1).view(-1, num_labels)
        uniform_dist = torch.full_like(soft_outputs, 1.0 / logits.size(-1))
        kl_div = torch.nn.functional.kl_div(soft_outputs.log(), uniform_dist, reduction='batchmean')
        return kl_div
   
class GDAsentRememberModule(UniformRememberModule):
    """ This module applies gradient descent on forget data and graient ascent on forget data. """

    def remember_loss(self, input_ids, attention_mask, labels=None, position_ids=None, **kwargs):
        next_token_loss = NextTokenPredictionLoss(
            self, input_ids, attention_mask, labels, position_ids
        )
        return -1. * next_token_loss


################################################################################################################
### Forget modules, these are baseline methods
################################################################################################################

class GradAscentModule(ForgetModule):

    def forget_loss_func(self, input_ids, attention_mask, labels=None, position_ids=None, **kwargs):
        next_token_loss = NextTokenPredictionLoss(
            self, input_ids, attention_mask, labels, position_ids
        )
        return -1 * next_token_loss

class GradDiffModule(ForgetRetainModule):
    """ This module applies gradient ascent on forget data and gradient descent on retain data. """
    
    def forget_loss_func(self, input_ids, attention_mask, labels=None, position_ids=None, **kwargs):
        next_token_loss = NextTokenPredictionLoss(
            self, input_ids, attention_mask, labels, position_ids
        )
        return -1 * next_token_loss

    def retain_loss_func(self, input_ids, attention_mask, labels=None, position_ids=None, **kwargs):
        next_token_loss = NextTokenPredictionLoss(
            self, input_ids, attention_mask, labels, position_ids
        )
        return next_token_loss

class IDKModule(ForgetModule):
    """ This module applies gradient descent on forget data while the responses are all changed to IDK. """

    def forget_loss_func(self, input_ids, attention_mask, labels=None, position_ids=None, **kwargs):
        next_token_loss = NextTokenPredictionLoss(
            self, input_ids, attention_mask, labels, position_ids
        )
        return next_token_loss

class KLModule(ForgetRetainModule):
    """ This module applies gradient ascent on forget data and KL on retain data. """
    
    def __init__(self, model_path, num_layer=0, remember_weight=1, data_type='bfloat16', learning_rate=0.00001, lr_scheduler_type="linear", weight_decay=0, warmup_ratio=0.1, **kwargs):
        super().__init__(model_path, num_layer, remember_weight, data_type, learning_rate, lr_scheduler_type, weight_decay, warmup_ratio, **kwargs)
    
        self.oracle_model = AutoModelForCausalLM.from_pretrained(
            model_path, torch_dtype=get_dtype(data_type),
            use_flash_attention_2=True, trust_remote_code=True, device_map='cpu'
        )
        self.oracle_model.eval()
        self.oracle_model.requires_grad_(False)

    def forget_loss_func(self, input_ids, attention_mask, labels=None, position_ids=None, **kwargs):
        next_token_loss = NextTokenPredictionLoss(
            self, input_ids, attention_mask, labels, position_ids
        )
        return -1 * next_token_loss

    def retain_loss_func(self, input_ids, attention_mask, labels=None, position_ids=None, **kwargs):
        with torch.no_grad():
            retain_outputs = self.oracle_model(
                input_ids, labels=labels, attention_mask=attention_mask, position_ids=position_ids
            )
            retain_probs = F.log_softmax(retain_outputs.logits, dim=-1)
            retain_probs = retain_probs.view(-1, retain_outputs.logits.shape[-1])

        outputs = self(input_ids,labels=labels, attention_mask=attention_mask)
        probs = F.log_softmax(outputs.logits, dim=-1)
        probs = probs.view(-1, outputs.logits.shape[-1])
        retain_loss = nn.functional.kl_div(probs, retain_probs, reduction='batchmean', log_target=True)
        return retain_loss
            

class IDKGradDiffModule(ForgetRetainModule):
    """ This module applies gradient descent on forget data (changed resposne to IDK) and grad descent on retain data. """

    def forget_loss_func(self, input_ids, attention_mask, labels=None, position_ids=None, **kwargs):
        next_token_loss = NextTokenPredictionLoss(
            self, input_ids, attention_mask, labels, position_ids
        )
        return next_token_loss

    def retain_loss_func(self, input_ids, attention_mask, labels=None, position_ids=None, **kwargs):
        next_token_loss = NextTokenPredictionLoss(
            self, input_ids, attention_mask, labels, position_ids
        )
        return next_token_loss

class IDKKLModule(KLModule):
    """ This module applies gradient descent on forget data (changed resposne to IDK) and KL on retain data. """

    def forget_loss_func(self, input_ids, attention_mask, labels=None, position_ids=None, **kwargs):
        next_token_loss = NextTokenPredictionLoss(
            self, input_ids, attention_mask, labels, position_ids
        )
        return next_token_loss

class DPOModule(ForgetModule):
    """ This module applies DPO loss on forget data (changed resposne to IDK). """

    def __init__(self, model_path, num_layer=0, data_type='bfloat16', learning_rate=0.00001, lr_scheduler_type="linear", weight_decay=0, warmup_ratio=0.1, **kwargs):
        super().__init__(model_path, num_layer, data_type, learning_rate, lr_scheduler_type, weight_decay, warmup_ratio, **kwargs)

        self.oracle_model = AutoModelForCausalLM.from_pretrained(
            model_path, torch_dtype=get_dtype(data_type),
            use_flash_attention_2=True, trust_remote_code=True,
            device_map='cpu'
        )
        self.oracle_model.eval()
        self.oracle_model.requires_grad_(False)
        self.beta = 0.1

    def forget_loss_func(self, input_ids, labels, label_attention_mask, prefer_input_ids, prefer_labels, prefer_label_attention_mask, **kwargs):
        """ DPO Loss """

        origin_outputs = self(input_ids, attention_mask=label_attention_mask)
        prefer_outputs = self(prefer_input_ids, attention_mask=prefer_label_attention_mask)
        # print(origin_outputs.logits.shape, origin_outputs.logits.device, origin_outputs.logits.dtype)
        with torch.no_grad():
            oracle_origin_outputs = self.oracle_model(
                input_ids, attention_mask=label_attention_mask,
            )
            oracle_prefer_outputs = self.oracle_model(
                prefer_input_ids, attention_mask=prefer_label_attention_mask, 
            )
            oracle_origin_loss = -1 * TokenNextTokenPredictionLoss(oracle_origin_outputs.logits, labels)
            orcale_prefer_loss = -1 * TokenNextTokenPredictionLoss(oracle_prefer_outputs.logits, prefer_labels)
        
        origin_loss = -1 * TokenNextTokenPredictionLoss(origin_outputs.logits, labels)
        origin_prefer_loss = -1 * TokenNextTokenPredictionLoss(prefer_outputs.logits, prefer_labels)

        pi_logratios = origin_prefer_loss - origin_loss
        ref_logratios = orcale_prefer_loss - oracle_origin_loss

        loss = -F.logsigmoid(
            self.beta * (pi_logratios - ref_logratios)
        ).mean() * 2 / self.beta
        return loss

    def calculate_loss(self, input_ids, labels, attention_mask, prefer_input_ids, prefer_labels, prefer_attention_mask, **kwargs
    ):
        return self.forget_loss_func(
            input_ids, labels, attention_mask, prefer_input_ids, prefer_labels, prefer_attention_mask
        )

    def training_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels'] # origin response
        prefer_input_ids = batch['prefer_input_ids']
        prefer_attention_mask = batch['prefer_attention_mask']
        prefer_labels = batch['prefer_labels'] # prefer response
        
        loss = self.calculate_loss(
            input_ids, labels, attention_mask,
            prefer_input_ids, prefer_labels, prefer_attention_mask
        )
        self.log(f"train/learning_rate", 
            self.lr_schedulers().get_last_lr()[0], 
            prog_bar=True, sync_dist=False
        )
        self.log_losses(loss, forget_loss=loss, stage='train')
        return loss

class DPORetainModule(DPOModule):
    def __init__(self, model_path, num_layer=0, remember_weight=1, data_type='bfloat16', learning_rate=0.00001, lr_scheduler_type="linear", weight_decay=0, warmup_ratio=0.1, **kwargs):
        super().__init__(model_path, num_layer, data_type, learning_rate, lr_scheduler_type, weight_decay, warmup_ratio, **kwargs)

        self.remember_weight = remember_weight
    
    def calculate_loss(self, input_ids, labels, label_attention_mask, prefer_input_ids, prefer_labels, prefer_label_attention_mask, retainlabels, **kwargs):
        # We only need preferlabel for 
        forget_input_ids = input_ids[retainlabels == 0]
        forget_attention_mask = label_attention_mask[retainlabels == 0]
        forget_labels = labels[retainlabels == 0]
        forget_prefer_input_ids = prefer_input_ids[retainlabels == 0]
        forget_prefer_attention_mask = prefer_label_attention_mask[retainlabels == 0]
        forget_prefer_labels = prefer_labels[retainlabels == 0]

        retain_input_ids = input_ids[retainlabels == 1]
        retain_labels = labels[retainlabels == 1]
        retain_attention_mask = label_attention_mask[retainlabels == 1]

        if forget_input_ids.shape[0] == 0:
            forget_loss = torch.tensor(0.).to(input_ids.device)
        else:
            forget_loss = self.forget_loss_func(
                forget_input_ids, 
                forget_labels, 
                forget_attention_mask, 
                forget_prefer_input_ids, 
                forget_prefer_labels, 
                forget_prefer_attention_mask
            )
        if retain_input_ids.shape[0] == 0:
            remember_loss = torch.tensor(0.).to(input_ids.device)
        else:
            remember_loss = self.retain_loss_func(retain_input_ids, retain_labels, retain_attention_mask)

        loss = forget_loss + self.remember_weight * remember_loss
        return loss, forget_loss, remember_loss
    
    def training_step(self, batch, batch_idx, dataloader_idx=None):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels'] # origin response
        prefer_input_ids = batch['prefer_input_ids']
        prefer_attention_mask = batch['prefer_attention_mask']
        prefer_labels = batch['prefer_labels'] # prefer response
        retainlabels = batch['retainlabels']

        loss, forget_loss, remember_loss = self.calculate_loss(
            input_ids, labels, attention_mask, prefer_input_ids, prefer_labels, prefer_attention_mask, retainlabels
        )
        self.log_losses(loss, forget_loss, remember_loss, stage='train')
        return loss

class DPOKLModule(DPORetainModule):

    def retain_loss_func(self, input_ids, attention_mask, labels=None, position_ids=None, **kwargs):
        with torch.no_grad():
            retain_outputs = self.oracle_model(
                input_ids, labels=labels, attention_mask=attention_mask, position_ids=position_ids
            )
            retain_probs = F.log_softmax(retain_outputs.logits, dim=-1)
            retain_probs = retain_probs.view(-1, retain_outputs.logits.shape[-1])

        outputs = self(input_ids,labels=labels, attention_mask=attention_mask)
        probs = F.log_softmax(outputs.logits, dim=-1)
        probs = probs.view(-1, outputs.logits.shape[-1])
        retain_loss = nn.functional.kl_div(probs, retain_probs, reduction='batchmean', log_target=True)
        return retain_loss
 
class DPOGradDiffModule(DPORetainModule):

    def retain_loss_func(self, input_ids, attention_mask, labels=None, position_ids=None, **kwargs):
        next_token_loss = NextTokenPredictionLoss(
            self, input_ids, attention_mask, labels, position_ids
        )
        return -1 * next_token_loss


class NPOModule(KLModule):
    def __init__(self, model_path, num_layer=0, remember_weight=1, beta=0.1,data_type='bfloat16', learning_rate=0.00001, lr_scheduler_type="linear", weight_decay=0, warmup_ratio=0.1, **kwargs):
        super().__init__(model_path, num_layer, remember_weight, data_type, learning_rate, lr_scheduler_type, weight_decay, warmup_ratio, **kwargs)
        self.beta = beta

    def forget_loss_func(self, input_ids, attention_mask, labels, position_ids=None, **kwargs):
        with torch.no_grad():
            oracle_outputs = self.oracle_model(input_ids, attention_mask=attention_mask)
            oracle_logits = oracle_outputs.logits
            oracle_tokenloss = TokenNextTokenPredictionLoss(oracle_logits, labels)

        with torch.enable_grad():
            outputs = self(input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            tokenloss = TokenNextTokenPredictionLoss(logits, labels)
            log_ratio = tokenloss - oracle_tokenloss

            loss = -F.logsigmoid(self.beta * log_ratio).mean() * 2 / self.beta
            return loss

class NPOKLModule(NPOModule):
    #! It already has a retain loss function 
    pass

class NPOGradDiffModule(NPOModule):
    
    def retain_loss_func(self, input_ids, attention_mask, labels=None, position_ids=None, **kwargs):
        next_token_loss = NextTokenPredictionLoss(
            self, input_ids, attention_mask, labels, position_ids
        )
        return -1 * next_token_loss

loss_type_modules = {
    "remember_uniform": UniformRememberModule,
    "remember_gddiff": GDAsentRememberModule,
    "grad_ascent": GradAscentModule,
    "grad_diff": GradDiffModule,
    "kl": KLModule,
    "idk": IDKModule,
    "idk_kl": IDKKLModule,
    "idk_gddiff": IDKGradDiffModule,
    "dpo": DPOModule,
    "dpo_kl": DPOKLModule,
    "npo": NPOModule,
    "npo_kl": NPOKLModule,
    "npo_gddiff": NPOGradDiffModule,
}

def get_parameter_names(model, forbidden_layer_types):
    """
    Returns the names of the model parameters that are not inside a forbidden layer.
    """
    result = []
    for name, child in model.named_children():
        result += [
            f"{name}.{n}"
            for n in get_parameter_names(child, forbidden_layer_types)
            if not isinstance(child, tuple(forbidden_layer_types))
        ]
    # Add model specific parameters (defined with nn.Parameter) since they are not in any child.
    result += list(model._parameters.keys())
    return result

ALL_LAYERNORM_LAYERS = [nn.LayerNorm]