import torch
import numpy as np
import ipdb

# infonce
def infonce_lower_bound(scores):
  device = scores.device
  pseudo_labels = torch.arange(scores.shape[0]).to(device)
  mi_3 = -1*torch.nn.functional.cross_entropy(input=scores, target=pseudo_labels) + torch.log(torch.tensor(scores.shape[0])) # torch的cross_entropy内部有负号，而想求的MI的lower bound（infoNCE）是相当于negative cross_entropy，所以这里*-1
  return mi_3


def club_upper_bound_my(scores):
  vclub_estimate = (torch.diag(scores) - scores.mean(dim=1)).mean()
  return vclub_estimate


def clamp(x, max=1e8, min=-1e8):
   x = torch.clamp(x, min=min, max=max)
   return x

# log-probabilities
def critic(mu, sigma, y):
  
  normalizer_term = torch.sum(-0.5 * (np.log(2. * np.pi) + 2. * torch.log(sigma) ), dim=1, keepdim=True) # [B, 1]
  sigma2 = sigma**2 # [B, d]
  x2_term = -torch.matmul(y**2, (1.0/(2*sigma2)).T ) # [B, B]
  mu2_term = -torch.sum(mu**2 / (2*sigma2), dim=1, keepdim=True ) # [B, 1], it will expand to [B, B] in the final line.
  cross_term = torch.matmul(y, (mu/sigma2).T ) # [B, B]

  normalizer_term = clamp(normalizer_term)
  x2_term = clamp(x2_term)
  mu2_term = clamp(mu2_term)
  cross_term = clamp(cross_term)


  log_prob = normalizer_term + x2_term + mu2_term + cross_term

  return log_prob