import torch
from transformers import Trainer as SFTTrainer

from transformers import TrainingArguments as SFTConfig
from losses import get_loss_fn

class DistillationTrainingArguments(SFTConfig):
    def __init__(
            self,
            *args,
            alpha=0.0,
            beta=1.0,
            gamma=0.0,
            temperature=1.0,
            distill_loss="kl",
            **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.temperature = temperature
        self.distill_loss = distill_loss

class DistillationTrainer(SFTTrainer):
    def __init__(self, *args, teacher_model=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher = teacher_model
        self.teacher.eval()
        self.distill_loss_fn = get_loss_fn(self.args.distill_loss)

    def train(self, *args, **kwargs):
        self._move_model_to_device(self.teacher, self.args.device)
        return super().train(*args, **kwargs)

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        student_outputs = model(**inputs)
        loss_log={}
        loss = 0

        if self.args.alpha:
            ce_loss = student_outputs.loss.mean()
            loss += self.args.alpha * ce_loss
            loss_log["ce_loss"] = round(ce_loss.item(), 5)
        if self.args.beta:
            with torch.no_grad():
                teacher_outputs = self.teacher(**inputs)
            assert student_outputs.logits.size() == teacher_outputs.logits.size()
            distill_loss = self.distill_loss_fn(
                logits=student_outputs.logits,
                teacher_logits=teacher_outputs.logits,
                attention_mask=inputs["attention_mask"],
                temperature=self.args.temperature,
            )
            
            loss += self.args.beta * distill_loss

            loss_log["distill_loss"]=round(distill_loss.item(), 5)
        
        self.log(loss_log)
        return (loss, outputs) if return_outputs else loss
