
# %%
import torch
from utils.probabilistic_model import HeterogeneousDistribution
from utils.likelihoods import get_likelihood
dim_list=[2, 4, 3]
distr_list=['normal', 'cat', 'ber']
lik_list = []
for dim, distr in zip(dim_list, distr_list):
    lik = get_likelihood(distr, dim)
    lik_list.append(lik)


params_list = [lik.params_size for lik in lik_list]
logits = torch.randn([2, sum(params_list)])


p = HeterogeneousDistribution(logits=logits,
                              dim_list=dim_list,
                              distr_list=distr_list,
                              lambda_kld=1.0)

print(p.mean)


# %%
samples = []

logits_list = torch.split(logits, split_size_or_sections=params_list, dim=1)
for distr_name, logits_i, lik_i in zip(distr_list, logits_list, lik_list):
    distr_i = lik_i(logits_i)
    sample_i = distr_i.sample()
    if  distr_name == 'cat':
        y_onehot = torch.FloatTensor(logits_i.shape)
        # In your for loop
        y_onehot.zero_()
        y_onehot.scatter_(1, sample_i.view(-1, 1), 1)
        sample_i = y_onehot
    samples.append(sample_i)

sample = torch.cat(samples, dim=1)

p.log_prob(sample)



# %%
import torch
from utils.probabilistic_model import ProbabilisticModelSCM
x_dim_list = [[1], [2,3], [3]]
distr_x_list = [['normal'], ['cat',]*2, ['normal',]*3]
likelihood = ProbabilisticModelSCM(x_dim_list=x_dim_list,
                          distr_x_list=distr_x_list,
                          lambda_kld=0.5)


print(likelihood.x_dim_list)
print(likelihood.distr_x_list)


num_graphs = 3
logits = torch.randn([num_graphs*likelihood.num_nodes, likelihood.embedding_size])


px = likelihood(logits, return_mean=False)