from pixyz.distributions import Normal, Laplace, Categorical, Bernoulli
import torch
from torch import nn
from torch.nn import ModuleList, functional as F
from pixyz.utils import epsilon


class Inference(Normal):
    def __init__(self, enc, hidden_dim, z_dim, cond_var, var=["z"]):
        super().__init__(cond_var=cond_var, var=var, name="q")
        self.enc = ModuleList(enc)
        hidden_dim = sum(hidden_dim)
        self.mu = nn.Linear(in_features=hidden_dim, out_features=z_dim, bias=True)
        self.logvar = nn.Linear(in_features=hidden_dim, out_features=z_dim, bias=True)

    def forward(self, **x):
        h_list = [_enc(x[_cond_var]) for _enc, _cond_var in zip(self.enc, self.cond_var)]
        h = torch.cat(h_list, dim=1)
        mu = self.mu(h)
        logvar = self.logvar(h)
        mu = mu.view(mu.size(0), -1)
        logvar = logvar.view(logvar.size(0), -1)
        return {"loc": mu, "scale": F.softplus(logvar) + epsilon()}

class GenerationLaplace(Laplace):
    def __init__(self, dec, var=["x"], cond_var=["z"]):
        super().__init__(cond_var=cond_var, var=var, name="p")
        self.dec = dec

    def forward(self, **z):
        z = torch.cat([z[_cond_var] for _cond_var in self.cond_var], dim=1)
        return {"loc": self.dec(z), "scale": torch.tensor(0.75).to(z.device)}

class GenerationCategorical(Categorical):
    def __init__(self, dec, var=["x"], cond_var=["z"]):
        super().__init__(cond_var=cond_var, var=var, name="p")
        self.dec = dec

    def forward(self, **z):
        z = torch.cat([z[_cond_var] for _cond_var in self.cond_var], dim=1)
        return {"probs": self.dec(z)}

class GenerationBernoulli(Bernoulli):
    def __init__(self, dec, var=["x"], cond_var=["z"]):
        super().__init__(cond_var=cond_var, var=var, name="p")
        self.dec = dec

    def forward(self, **z):
        z = torch.cat([z[_cond_var] for _cond_var in self.cond_var], dim=1)
        return {"probs": self.dec(z)}

class CLF(Categorical):
    def __init__(self, z_dim, num_category):
        super().__init__(cond_var=["z"], var=["labels"], name="f")
        self.clf = nn.Linear(z_dim, num_category)

    def forward(self, z):
        return {"probs": F.softmax(self.clf(z), dim=1)}

class CLFBernoulli(Bernoulli):
    def __init__(self, z_dim, num_category):
        super().__init__(cond_var=["z"], var=["labels"], name="f")
        self.clf = nn.Linear(z_dim, num_category)

    def forward(self, z):
        return {"probs": torch.sigmoid(self.clf(z))}

class CLFCategorical(Categorical):
    def __init__(self, z_dim, num_category, y_dim):
        super().__init__(cond_var=["z"], var=["labels"], name="f")
        self.clf = nn.Linear(z_dim, num_category*y_dim)
        self.num_category = num_category
        self.y_dim = y_dim

    def forward(self, z):
        z = self.clf(z).reshape(-1, self.num_category, self.y_dim)
        return {"probs": torch.softmax(z, -1)}