
from typing import Tuple, List, Dict

import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss

from resnet import ResNet as torchvisionResNet
from resnet import Bottleneck


def torch_accuracy(
    output: torch.Tensor, target: torch.Tensor, topk=(1,)
) -> List[float]:


    topn = max(topk)
    batch_size = output.size(0)

    _, pred = output.topk(topn, 1, True, True)
    pred = pred.t()

    is_correct = pred.eq(target.view(1, -1).expand_as(pred))

    ans = []
    for i in topk:
        is_correct_i = is_correct[:i].reshape(-1).float().sum(0, keepdim=True)
        ans.append(is_correct_i.mul_(100.0 / batch_size).item())

    return ans


class LabelSmoothCrossEntropyLoss(_Loss):
    def __init__(self, eps=0.1, class_num=1000):
        super(LabelSmoothCrossEntropyLoss, self).__init__()

        self.min_value = eps / class_num
        self.eps = eps

    def __call__(self, pred: torch.Tensor, target: torch.Tensor):

        epses = self.min_value * torch.ones_like(pred)
        log_probs = F.log_softmax(pred, dim=1)

        if target.ndimension() == 1:
            target = target.expand(1, *target.shape)
            target = target.transpose(1, 0)
        target = torch.zeros_like(log_probs).scatter_(1, target, 1)
        target = target.type(torch.float)
        target = target * (1 - self.eps) + epses

        element_wise_mul = log_probs * target * -1.0

        loss = torch.sum(element_wise_mul, 1)
        loss = torch.mean(loss)

        return loss


class ResNet(torchvisionResNet):
    def __init__(self, loss_func=LabelSmoothCrossEntropyLoss(), *args, **kwargs):
        assert callable(loss_func)

        super(ResNet, self).__init__(*args, **kwargs)
        self.loss_func = loss_func

    def forward(self, x, label) -> Tuple[Dict[str, torch.Tensor], Dict[str, float]]:

        pred = super(ResNet, self).forward(x)
        loss = self.loss_func(pred, label)
        top1, top5 = torch_accuracy(pred, label, topk=(1, 5))
        loss_dict = {"loss": loss}
        output_dict = {"Loss": loss.item(), "Err1": 100 - top1, "Err5": 100 - top5}

        return loss_dict, output_dict


def create_network():
    return ResNet(LabelSmoothCrossEntropyLoss(0.1), Bottleneck, [3, 4, 6, 3])
