from typing import *
import torch
import math
import torch.nn.functional as F
from transformers import Trainer
from ...extras.constants import IGNORE_INDEX

from torch.nn.modules.loss import _WeightedLoss
import torch.nn as nn

def fixed_cross_entropy(source, target, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs):
    reduction = "sum" if num_items_in_batch is not None else "mean"
    loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
    if reduction == "sum":
        loss = loss / num_items_in_batch
    return loss

class DeCE(_WeightedLoss):
    def __init__(self, weight: Optional[torch.Tensor] = None, size_average=None, ignore_index: int = None,
                reduce=None, reduction: str = 'mean', label_smoothing: float = 0.05, alpha_base: float = 0.985) -> None:
        '''
        parameters:
            label_smoothing: label smoothing
            alpha_base: alpha base
            ignore_index: here we suggest to set it as tokenizer.pad_token_id
        '''
        super().__init__(weight, size_average, reduce, reduction)
        self.ignore_index = ignore_index
        self.label_smoothing = label_smoothing
        self.alpha = 1
        self.alpha_base = alpha_base

    @staticmethod
    def _smooth_one_hot(targets: torch.Tensor, n_classes: int, smoothing=0.0):
        assert 0 <= smoothing < 1
        with torch.no_grad():
            smoothTarget = torch.empty(size=(targets.size(0), n_classes), device=targets.device).fill_(smoothing / (n_classes - 1))
            mask = (targets != IGNORE_INDEX)
            valiTarget = targets[mask]
            valiIdx = mask.nonzero(as_tuple=True)[0]
            smoothTarget[valiIdx] = smoothTarget[valiIdx].scatter(1, valiTarget.data.unsqueeze(1), 1. - smoothing)
            # smoothTarget = smoothTarget.scatter(1, targets.data.unsqueeze(1), 1. - smoothing)
            # targets = torch.empty(size=(targets.size(0), n_classes),
            #                     device=targets.device) \
            #     .fill_(smoothing / (n_classes - 1)) \
            #     .scatter_(1, targets.data.unsqueeze(1), 1. - smoothing)
        return smoothTarget
    
    def forward(self, input: torch.Tensor, target: torch.Tensor, cur_epoch: int) -> torch.Tensor:
        self.alpha = math.pow(self.alpha_base, cur_epoch)

        new_target = DeCE._smooth_one_hot(target, input.size(-1), self.label_smoothing)
        new_input = self.alpha * input + (1 - self.alpha) * new_target

        loss = fixed_cross_entropy(new_input, new_target, ignore_index=IGNORE_INDEX)
        
        return loss

class CausalDeCETrainer(Trainer):
    def __init__(
        self, 
        finetuning_args,
        processor,
        label_smoothing:Optional[float] = 0.05, 
        alpha_base:Optional[float] = 0.99,
        ignore_index: int = IGNORE_INDEX,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.processor = processor
        self.finetuning_args = finetuning_args
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

        self.lossFn = DeCE(label_smoothing=label_smoothing, alpha_base=alpha_base, ignore_index=ignore_index)

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        shiftLogits = logits[..., :-1, :].reshape(-1, self.model.config.vocab_size)
        shiftLabels = labels[..., 1:].reshape(-1)
        loss = self.lossFn.forward(shiftLogits, shiftLabels, self.state.epoch + 1)

        return (loss, outputs) if return_outputs else loss