import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import deque

from common.modules.classifier import Classifier

class CrossEntropyLabelSmooth(nn.Module):
    """Cross entropy loss with label smoothing regularizer.
    Reference:
    Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
    Equation: y = (1 - epsilon) * y + epsilon / K.
    Args:
        num_classes (int): number of classes.
        epsilon (float): weight.
    """

    def __init__(self, num_classes, epsilon=0.1, use_gpu=True, reduction=True):
        super(CrossEntropyLabelSmooth, self).__init__()
        self.num_classes = num_classes
        self.epsilon = epsilon
        self.use_gpu = use_gpu
        self.reduction = reduction
        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, inputs, targets):
        """
        Args:
            inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
            targets: ground truth labels with shape (num_classes)
        """
        log_probs = self.logsoftmax(inputs)
        targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1)
        if self.use_gpu: targets = targets.cuda()
        targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
        loss = (- targets * log_probs).sum(dim=1)
        if self.reduction:
            return loss.mean()
        else:
            return loss

class ProtoLoss(nn.Module):
    def __init__(self, nav_t=1, s_par = 0.5, n_cache=2, assign_type='prob', cost_type='cos', balance_type='proto'):
        super(ProtoLoss, self).__init__()
        self.nav_t = nav_t
        self.s_par = s_par
        self.assign_type = assign_type
        self.cost_type = cost_type
        self.balance_type = balance_type
        self.cache = deque(iterable=[], maxlen=n_cache)

    def pairwise_cosine_dist(self, x, y):
        x = F.normalize(x, p=2, dim=1)
        y = F.normalize(y, p=2, dim=1)
        return 1 - torch.matmul(x, y.T)

    def forward(self, mu_s, f_t, bias=None, norm=False):
        if bias is None:
            bias = 0
        else:
            bias = bias.unsqueeze(1)
        if self.assign_type == "prob":
            sim_mat = torch.matmul(mu_s, f_t.T)
            real_dist = F.softmax(sim_mat/self.nav_t+bias, dim=0)
        elif self.assign_type == "hard":
            sim_mat = torch.matmul(mu_s, f_t.T)
            labels = torch.argmax(sim_mat/self.nav_t+bias, dim=0)
            real_dist = torch.zeros(mu_s.size(0), f_t.size(0)).to(labels.device).scatter_(0, labels.unsqueeze(0), 1)
        elif self.assign_type == 'cos':
            sim_mat = 1 - self.pairwise_cosine_dist(mu_s, f_t)
            real_dist = F.softmax(sim_mat/self.nav_t+bias, dim=0)

        if self.cost_type == 'cos':
            cost_mat = self.pairwise_cosine_dist(mu_s, f_t)
        elif self.cost_type == 'prob':
            sim_mat = torch.matmul(mu_s, f_t.T)
            cost_mat = 1 - F.softmax(sim_mat/self.nav_t+bias, dim=0)
        elif self.cost_type == 'logprob':
            sim_mat = torch.matmul(mu_s, f_t.T)
            cost_mat = - torch.log(F.softmax(sim_mat/self.nav_t+bias, dim=0)+1e-6)
        true_loss = (cost_mat*real_dist).sum(0).mean()
        norm_weight = torch.max(cost_mat, dim=0)[0]
        true_loss_normed = ((cost_mat*real_dist).sum(0)/norm_weight).mean()
        norm_weight = norm_weight.mean()

        if len(self.cache) > 0:
            cache = list(self.cache)
            f_cache = torch.cat(cache+[f_t], 0)
        else:
            f_cache = f_t
        sim_mat = torch.matmul(mu_s, f_cache.T)
        if self.balance_type == 'proto':
            fake_dist = F.softmax(sim_mat/self.nav_t+bias, dim=1)
            cost_mat = self.pairwise_cosine_dist(mu_s, f_cache)
            fake_loss = (cost_mat*fake_dist).sum(1).mean()
        elif self.balance_type == 'entropy':
            prob = F.softmax(sim_mat/self.nav_t+bias, dim=0)
            fake_loss = (prob.mean(1)*torch.log(prob.mean(1)+1e-6)).sum()
        loss = self.s_par*true_loss + (1-self.s_par)*fake_loss
        self.cache.append(f_t.detach())
        if norm:
            return loss, true_loss_normed, norm_weight
        return loss

class ProtoLoss_learnable(nn.Module):
    def __init__(self, init_mu=None, init_bias=None, num_classes=None, features_dim=None, nav_t=1, s_par = 0.5, n_cache=2, assign_type='prob', cost_type='cos'):
        super(ProtoLoss_learnable, self).__init__()
        if init_mu is None:
            self.head = nn.Linear(features_dim, num_classes, bias=True)
        else:
            self.head = nn.Linear(init_mu.size(1), init_mu.size(0), bias=True)
            self.head.weight.data = init_mu
            if init_bias is not None: self.head.bias.data = init_bias
        self.protoloss = ProtoLoss(nav_t, s_par, n_cache, assign_type, cost_type)

    def forward(self, g_s, f_s, g_t, f_t, label_s=None, label_t=None, d_label=None, training=True):
        if training:
            logist = self.head(f_s)
            loss = F.cross_entropy(logist, label_s)
            return loss
        else:
            return self.protoloss(self.head.weight.data, f_t, bias=self.head.bias.data)

class LinearCrossEntropy(nn.Module):
    def __init__(self, init_mu=None, init_bias=None, num_classes=None, features_dim=None):
        super().__init__()
        if init_mu is None:
            self.head = nn.Linear(features_dim, num_classes, bias=True)
        else:
            self.head = nn.Linear(init_mu.size(1), init_mu.size(0), bias=True)
            self.head.weight.data = init_mu
            if init_bias is not None: self.head.bias.data = init_bias

    def forward(self, g_s, f_s, g_t=None, f_t=None, label_s=None, label_t=None, d_label=None, training=True):
        """
        Args:
            inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
            targets: ground truth labels with shape (num_classes)
        """
        logist = self.head(f_s)
        if training:
            return F.cross_entropy(logist, label_s)
        else:
            return torch.mean((torch.argmax(logist, 1) == label_s).float())