import contextlib
import torch
import torch.nn as nn
import torch.nn.functional as F

import logging

@contextlib.contextmanager
def _disable_tracking_bn_stats(model):

    def switch_attr(m):
        if hasattr(m, 'track_running_stats'):
            m.track_running_stats ^= True
            
    model.apply(switch_attr)
    yield
    model.apply(switch_attr)


def _l2_normalize(d):
    d_reshaped = d.view(d.shape[0], -1, *(1 for _ in range(d.dim() - 2)))
    d /= torch.norm(d_reshaped, dim=1, keepdim=True) + 1e-8
    return d

def target_distribution(batch):
    """
    Compute the target distribution p_ij, given the batch (q_ij), as in 3.1.3 Equation 3 of
    Xie/Girshick/Farhadi; this is used the KL-divergence loss function.

    :param batch: [batch size, number of clusters] Tensor of dtype float
    :return: [batch size, number of clusters] Tensor of dtype float
    """
    weight = (batch ** 2) / torch.sum(batch, 0)
    return (weight.t() / torch.sum(weight, 1)).t()


def init_weights(layer):
    layer_name = layer.__class__.__name__
    if layer_name.find("Conv2d") != -1 or layer_name.find("ConvTranspose2d") != -1:
        nn.init.kaiming_uniform_(layer.weight)
    elif layer_name.find("BatchNorm") != -1:
        nn.init.normal_(layer.weight, 1.0, 0.02)
    elif layer_name.find("Linear") != -1:
        nn.init.xavier_normal_(layer.weight)

class VATLoss(nn.Module):

    def __init__(self, xi=10.0, eps=1e-1, ip=1, repeat = 2, num_aug = 2):
        """VAT loss
        :param xi: hyperparameter of VAT (default: 10.0)
        :param eps: hyperparameter of VAT (default: 1.0)
        :param ip: iteration times of computing adv noise (default: 1)
        """
        super(VATLoss, self).__init__()
        self.xi = xi
        self.eps = eps
        self.ip = ip
        self.repeat = 2
        self.num_aug = num_aug

    def forward(self, encoder, assign, x, sens):
        with torch.no_grad():
            z, _, _ = encoder(x)
            pred = assign(z, sens)

        x_aug_list = list()
        
        for i in range(self.num_aug):
            # prepare random unit tensor
            d = torch.rand(x.shape).sub(0.5).to(x.device)
            d = _l2_normalize(d)

            x_aug = None

            for i in range(self.repeat):
                with _disable_tracking_bn_stats(encoder):
                    # calc adversarial direction
                    for _ in range(self.ip):
                        d.requires_grad_()
                        z_hat, _, _ = encoder(x + self.xi * d)
                        pred_hat = assign(z_hat, sens)

                        logp_hat = torch.log(pred_hat)
                        adv_distance = F.kl_div(logp_hat, pred, reduction='batchmean')
                        adv_distance.backward(retain_graph=True)
                        d = _l2_normalize(d.grad)

                        encoder.zero_grad()
                        assign.zero_grad()

                    # calc LDS
                    r_adv = d * self.eps
    #                 pred_hat = model(x + r_adv)
                if x_aug == None:
                    x_aug = x + r_adv
                else:
                    x_aug = torch.cat([x_aug, x + r_adv], dim = 0)
                    
                x_aug_list.append(x_aug)

        return torch.cat(x_aug_list, 0)
    
def JSD(p_x):
    debias = torch.zeros((p_x.shape[0], p_x.shape[0])).to(p_x.device)
    for row in range(p_x.shape[0]):
        with torch.no_grad():
            p_anchor = p_x[row]
            p_neg = p_x
            m = 0.5 * (p_anchor + p_neg)

            loss = 0.0
            loss += F.kl_div(torch.log(p_anchor.repeat(p_x.shape[0], 1)), m, reduction="none").sum(-1)
            loss += F.kl_div(torch.log(p_neg), m, reduction="none").sum(-1)

            debias[row] = loss
    
    debias *= 0.5

    return debias

def Contrastive_Loss(z, z_aug, p_x, temperature = 1):
    
    z = F.normalize(z, dim = 1)
    z_aug = F.normalize(z_aug, dim = 1)
    
    out_grid = torch.exp(torch.mm(z, z.t().contiguous()) / temperature)
    N_aug = int(z_aug.shape[0] / z.shape[0])


    neg = (out_grid * ~torch.eye(z.shape[0]).bool().to(z.device))
    neg[neg != neg] = 0

    pos = (torch.exp(torch.mm(z,z_aug.t() / temperature)) * torch.eye(z.shape[0]).bool().to(z.device).repeat(1,N_aug)).sum(-1)
#     print('pos', pos)
#     print('neg', neg)
    debias = JSD(p_x).detach()
    debias /= debias.max()
#     print('debias', debias)
    # contrastive loss
    loss = (- torch.log((pos + 1e-6) / (pos + (debias * neg).sum(-1) + 1e-6) )).mean()

    return loss


def setup_logger(logger_name, log_file, level=logging.INFO):
    l = logging.getLogger(logger_name)
    
    formatter = logging.Formatter('[%(asctime)s]  %(message)s')
    fileHandler = logging.FileHandler(log_file, mode='w')
    fileHandler.setFormatter(formatter)
    streamHandler = logging.StreamHandler()
    streamHandler.setFormatter(formatter)

    l.setLevel(level)
    l.addHandler(fileHandler)
    l.addHandler(streamHandler)  