import torch
import torch.nn as nn

# 实现 CLUB 方法，用于估计 mutual information 上界
class CLUBSample(nn.Module):
    def __init__(self, args):
        super(CLUBSample, self).__init__()
        self.p_mu = nn.Sequential(nn.Linear(args.hidden_dim, args.hidden_dim),
                                  nn.ReLU(),
                                  nn.Linear(args.hidden_dim, args.hidden_dim))

        self.p_logvar = nn.Sequential(nn.Linear(args.hidden_dim, args.hidden_dim),
                                      nn.ReLU(),
                                      nn.Linear(args.hidden_dim, args.hidden_dim),
                                      nn.Tanh())
        self.to(args.device)

    def get_mu_logvar(self, x_samples):
        mu = self.p_mu(x_samples)
        logvar = self.p_logvar(x_samples)
        return mu, logvar

    def forward(self, x_samples, y_samples):
        mu, logvar = self.get_mu_logvar(x_samples)
        sample_size = x_samples.shape[0]
        random_index = torch.randperm(sample_size).long()
        positive = (-(mu - y_samples) ** 2 / logvar.exp() / 2. - logvar / 2.).sum(dim=1)
        negative = (-(mu - y_samples[random_index]) ** 2 / logvar.exp() / 2. - logvar / 2.).sum(dim=1)
        bound = (positive - negative).mean()
        return torch.clamp(bound / 2., min=0.0)

    # 用于训练 CLUB 网络，拟合条件高斯分布
    # “we leverage a neural network qϕ(e_s | e_z) to approximate p(e_s | e_z) by minimizing KL-divergence”
    def loglikeli(self, x_samples, y_samples):
        mu, logvar = self.get_mu_logvar(x_samples)
        llh = (-(mu - y_samples) ** 2 / logvar.exp() / 2. - logvar / 2.).sum(dim=1).mean()
        return llh

    def learning_loss(self, x_samples, y_samples):
        return -self.loglikeli(x_samples, y_samples)
    