import torch
from trl import SFTTrainer

try:
    from trl import SFTConfig
except ImportError:
    from transformers import TrainingArguments as SFTConfig

from losses import get_loss_fn


class DistillationTrainingArguments(SFTConfig):
    def __init__(self, *args, alpha=1.0, beta=0.0, temperature=1.0, distill_loss="kl", **kwargs):
        super().__init__(*args, **kwargs)
        self.alpha = alpha
        self.beta = beta
        self.distill_loss = distill_loss
        self.temperature = temperature

class DistillationTrainer(SFTTrainer):
    def __init__(self, *args, teacher_model=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher = teacher_model
        if self.teacher is not None:
            self.teacher.eval()
            self.teacher.requires_grad_(False)
        self.distill_loss_fn = get_loss_fn(self.args.distill_loss)

    def train(self, *args, **kwargs):
        if self.teacher is not None:
            self._move_model_to_device(self.teacher, self.args.device)
        return super().train(*args, **kwargs)

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        loss_log={}
        student_outputs = model(**inputs)
        loss = 0
        # import IPython;IPython.embed()
        if self.args.alpha:        
            loss = self.args.alpha * student_outputs.loss
            if loss.ndim > 0:
                loss = loss.mean()
            loss_log["ce_loss"] = round(loss.item(), 4)

        if self.args.beta:
            with torch.no_grad():
                # import IPython;IPython.embed()
                # import sys;sys.exit()
                teacher_outputs = self.teacher(**inputs)

            assert student_outputs.logits.size() == teacher_outputs.logits.size()
            
            distill_loss = self.distill_loss_fn(
                logits=student_outputs.logits,
                teacher_logits=teacher_outputs.logits,
                temperature=self.args.temperature
            )
            
            loss += self.args.beta * distill_loss
            loss_log["distill_loss"] = round(distill_loss.item(), 4)

        self.log(loss_log)
        return (loss, student_outputs) if return_outputs else loss
