from __future__ import absolute_import
from __future__ import print_function

import logging

import math
import torch
import torch.nn as nn
from torch.nn import functional as F


def _accuracy(output, target, topk=(1,)):
    maxk = max(topk)
    batch_size = target.size(0)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1))
    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


class ClassificationLoss(nn.Module):
    def __init__(self, args, topk=(1, 2, 3), size_average=True, reduce=True):
        super(ClassificationLoss, self).__init__()
        self._cross_entropy = torch.nn.CrossEntropyLoss(reduction='mean')
        self._topk = topk

    def forward(self, output_dict, target_dict):
        output = output_dict['output1']
        target = target_dict['target1']
        cross_entropy = self._cross_entropy(output, target)
        loss_dict = {
            "xe": cross_entropy,
        }
        with torch.no_grad():
            acc_k = _accuracy(output, target, topk=self._topk)
            for acc, k in zip(acc_k, self._topk):
                loss_dict["top%i" % k] = acc
        return loss_dict


class ClassificationLossNCVIBN(nn.Module):
    def __init__(self, topk=(1,), nce_weight=1.0, vibn_weight=1e0, decoder_kld_weight=0.0, annealed_iterations=1):
        super(ClassificationLossNCVIBN, self).__init__()
        self._topk = topk
        self._iteration_counter = 0
        self._alpha = vibn_weight
        self._beta = nce_weight
        self._delta = decoder_kld_weight
        self._annealed_iterations = annealed_iterations

    def forward(self, output_dict, target_dict):
        losses = {}
        alpha = self._get_alpha()
        beta = self._get_beta()
        delta = self._get_delta()

        target = target_dict['target1']
        prediction = output_dict['prediction']
        samples = prediction.shape[0] // target.shape[0]
        target_expanded = target.repeat_interleave(samples, dim=0)

        loss_xe = F.cross_entropy(prediction, target_expanded).mean()
        losses['vibn_loss'] = alpha * output_dict['vibn_loss']

        l_noise = output_dict['l_noise']
        losses['nce_loss'] = beta * ((F.softmax(l_noise, dim=1) * F.log_softmax(l_noise, dim=1)).sum(dim=1) \
            + math.log(l_noise.shape[1])).mean()

        losses['decoder_kld'] = delta * output_dict['decoder_kld']

        losses['total_loss'] =  loss_xe + losses['vibn_loss'] + losses['decoder_kld'] + losses['nce_loss'] 

        with torch.no_grad():
            prediction = prediction.unsqueeze(0).reshape(len(target), samples, -1)
            p = F.softmax(prediction, dim=2).mean(dim=1)
            losses['xe'] = - torch.log(p[range(p.shape[0]), target] + 1e-24).mean()
            acc_k = _accuracy(p, target, topk=self._topk)
            for acc, k in zip(acc_k, self._topk):
                losses["top%i" % k] = acc
        return losses

    def _get_alpha(self):
        if self.training:
            self._iteration_counter +=1
        return self._alpha * min(self._iteration_counter / self._annealed_iterations, 1.0)

    def _get_beta(self):
        return self._beta * min(self._iteration_counter / self._annealed_iterations, 1.0)

    def _get_delta(self):
        return self._delta * min(self._iteration_counter / self._annealed_iterations, 1.0)


class ClassificationLossVIBN(nn.Module):
    def __init__(self, topk=(1,), nce_weight=1.0, vibn_weight=1e0, annealed_iterations=1):
        super(ClassificationLossVIBN, self).__init__()
        self._topk = topk
        self._iteration_counter = 0
        self._alpha = vibn_weight
        self._beta = nce_weight
        self._annealed_iterations = annealed_iterations

    def forward(self, output_dict, target_dict):
        losses = {}
        alpha = self._get_alpha()
        beta = self._get_beta()

        target = target_dict['target1']
        prediction = output_dict['prediction']
        samples = prediction.shape[0] // target.shape[0]
        target_expanded = target.repeat_interleave(samples, dim=0)

        loss_xe = F.cross_entropy(prediction, target_expanded).mean()
        losses['vibn_loss'] = alpha * output_dict['vibn_loss']

        l_noise = output_dict['l_noise']
        losses['nce_loss'] = beta * ((F.softmax(l_noise, dim=1) * F.log_softmax(l_noise, dim=1)).sum(dim=1) \
            + math.log(l_noise.shape[1])).mean()

        losses['total_loss'] =  loss_xe + losses['vibn_loss'] + losses['nce_loss'] 

        with torch.no_grad():
            prediction = prediction.unsqueeze(0).reshape(len(target), samples, -1)
            p = F.softmax(prediction, dim=2).mean(dim=1)
            losses['xe'] = - torch.log(p[range(p.shape[0]), target] + 1e-24).mean()
            acc_k = _accuracy(p, target, topk=self._topk)
            for acc, k in zip(acc_k, self._topk):
                losses["top%i" % k] = acc
        return losses

    def _get_alpha(self):
        if self.training:
            self._iteration_counter +=1
        return self._alpha * min(self._iteration_counter / self._annealed_iterations, 1.0)

    def _get_beta(self):
        return self._beta * min(self._iteration_counter / self._annealed_iterations, 1.0)