import torch
from trl import SFTTrainer

try:
    from trl import SFTConfig
except ImportError:
    from transformers import TrainingArguments as SFTConfig

# from losses import get_loss_fn



class RegularizedTrainingArguments(SFTConfig):
    def __init__(self, *args, alpha=0.0, **kwargs):
        super().__init__(*args, **kwargs)
        self.alpha = alpha
        


class RegularizedTrainer(SFTTrainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def train(self, *args, **kwargs):
        return super().train(*args, **kwargs)
    
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):        
        outputs_student = model(**inputs)
        loss_log={}
        loss = outputs_student.loss
        loss_log = {"CE Loss": round(loss.item(), 4)}
        if self.args.alpha:
            global_step = self.state.global_step
            max_steps = self.state.max_steps
            alpha = self.args.alpha
            model_for_params = getattr(self, "model_wrapped", model)
            reg_loss = []

            for name, param in model_for_params.named_parameters():
                if "adapt" in name:
                    frobenius_norm = torch.norm(param, p='fro', dim=[1, 2])**2
                    reg_loss.append(frobenius_norm.mean())
            reg_loss = torch.stack(reg_loss)
            reg_loss = reg_loss.mean()
            loss = loss + alpha * reg_loss
            loss_log["Reg Loss"] = round(reg_loss.item(), 4)
        
        self.log(loss_log)
        return (loss, outputs_student) if return_outputs else loss