# pylint: disable=C0114, C0301, R0913, C0303, C0115, C0116, R0914, E0402
import torch
import torch.nn.functional as F


class Criterion:
    def __init__(self, model) -> None:
        self.model = model
        # self.teacher = teacher

    def _kld_loss(self, logits_stu, logits_tea, token_mask, reduction):
        b, t, _ = logits_stu.shape
        logits_stu = logits_stu.reshape((-1, logits_stu.shape[-1])).float()
        logits_tea = logits_tea.reshape((-1, logits_tea.shape[-1])).float()
        loss = F.kl_div(
            F.log_softmax(logits_stu, dim=-1),
            F.softmax(logits_tea, dim=-1),
            reduction='none',
        ).sum(-1)
        loss = loss.reshape((b, t)) * token_mask.reshape((b, t))
        loss = loss.clamp(max=10, min=0)
        if reduction == 'mean':
            loss = loss.sum() / token_mask.sum()
        elif reduction == 'sum':
            loss = loss.sum()
        return loss

    def __call__(self, batch, global_step=-1, train=True):
        # speech
        self.model.input_modality = "speech"
        output = self.model(batch)
        # with torch.inference_mode():
        #     teacher_output = self.teacher(**batch['inputs'])

        # kd_loss = self._kld_loss(
        #         speech_output['logits_full'], 
        #         teacher_output['logits'].detach(), 
        #         speech_output['ast_mask'], 
        #         reduction='mean'
        #         )

        losses = {
            "loss_norm": output['loss_norm'],
            "loss_kd": output['loss_kd'],
        }
        for k in ['simul', 'full', 'trunc']:
            losses[f'loss_{k}'] = output[k]['loss']

        weights = {
            # "loss_kd": 0.5,
            "loss_simul": 0.2,
            "loss_full": 1,
            "loss_trunc": 0.2,
            "loss_kd": 0.2,
            "loss_norm": 0.01,
        }
        if train and losses['loss_simul'] > 4.6:
            weights['loss_trunc'] = 0.0


        for k, v in losses.items():
            if torch.isnan(v) or torch.isinf(v):
                print(f'{k} is nan')
                losses[k] = torch.tensor(0.0, device=v.device)

        total_loss = sum(losses[k] * weights[k] for k in losses.keys())

        logs = losses
        for k in ['simul', 'full', 'trunc']:
            logs[f'acc_{k}'] = output[k]['acc']

        
        logs['route_score'] = output['route_score']
        logs['loss_full_selected'] = output['loss_full_selected']
        # losses['loss_teacher'] = teacher_output['loss']

        return {
            'loss': total_loss, 
            'log': losses, 
            'bsz': batch['text'].shape[0]
        }
    