import torch
import torch.nn as nn
import torch.nn.functional as F

class client_model_VIB(nn.Module):
    def __init__(self, args, dimZ=256, alpha=0, dataset = 'EMNIST'):
        # the dimension of Z
        super().__init__()

        self.alpha = alpha
        self.dimZ = dimZ
        self.device = args.device
        self.dataset = dataset

        if self.dataset == 'fmnist':
            self.n_cls = 10
            self.fc1 = nn.Linear(1 * 28 * 28, 1024)
            self.fc2 = nn.Linear(1024, 1024)
            self.fc3 = nn.Linear(1024, 2 * self.dimZ)
            self.fc4 = nn.Linear(self.dimZ, self.n_cls)
            self.weight_keys = [['fc1.weight', 'fc1.bias'],
                                ['fc2.weight', 'fc2.bias'],
                                ['fc3.weight', 'fc3.bias'],
                                ['fc4.weight', 'fc4.bias']]

        
    def gaussian_noise(self, num_samples, K):
        # works with integers as well as tuples

        return torch.normal(torch.zeros(*num_samples, K), torch.ones(*num_samples, K)).to(self.device)#返回一个正态分布，均值为0，方差为1

    def sample_prior_Z(self, num_samples):
        return self.gaussian_noise(num_samples=num_samples, K=self.dimZ)

    def encoder_result(self, encoder_output):
        mu = encoder_output[:, :self.dimZ]
        sigma = torch.nn.functional.softplus(encoder_output[:, self.dimZ:] - self.alpha)

        return mu, sigma

    def sample_encoder_Z(self, batch_size, encoder_Z_distr, num_samples):

        mu, sigma = encoder_Z_distr

        return mu + sigma * self.gaussian_noise(num_samples=(num_samples, batch_size), K=self.dimZ)

    def forward(self, batch_x, num_samples = 1):


        if self.dataset == 'fmnist':
            batch_size = batch_x.size()[0]
            # sample from encoder
            x = batch_x.view(-1, 1 * 28 * 28)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))

            encoder_output = self.fc3(x)
            encoder_Z_distr = self.encoder_result(encoder_output)
            to_decoder = self.sample_encoder_Z(batch_size=batch_size, encoder_Z_distr=encoder_Z_distr,
                                               num_samples=num_samples)
            decoder_logits = self.fc4(to_decoder)
        
        
        regL2R  = torch.norm(to_decoder)

        return encoder_Z_distr, decoder_logits, regL2R


    def weight_init(self):
        for m in self._modules:
            xavier_init(self._modules[m])


def KL_between_normals(q_distr, p_distr):
    mu_q, sigma_q = q_distr
    mu_p, sigma_p = p_distr    #Standard Deviation
    k = mu_q.size(1)

    mu_diff = mu_p - mu_q
    mu_diff_sq = torch.mul(mu_diff, mu_diff)
    logdet_sigma_q = torch.sum(2 * torch.log(torch.clamp(sigma_q, min=1e-8)), dim=1)
    logdet_sigma_p = torch.sum(2 * torch.log(torch.clamp(sigma_p, min=1e-8)), dim=1)

    fs = torch.sum(torch.div(sigma_q ** 2, sigma_p ** 2), dim=1) + torch.sum(torch.div(mu_diff_sq, sigma_p ** 2), dim=1)
    two_kl = fs - k + logdet_sigma_p - logdet_sigma_q
    return two_kl * 0.5


def xavier_init(ms):
    for m in ms :
        if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
            nn.init.xavier_uniform(m.weight,gain=nn.init.calculate_gain('relu'))
            m.bias.data.zero_()
