"""
Code build from https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial7/GNN_overview.html
"""
import torch.nn as nn
import torch.distributions as td
from utils.likelihoods import get_likelihood
import torch




class HeterogeneousDistribution(td.Distribution):
    def __init__(self, logits=None,dim_list=None, distr_list=None, lambda_kld=1.0, norm_categorical=False, validate_args=None):
        if logits is None:
            raise ValueError("`logits` must be specified.")
        assert isinstance(dim_list, list)
        assert isinstance(distr_list, list)


        self.norm_categorical = norm_categorical

        lik_list = []
        self.params_size_list = []
        for dim, distr in zip(dim_list, distr_list):
            lik = get_likelihood(distr, dim)
            if distr == 'delta':
                lik.set_lambda(lambda_kld)
            lik_list.append(lik)
            self.params_size_list.append(lik.params_size)

        logits_list = torch.split(logits, split_size_or_sections=self.params_size_list, dim=1)
        self.distr_list = []
        for lik_i, logits_i in zip(lik_list, logits_list):
            self.distr_list.append(lik_i(logits_i))

        self.distr_name_list = distr_list
        self.logits_list = logits_list
        self.lambda_kld=lambda_kld
        self.dim_list = dim_list
        self._param = logits
        batch_shape = self._param.size()
        super(HeterogeneousDistribution, self).__init__(batch_shape, validate_args=validate_args)

    @property
    def mean(self):
        means = []
        for i, distr in enumerate(self.distr_list):
            if self.distr_name_list[i] in ['cat', 'ber', 'cb']:
                means.append(distr.probs)
            else:
                means.append(distr.mean)
        return torch.cat(means, dim=1)

    def sample(self, sample_shape=torch.Size()):
        samples = []
        for i, distr in enumerate(self.distr_list):
            if self.distr_name_list[i] in ['cat']:
                sample_i = distr.sample(sample_shape)
                y_onehot = torch.FloatTensor(distr.probs.shape)
                # In your for loop
                y_onehot.zero_()
                y_onehot.scatter_(1, sample_i.view(-1, 1), 1)
                sample_i = y_onehot
                samples.append(sample_i)
            else:
                samples.append(distr.sample(sample_shape))

        return torch.cat(samples, dim=1)

    def rsample(self, sample_shape=torch.Size()):
        raise NotImplementedError()

    def log_prob(self, value):
        '''
        [num_graphs, total_dim_nodes]
        '''
        value_list = torch.split(value, split_size_or_sections=self.dim_list, dim=1)

        log_probs = []
        for distr_name, value_i, distr_i in zip(self.distr_name_list, value_list, self.distr_list):

            if distr_name in ['cat']:
                num_categories = value_i.shape[1]
                value_i = torch.argmax(value_i, dim=-1)
                log_prob_i = distr_i.log_prob(value_i).view(-1, 1)
                if self.norm_categorical: log_prob_i = log_prob_i / num_categories
                log_probs.append(log_prob_i)
            else:
                log_probs.append(distr_i.log_prob(value_i))

        return torch.cat(log_probs, dim=1)

class ProbabilisticModelSCM(nn.Module):
    def __init__(self, x_dim_list,  # [[2], [3,4], [3,4,5]], e.g. [node_1, node_2, node_3] Should be in the same order as in the dataset
                 distr_x_list, #  [['normal'], ['cat', 'cat'], ['normal',]*3]Should be the same size as x_dim_list
                 lambda_kld,
                 norm_categorical=False):
        super().__init__()

        flatten = lambda t: [item for sublist in t for item in sublist]

        self.num_nodes = len(x_dim_list)
        self.total_x_dim = sum([sum(x_dim_i) for x_dim_i in x_dim_list])
        self.node_dim_list = [sum(x_dim_i) for x_dim_i in x_dim_list]

        self.lambda_kld = lambda_kld

        self.norm_categorical = norm_categorical

        likelihood_node_params_size_list = []  # Size = num_nodes
        for node_i_list, distr_node_i_list in zip(x_dim_list, distr_x_list):
            likelihood_node_i_params_size = 0
            for x_dim_i, distr_x_i in zip(node_i_list, distr_node_i_list):
                likelihood_x = get_likelihood(distr_x_i, x_dim_i)
                likelihood_node_i_params_size += likelihood_x.params_size
            likelihood_node_params_size_list.append(likelihood_node_i_params_size)


        self.embedding_size = max(likelihood_node_params_size_list)
        self._decoder_embeddings = nn.ModuleList()
        for likelihood_node_params_size_i in likelihood_node_params_size_list:
            embed_i = nn.Linear(self.embedding_size, likelihood_node_params_size_i, bias=False)
            self._decoder_embeddings.append(embed_i)


        self.x_dim_list = flatten(x_dim_list)
        self.logits_dim = sum(likelihood_node_params_size_list)
        self.distr_x_list = flatten(distr_x_list)





    def forward(self, logits,  return_mean=False):
        '''
        [num_nodes, max_dim_node]
        '''
        d = logits.shape[1]
        logits_0  = logits.view(-1, self.num_nodes*d) # Num graphs, max_dim_node*num_nodes

        logits_list  = []
        for i, embed_i in enumerate(self._decoder_embeddings):
            logits_0i = logits_0[:, (i*d):((i+1)*d)]
            logits_i = embed_i(logits_0i)
            logits_list.append(logits_i)

        logits = torch.cat(logits_list, dim=-1)
        assert logits.shape[1] == self.logits_dim
        p = HeterogeneousDistribution(logits,
                                      dim_list=self.x_dim_list,
                                      distr_list=self.distr_x_list,
                                      lambda_kld=self.lambda_kld,
                                      norm_categorical=self.norm_categorical)
        if return_mean:
            return p.mean, p
        else:
            return p
