"""
=============
Dequantizers
=============

"""
import torch
from torch import nn
import torch.nn.functional as F
from torch_geometric.nn import global_add_pool

from _utils import *
from _convs import EGNN_output_h

class UniformDequantizer(nn.Module):
    """Graph Neural Net with global state and fixed number of nodes per graph.
    Args:
          hidden_dim: Number of hidden units.
          num_nodes: Maximum number of nodes (for self-attentive pooling).
          global_agg: Global aggregation function ('attn' or 'sum').
          temp: Softmax temperature.
    """

    def __init__(self):
        super(UniformDequantizer, self).__init__()

    def forward(self, tensor, node_mask, edge_mask, context):
        category, integer = tensor['categorical'], tensor['integer']
        zeros = torch.zeros(integer.size(0), device=integer.device)

        out_category = category + torch.rand_like(category) - 0.5
        out_integer = integer + torch.rand_like(integer) - 0.5

        if node_mask is not None:
            out_category = out_category * node_mask
            out_integer = out_integer * node_mask

        out = {'categorical': out_category, 'integer': out_integer}
        return out, zeros

    def reverse(self, tensor):
        categorical, integer = tensor['categorical'], tensor['integer']
        categorical, integer = torch.round(categorical), torch.round(integer)
        return {'categorical': categorical, 'integer': integer}


class VariationalDequantizer(nn.Module):
    def __init__(self, node_nf, agg='sum'):
        super().__init__()
        self.net_fn = EGNN_output_h(
            in_node_nf=node_nf, out_node_nf=node_nf*2, agg=agg
        )

    def sample_qu_xh(self, h, x, edges, batch):
        net_out = self.net_fn(h, x, edges, batch)
        mu, log_sigma = torch.chunk(net_out, chunks=2, dim=-1)

        eps = sample_gaussian(mu.size(), mu.device)
        log_q_eps = standard_gaussian_log_likelihood_tobatch(eps, batch)

        #jb: -- return to this --

        # assert global_add_pool(mu,batch).sum() < 1e-5 and \
        #        global_add_pool(log_sigma, batch).sum() < 1e-5, \
        #        'These parameters should be masked.'
        u, ldj = affine_tobatch(eps, mu, log_sigma, batch)
        log_qu = log_q_eps - ldj

        return u, log_qu

    def transform_to_partition_v(self, h_category, h_integer, u_category, u_integer, batch):
        u_category, ldj_category = sigmoid_no_mask(u_category, batch)
        u_integer, ldj_integer = sigmoid_no_mask(u_integer, batch)
        ldj = ldj_category + ldj_integer

        v_category = transform_to_hypercube_partition(h_category, u_category)
        v_integer = transform_to_hypercube_partition(h_integer, u_integer)
        return v_category, v_integer, ldj

    def forward(self, tensor, x, edges, batch):
        categorical, integer = tensor['categorical'], tensor['integer']

        h = torch.cat([categorical, integer], dim=1)

        n_categorical, n_integer = categorical.size(1), integer.size(1)

        u, log_qu_xh = self.sample_qu_xh(h, x, edges, batch)

        u_categorical = u[:,:n_categorical]
        u_integer = u[:, n_categorical:]

        v_categorical, v_integer, ldj = self.transform_to_partition_v(categorical, integer, u_categorical, u_integer, batch)
        log_qv_xh = log_qu_xh - ldj

        v = {'categorical': v_categorical, 'integer': v_integer}
        return v, log_qv_xh

    def reverse(self, tensor):
        categorical, integer = tensor['categorical'], tensor['integer']
        categorical, integer = torch.floor(categorical), torch.floor(integer)
        return {'categorical': categorical, 'integer': integer}


class ArgmaxAndVariationalDequantizer(VariationalDequantizer):
    def __init__(self, node_nf, agg='sum'):
        super().__init__(node_nf, agg)

    def transform_to_partition_v(self, h_category, h_integer, u_category, u_integer, batch):
        u_integer, ldj_integer = sigmoid_no_mask(u_integer, batch)
        v_category, ldj_category = transform_to_argmax_partition_no_mask(h_category,u_category, batch)
        ldj = ldj_category + ldj_integer
        v_integer = h_integer + u_integer
        return v_category, v_integer, ldj

    def reverse(self, tensor):
        categorical, integer = tensor['categorical'], tensor['integer']
        K = categorical.size(2)
        integer = torch.floor(integer)

        categorical = F.one_hot(torch.argmax(categorical, dim=-1), K)
        return {'categorical': categorical, 'integer': integer}