import torch.nn as nn
import torch
import torch.nn.functional as F
import math

class Bitparm(nn.Module):
    '''
    save params
    '''
    def __init__(self, num_latent_tokens, channel, final=False):
        super(Bitparm, self).__init__()
        self.final = final
        self.h = nn.Parameter(torch.nn.init.normal_(torch.empty(1, num_latent_tokens, channel), 0, 0.01))
        self.b = nn.Parameter(torch.nn.init.normal_(torch.empty(1, num_latent_tokens, channel), 0, 0.01))
        if not final:
            self.a = nn.Parameter(torch.nn.init.normal_(torch.empty(1, num_latent_tokens, channel), 0, 0.01))
        else:
            self.a = None

    def forward(self, x):
        if self.final:
            return torch.sigmoid(x * F.softplus(self.h) + self.b)
        else:
            x = x * F.softplus(self.h) + self.b
            return x + torch.tanh(x) * torch.tanh(self.a)

class BitEstimator(nn.Module):
    '''
    Estimate bit
    '''
    def __init__(self, num_latent_tokens, channel):
        super(BitEstimator, self).__init__()
        self.f1 = Bitparm(num_latent_tokens, channel)
        self.f2 = Bitparm(num_latent_tokens, channel)
        self.f3 = Bitparm(num_latent_tokens, channel)
        self.f4 = Bitparm(num_latent_tokens, channel, True)

    def forward(self, x):
        x = self.f1(x)
        x = self.f2(x)
        x = self.f3(x)
        pdf =  self.f4(x)
        """
        https://github.com/XingtongGe/PreprocessingICM/blob/main/mmdetection_toward/mmdet/models/backbones/towards/ImageCompression/model.py#L47
        压缩任务中需要对latents进行量化，这里考虑连续形式，直接预测边缘分布PDF，本质上total_bit是NLL
        """
        total_bits = torch.sum(torch.clamp(-torch.log(pdf + 1e-5) / math.log(2.0), 0, 50))
        return total_bits

class GaussianBitEstimator(nn.Module):
    def __init__(self, mu=0, sigma=1):
        super().__init__()
        self.gaussian = torch.distributions.Normal(mu, sigma)

    def forward(self, x):
        pdf = torch.exp(self.gaussian.log_prob(x))
        total_bits = torch.sum(torch.clamp(-torch.log(pdf + 1e-5) / math.log(2.0), 0, 50))
        return total_bits