import torch
from torch import nn

from hvae_backbone.utils import SerializableModule

class ProductOfExperts(SerializableModule):
    def __init__(self):
        super(ProductOfExperts, self).__init__()

    def forward(self, x):
        mu_0, sigma_0, mu_1, sigma_1 = x.chunk(4, dim=1)
        MU_num = mu_0 * (sigma_1 ** 2) + mu_1 * (sigma_0 ** 2)
        den = sigma_0 ** 2 + sigma_1 ** 2
        MU = MU_num / den
        SIGMA = torch.sqrt(sigma_0**2 * sigma_1**2 / den) 
        return torch.cat([MU, SIGMA], dim=1)
    
    
    def serialize(self) -> dict:
        serialized = super().serialize()
        return serialized

    @staticmethod
    def deserialize(serialized: dict):
        return ProductOfExperts()
    

class ProductOfExpertsNoStd(SerializableModule):
    def __init__(self):
        super(ProductOfExpertsNoStd, self).__init__()

    def forward(self, x):
        mu_0, sigma_0, mu_1, sigma_1 = x.chunk(4, dim=1)
        MU_num = mu_0 * (sigma_1 ** 2) + mu_1 * (sigma_0 ** 2)
        den = sigma_0 ** 2 + sigma_1 ** 2
        MU = MU_num / den
        #SIGMA = torch.sqrt(sigma_0**2 * sigma_1**2 / den) 
        return MU
    
    
    def serialize(self) -> dict:
        serialized = super().serialize()
        return serialized

    @staticmethod
    def deserialize(serialized: dict):
        return ProductOfExpertsNoStd() 
    
    
class ImgChLayerNorm(nn.Module):
    def __init__(self, ch, eps=1e-03):
        super(ImgChLayerNorm, self).__init__()
        self.norm = nn.LayerNorm(ch, eps=eps)

    def forward(self, x):
        x = x.permute(0, 2, 3, 1)
        x = self.norm(x)
        x = x.permute(0, 3, 1, 2)
        return x
    
class SimpleGate(SerializableModule):
    def __init__(self, alpha = 0.5):
        super(SimpleGate, self).__init__()
        self.alpha = alpha

    def forward(self, tensor1, tensor2):
        return self.alpha * tensor1 + (1 - self.alpha) * tensor2
    
    
    def serialize(self) -> dict:
        serialized = super().serialize()
        serialized.update({'alpha': self.alpha})
        return serialized

    @staticmethod
    def deserialize(serialized: dict):
        return SimpleGate(serialized['alpha'])