# pylint: disable=C0114, C0301, R0913, C0303, C0115, C0116, R0914, E0402
import torch
import torch.nn.functional as F


class MTCriterion:
    def __init__(self, model) -> None:
        self.model = model

    def __call__(self, batch):
        self.model.input_modality = "text"
        output = self.model(**batch['inputs'])
        logged_output = {
            "mt_loss": output.loss,
            }
        return {
            'loss': output.loss, 
            'log': logged_output, 
            'bsz': batch['inputs'].input_ids.shape[0]
        }


class STCriterion:
    def __init__(self, model, teacher) -> None:
        self.model = model
        self.teacher = teacher

    def _kld_loss(self, logits_stu, logits_tea, token_mask, reduction):
        b, t, _ = logits_stu.shape
        logits_stu = logits_stu.reshape((-1, logits_stu.shape[-1])).float()
        logits_tea = logits_tea.reshape((-1, logits_tea.shape[-1])).float()
        loss = F.kl_div(
            F.log_softmax(logits_stu, dim=-1),
            F.softmax(logits_tea, dim=-1),
            reduction='none',
        ).sum(-1)
        loss = loss.reshape((b, t)) * token_mask.reshape((b, t))
        loss = loss.clamp(max=10, min=0)
        if reduction == 'mean':
            loss = loss.sum() / token_mask.sum()
        elif reduction == 'sum':
            loss = loss.sum()
        return loss

    def __call__(self, batch, global_step=-1):
        # speech
        self.model.input_modality = "speech"
        speech_output = self.model(**batch['inputs'])
        with torch.inference_mode():
            teacher_output = self.teacher(**batch['inputs'])

        kd_loss = self._kld_loss(
                speech_output['logits_full'], 
                teacher_output['logits'].detach(), 
                speech_output['ast_mask'], 
                reduction='mean'
                )

        losses = {
            "loss_kd": kd_loss,
            "loss_simul": speech_output['loss_simul'],
            "loss_full": speech_output['loss_full'],
            "loss_trunc": speech_output['loss_trunc'],
            "loss_norm": speech_output['loss_norm'],
        }

        weights = {
            "loss_kd": 0.5,
            "loss_simul": 0.2,
            "loss_full": 1,
            "loss_trunc": 0.2,
            "loss_norm": 0.01,
        }


        for k, v in losses.items():
            if torch.isnan(v) or torch.isinf(v):
                print(f'{k} is nan')
                losses[k] = torch.tensor(0.0, device=v.device)

        total_loss = sum(losses[k] * weights[k] for k in losses.keys())
        
        losses['route_score'] = speech_output['route_score']
        losses['loss_teacher'] = teacher_output['loss']

        return {
            'loss': total_loss, 
            'log': losses, 
            'bsz': batch['inputs']['ast_ids'].shape[0]
        }
    