import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import math
import torch.nn.functional as F
import pdb
import utils_func
import ot
from typing import Optional

def Entropy(input_):
    bs = input_.size(0)
    epsilon = 1e-5
    entropy = -input_ * torch.log(input_ + epsilon)
    entropy = torch.sum(entropy, dim=1)
    return entropy 

class CrossEntropyLabelSmooth(nn.Module):
    """Cross entropy loss with label smoothing regularizer.
    Reference:
    Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
    Equation: y = (1 - epsilon) * y + epsilon / K.
    Args:
        num_classes (int): number of classes.
        epsilon (float): weight.
    """

    def __init__(self, num_classes, epsilon=0.1, use_gpu=True, reduction=True):
        super(CrossEntropyLabelSmooth, self).__init__()
        self.num_classes = num_classes
        self.epsilon = epsilon
        self.use_gpu = use_gpu
        self.reduction = reduction
        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, inputs, targets):
        """
        Args:
            inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
            targets: ground truth labels with shape (num_classes)
        """
        log_probs = self.logsoftmax(inputs)
        targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1)
        if self.use_gpu: targets = targets.cuda()
        targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
        loss = (- targets * log_probs).sum(dim=1)
        if self.reduction:
            return loss.mean()
        else:
            return loss
        return loss

class ProtoMultiDomainLoss(nn.Module):
    def __init__(self, num_domains: int, num_classes: int, ot_reg: float, args):
        super(ProtoMultiDomainLoss, self).__init__()
        self.num_domains = num_domains
        self.num_classes = num_classes
        self.ot_reg = ot_reg
        self.global_prop = F.normalize(torch.ones((num_domains, num_classes)), p=1, dim=1)
        self.momentum_schedule = utils_func.cosine_scheduler(args.beta, 1, args.epochs, args.iters_per_epoch)
        self.cur_iter = 0
        print("NUM_DOMAINS", self.num_domains) 
        print("NUM_CLASSES", self.num_classes)
    
    def get_local_prop(self, labels_d, num_classes):
        local_prop = torch.zeros(num_classes).to(labels_d.device)
        for label in labels_d:
            local_prop[label] += 1
        local_prop = F.normalize(local_prop, p=1, dim=0)
        return local_prop
     
    def pairwise_cosine_dist(self, x, y):
        x = F.normalize(x, p=2, dim=1)
        y = F.normalize(y, p=2, dim=1)
        return 1 - torch.matmul(x, y.T)
    
    def normalize(self, x):
        return x/sum(x)

    def forward(self, prototypes: torch.Tensor, f: torch.Tensor, labels: torch.Tensor, domain_ids: torch.Tensor) -> torch.Tensor:
        loss = 0.0
        num_domains = 0.0
        for domain in range(self.num_domains):
            idx = domain_ids == domain
            if sum(idx) == 0.0:
                continue 
            f_d = f[idx]
            labels_d = labels[idx]
            local_prop = self.get_local_prop(labels_d, self.num_classes).detach().cpu()
            cur_beta = self.momentum_schedule[self.cur_iter]           
            self.global_prop[domain] = cur_beta * self.global_prop[domain] + (1-cur_beta) * local_prop
            costs = self.pairwise_cosine_dist(f_d, prototypes)
            costs_ = costs.detach().cpu().to(torch.float64).numpy()
            a = self.normalize(torch.ones(costs.shape[0]).to(torch.float64).numpy())
            b = self.normalize(self.global_prop[domain,].detach().cpu().to(torch.float64).numpy())
            if self.ot_reg:
                T = torch.Tensor(ot.sinkhorn(a, b, costs_, self.ot_reg)).to(costs.device) 
            else:
                T = torch.Tensor(ot.emd(a, b, costs_)).to(costs.device)
            loss += (T * costs).sum()
            num_domains += 1
        loss = loss.mul(1.0/num_domains)
        self.cur_iter += 1
        return loss

