#
# Relational-Conditional VAE: 
# ---------------
# A generative model that incorporates relational inductive biases through use of 
# Graph Networks.
# 
#
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
from tf_gnns.graphnet_utils import make_full_graphnet_functions, make_graph_indep_graphnet_functions, make_graph_to_graph_and_global_functions 
from tf_gnns import GraphNet

from tensorflow.keras.layers import Input, Dense

DTYPE = "float32"

def make_repar_gaussian_mlp(repar_size, noise_scale_input = None, nsamples = 10, name = None, min_std = 1e-5):
    """
    Creates an MLP that expects reparametrizations of a distribution as input
    and uses `tensorflow_probability` layers.

    parameters:
    noise_scale_input : helps with training for good reconstruction if this is small in the beginning of training.
    repar_size        : the size of the latent vector this layer is constructed for.
    nsamples          : when sampling, how many samples to use
    """
    mean_param = Input(shape = repar_size,  name = "input_mean")
    std_param  = Input(shape = repar_size,  name = "input_std")

    std_param_sp = (tf.nn.softplus(std_param) + 1e-5) * noise_scale_input
    z_posterior = tfp.layers.DistributionLambda(make_distribution_fn=lambda t: tfp.distributions.MultivariateNormalDiag(
            loc=t[..., :repar_size], scale_diag=t[..., repar_size:]),
        convert_to_tensor_fn=lambda s: s.sample(nsamples), name = "node_latent")(tf.concat([mean_param, std_param_sp], axis = -1))


    model = tf.keras.Model(inputs = [std_param, mean_param] , outputs = [tf.identity(z_posterior), z_posterior], name = name)
    return model


@tf.function
def _split_tensor_dict_to_repar(td, nlatent_nodes_or_all, nlatent_edges = None, nlatent_global = None):
    """
    splits a tensor dictionary that corresponds to a graphTuple to mean and std parametrization. 
    if only `nlatent_nodes_or_all` is defined, then all dimensions are the same.

    """
    nlatent_nodes = nlatent_nodes_or_all;
    if nlatent_edges is None:
        nlatent_edges = nlatent_nodes_or_all

    if nlatent_global is None:
        nlatent_global = nlatent_nodes_or_all

    td_mean = td.copy()
    td_std  = td.copy()

    td_mean['edges'] = td['edges'][:,:nlatent_edges]
    td_mean['nodes'] = td['nodes'][:,:nlatent_nodes]
    td_mean['global_attr'] = td['global_attr'][:,:nlatent_global]

    td_std['edges'] = td['edges'][:,nlatent_edges:]
    td_std['nodes'] = td['nodes'][:,nlatent_nodes:]
    td_std['global_attr'] = td['global_attr'][:,nlatent_global:]
    return td_mean, td_std

def _reshape_sampled_multi_graphtuples_td(node_prob_out, 
                                         edge_prob_out,
                                         glob_prob_out,
                                         tdict = None,
                                         nreps = None,
                                         n_nodes = None,
                                         n_edges = None,
                                         n_graphs = None,
                                         senders = None,
                                         receivers = None,
                                         _global_reps_for_nodes = None, 
                                         _global_reps_for_edges = None,
                                         **kwargs):
    """
    ####################################################################
    # -> Operates on tensor dictionaries                               #                               
    ####################################################################
    After the sampling of the VAE encoder, there is an additional precedding dimension.
    With this function it is reshaped in order to correctly correspond to graphs for 
    easier use in downstream layers.

    parameters:
      node_prob_out : [nsamples, ngraphs, n_node_dimension] tensor containing the node attributes
      edge_prob_out : [nsamples, ngraphs, n_edge_dimension]
      glob_prob_out : [nsamples, ngraphs, n_global_dimension]
      tdict         : a dictionary containing what a GraphTuple normally contains. 
    """
    if nreps == None:
        nreps = node_prob_out.shape[0] 

    if tdict is not None:
        senders = tdict['senders']
        receivers = tdict['receivers']
        n_graphs = tdict['n_graphs']
        n_edges = tdict['n_edges']
        n_nodes = tdict['n_nodes']
        _global_reps_for_nodes = tdict['_global_reps_for_nodes']
        _global_reps_for_edges = tdict['_global_reps_for_edges']


    #nedges = np.sum(tdict['n_edges'])
    nnodes = tf.reduce_sum(n_nodes)

    new_senders             = tf.concat([senders   + (nnodes)*k for k in range(nreps)], axis = 0)
    new_receivers           = tf.concat([receivers + (nnodes)*k for k in range(nreps)], axis = 0)
    new_reps_for_edges      = tf.concat([_global_reps_for_edges + n_graphs * k for k in range(nreps)], axis = 0)
    new_reps_for_nodes      = tf.concat([_global_reps_for_nodes + n_graphs * k for k in range(nreps)], axis = 0)

    #new_nodes               = tf.reshape(node_prob_out, [node_prob_out.shape[0]*node_prob_out.shape[1],node_prob_out.shape[2]])
    #new_edges               = tf.reshape(edge_prob_out, [edge_prob_out.shape[0]*edge_prob_out.shape[1],edge_prob_out.shape[2]])

    new_nodes               = tf.reshape(node_prob_out, [tf.shape(node_prob_out)[0]*tf.shape(node_prob_out)[1],tf.shape(node_prob_out)[2]])
    new_edges               = tf.reshape(edge_prob_out, [tf.shape(edge_prob_out)[0]*tf.shape(edge_prob_out)[1],tf.shape(edge_prob_out)[2]])

    global_state_size_      = glob_prob_out.shape[-1]
    reshaped_global_latent  = tf.reshape(glob_prob_out,[-1, global_state_size_])

    out_dict = {'edges' : new_edges ,'nodes' : new_nodes,
                'senders' : new_senders, 'receivers' : new_receivers,
                'n_edges' : tf.repeat(n_edges,nreps),'n_nodes' : tf.repeat(n_nodes, nreps), 
                'n_graphs' : n_graphs*nreps , '_global_reps_for_edges' : new_reps_for_edges, 
                '_global_reps_for_nodes' : new_reps_for_nodes,'global_attr' : reshaped_global_latent}
    return out_dict
    
    


# Graph Conditional VAE layer:
# takes a graph as input and outputs samples acc. to a VAE.

class RelationalGraphVAE:
    """
    Relational Graph VAE:
    --------------------
    A GraphNet with latent variables on edges, nodes and global variables, with the 
    possibility of conditioning the VAE on other graphs (of the same connectivity)
    
    It operates on tensor dictionaries so that it is possible to trace the graph for 
    performance with @tf.function.
    
    Usage:
        rvae = RelationalGraphVAE(units, n_iwae)
        out  = rvae(tensor_dict, cond_input) # <- this will contain a set of graphs.
                                             # The conditioning is used to augment the
                                             latent space (after appropriate re-shaping,
                                             different for nodes, edges and globals)
    
    """
    def __init__(self, units, gn_state_size, n_iwae = 10, nlatent = 3,noise_scale = 0.1, graph_indep = True,min_std = 1e-5,**kwargs):
        """
        parameters:
          units            : global MLP size for the involved GNs
          gn_state_size    : the output size for the GN (all blocks: edge/node/global) have 
                             the same output.
          n_iwae           : number of IWAE samples
          nlatent          : size of all latents (global/edges/nodes)
          nlatent_nodes    : overrides nlatent - number of latent variables for nodes.
          nlatent_edges    : overrides nlatent - number of latent variables for edges.
          nlatent_globs  : overrides nlatent - number of latent variables for globals.
          noise_scale      : [.1] a number to multiply the STD of all the involved base distributions. 
                             can be changed with assigning the `self.noise_scale_input` variable.

          min_std          : the minimum value of the standard deviation that is used for the latent space. 
        """
        self.gn_state_size = gn_state_size 
        self.min_std = min_std


        self.graph_indep = graph_indep
        
        nlatent_nodes , nlatent_edges, nlatent_globs = [nlatent]*3
        if 'nlatent_nodes' in kwargs.keys():
            nlatent_nodes = kwargs['nlatent_nodes']
        if 'nlatent_edges' in kwargs.keys():
            nlatent_edges = kwargs['nlatent_edges']
        if 'nlatent_globs' in kwargs.keys():
            nlatent_globs = kwargs['nlatent_globs']
        self.full_gn = None
        self.units  = units
        self.n_iwae = n_iwae
        
        self.nlatent_nodes = nlatent_nodes
        self.nlatent_edges = nlatent_edges
        self.nlatent_globs = nlatent_globs
        self.noise_scale_input = tf.Variable(0.1, trainable = False);
        self.cond_edges, self.cond_nodes, self.cond_globals = [False, False, False]
        if 'use_global_input' not in kwargs.keys():
            self.use_global_input = True

        self.is_built = False
        self.weights = []
        
    def _build(self,
               node_in_size,
               edge_in_size,
               global_in_size, 
               cond_node_in_size,
               cond_edge_in_size, 
               cond_global_in_size):
        
        dec_latents = {'node' : (self.nlatent_nodes, cond_node_in_size),
                       'edge' : (self.nlatent_edges, cond_edge_in_size),
                       'glob' : (self.nlatent_globs, cond_global_in_size)}
        dec_input_sizes = {'node' : None,            'edge' : None,            'glob' : None             }

        def _dec_in(nlatent,cond_size):
            if cond_size is None or (cond_size == 0):
                return nlatent , False 
            else:
                return nlatent + cond_size, True

        vv = []
        for h in ['node','edge','glob']:
            dec_input_sizes[h] , dec_has_cond = _dec_in(*dec_latents[h])
            vv.append(dec_has_cond)

        self.cond_nodes, self.cond_edges, self.cond_globals = vv

            
        # This takes the input GraphTuple and computes a GraphTuple with sizes to be used 
        # in the re-parametrization of a Gaussian.
        with tf.name_scope("to_latent"): #
            if not self.graph_indep:
                if self.use_global_input:
                    # Full message passing when computing the latent:
                    graph_fcn_to_latent = make_full_graphnet_functions(self.units,
                                                                       node_or_core_input_size = node_in_size,
                                                                       node_or_core_output_size = self.nlatent_nodes * 2,
                                                                       edge_input_size = edge_in_size,
                                                                       edge_output_size = self.nlatent_edges * 2,
                                                                       global_input_size = global_in_size,
                                                                       global_output_size=self.nlatent_globs * 2)
                else:
                    # When we want the node-to-global and edge-to-global message passing but we have no 
                    # global input ("graph" -> "graph_and_global"). Do not rely on this - probably is 
                    # going to be removed in the future.
                    graph_fcn_to_latent = make_graph_to_graph_and_global_functions(self.units,
                                                                       node_or_core_input_size = node_in_size,
                                                                       edge_input_size=edge_in_size,
                                                                       node_or_core_output_size = self.nlatent_nodes * 2,
                                                                       edge_output_size = self.nlatent_edges * 2,
                                                                       global_output_size=self.nlatent_globs * 2)
            else:
                # Graph independent block: no message passing:
                graph_fcn_to_latent = make_graph_indep_graphnet_functions(self.units, 
                                                                        node_or_core_input_size=node_in_size,
                                                                        node_or_core_output_size = self.nlatent_nodes * 2,
                                                                        edge_input_size = edge_in_size,
                                                                        edge_output_size = self.nlatent_edges * 2,
                                                                        global_input_size = global_in_size,
                                                                        global_output_size = self.nlatent_globs * 2)

            
        #this takes the augmented input (if the input is augmented)
        
        
        with tf.name_scope("gnn_decoder"):
            if not self.graph_indep:
                graph_fcn_output_from_aug = make_full_graphnet_functions(self.units,
                                                        node_or_core_input_size  = dec_input_sizes['node'] ,
                                                        edge_input_size          = dec_input_sizes['edge'],
                                                        node_or_core_output_size = self.gn_state_size,
                                                        global_input_size        = dec_input_sizes['glob'])

            else:
                graph_fcn_output_from_aug = make_graph_indep_graphnet_functions(self.units,
                                                        node_or_core_input_size  = dec_input_sizes['node'] ,
                                                        edge_input_size          = dec_input_sizes['edge'],
                                                        node_or_core_output_size = self.gn_state_size,
                                                        global_input_size        = dec_input_sizes['glob'])
        with tf.name_scope("rvae_post"):
            self.edge_prob_mlp   = make_repar_gaussian_mlp(self.nlatent_edges,self.noise_scale_input,
                                              nsamples = self.n_iwae,
                                              name = "edge_prob_model"  ,
                                              min_std = self.min_std)

            self.node_prob_mlp   = make_repar_gaussian_mlp(self.nlatent_nodes,self.noise_scale_input,
                                                      nsamples = self.n_iwae,
                                                      name = "node_prob_model"  ,
                                                      min_std = self.min_std)

            self.glob_prob_mlp   = make_repar_gaussian_mlp(self.nlatent_globs,self.noise_scale_input,
                                                      nsamples = self.n_iwae,
                                                      name = "global_prob_model",
                                                      min_std = self.min_std)

        self.input_gnn  = GraphNet(**graph_fcn_to_latent)
        self.output_gnn = GraphNet(**graph_fcn_output_from_aug)
        self.is_built   = True
        self.weights = [*self.input_gnn.weights, *self.output_gnn.weights,*self.glob_prob_mlp.weights , *self.edge_prob_mlp.weights, *self.node_prob_mlp.weights]

    def _get_posterior_repar(self,tensordict):
        """
        Returns the reparametrization computed with "tensordict" as input

        """

        o1 = self.input_gnn.eval_tensor_dict(tensordict)
        c_mean, c_std                     = _split_tensor_dict_to_repar(o1,self.nlatent_nodes,
                                                                        nlatent_edges=self.nlatent_edges,
                                                                        nlatent_global = self.nlatent_globs)
        
        return c_mean, c_std 

    def get_posteriors(self, tensordict):
        """
        Simply returns the posterior distributions from the encoder. Note that 
        conditioning is not needed for computing the posteriors (the conditioning is repeated and appended 
        appropiately in the __call__ method). 

        Returns instances of distributions from tensorflow probability.
        Arguments:
          tensordict
        Returns:
          post_nodes, post_edges, post_globs
        """
        o1 = self.input_gnn.eval_tensor_dict(tensordict)
        c_mean, c_std                     = _split_tensor_dict_to_repar(o1,self.nlatent_nodes,
                                                                        nlatent_edges=self.nlatent_edges,
                                                                        nlatent_global = self.nlatent_globs)
        
        edge_latent_samples  , post_edges = self.edge_prob_mlp([c_mean['edges'],      c_std['edges']]      )
        node_latent_samples  , post_nodes = self.node_prob_mlp([c_mean['nodes'],      c_std['nodes']]      )
        global_latent_samples, post_globs = self.glob_prob_mlp([c_mean['global_attr'],c_std['global_attr']])
        return post_nodes, post_edges, post_globs
        
    def __call__(self, tensordict, tensordict_cond= None, return_post = False):
        """
        Params:
            tensordict      : a tensor dictionary representing a graph tuple
            tensordict_cond : conditioning tensor dictionary (contains edge,global,
                              node conditioning) also representing a GraphTuple 
                              (it is assumed the connectivity is the same)
        outputs:
            A tensordict representing a GT
        """
        
        # computing the re-parametrization and splitting the output:
        o1 = self.input_gnn.eval_tensor_dict(tensordict)
        c_mean, c_std                     = _split_tensor_dict_to_repar(o1,self.nlatent_nodes,
                                                                        nlatent_edges=self.nlatent_edges,
                                                                        nlatent_global = self.nlatent_globs)
        
        edge_latent_samples  , post_edges = self.edge_prob_mlp([c_mean['edges'],      c_std['edges']]      )
        node_latent_samples  , post_nodes = self.node_prob_mlp([c_mean['nodes'],      c_std['nodes']]      )
        global_latent_samples, post_globs = self.glob_prob_mlp([c_mean['global_attr'],c_std['global_attr']])
        
        # evaluating the distributions:
        td_cond = [None,None,None]
        if self.cond_nodes:
            td_cond[0] = tensordict_cond['nodes']
        if self.cond_edges:
            td_cond[1] = tensordict_cond['edges']
        if self.cond_globals:
            td_cond[2] = tensordict_cond['global_attr']
            
        # If there is any conditioning information (also a graph), append it:
        aug_node,aug_edge, aug_glob = self._augment_graph_latent_with_static_vars(node_latent_samples,
                                                                                  edge_latent_samples,
                                                                                  global_latent_samples,*td_cond)
        
        gt_to_decoder = _reshape_sampled_multi_graphtuples_td(aug_node,
                                                              aug_edge,
                                                              aug_glob,
                                                              tdict = tensordict,
                                                              nreps = self.n_iwae)
        

        o2 = self.output_gnn.eval_tensor_dict(gt_to_decoder)
        if not return_post:
            return o2
        else:
            return o2, post_nodes, post_edges, post_globs
    
    @tf.function
    def _augment_graph_latent_with_static_vars(self,
                                               node_latent_samples,
                                               edge_latent_samples, 
                                               global_latent_samples,
                                               cond_nodes, 
                                               cond_edges, 
                                               cond_global):
        """
        A convenience function for sampling the latent space which takes care of repeating and appending 
        correctly conditioning (for instance, static variables) to the latent space (to implement the graph CVAE).

        Repeats the conditioning variables and concatenates them with the corresponding node_latent_* variables. The number of IWAE samples from the relational VAE is used as the number of repetitions.

        parameters:
          node_latent_samples  : samples from the latent space for nodes. Either from the posterior or the prior. 
          edge_latent_samples  : samples from the latent space for edges.
          global_latent_samples : samples from the latent space for global attributes. (size [ngraphs, 
          cond_nodes           : the conditioning (non-repeated) variables
          cond_edges           : the conditioning (non-repeated) edges
          cond_global          : the conditioning (non-repeated) global attributes.

        outputs:
          aug_latent_node_samples, aug_latent_edge_samples, aug_global
        """

        if self.cond_edges:
            static_edges_rep = tf.repeat(tf.cast(cond_edges[tf.newaxis,...], DTYPE),self.n_iwae, axis = 0 )
            aug_latent_edge_samples = tf.concat([tf.cast(edge_latent_samples,DTYPE), static_edges_rep], axis = -1)
        else:
            aug_latent_edge_samples = edge_latent_samples
        
        if self.cond_nodes:
            static_nodes_rep = tf.repeat(tf.cast(cond_nodes[tf.newaxis,...], DTYPE),self.n_iwae, axis = 0 )
            aug_latent_node_samples = tf.concat([tf.cast(node_latent_samples, DTYPE), static_nodes_rep], axis = -1)
        else:
            aug_latent_node_samples = node_latent_samples
            
        if self.cond_globals : 
            determ_global_rep  = tf.repeat(tf.cast(cond_global[tf.newaxis,...], DTYPE),self.n_iwae, axis = 0 )
            aug_global         = tf.concat([global_latent_samples,tf.cast(determ_global_rep,DTYPE)], axis = -1)
        else:
            aug_global         = global_latent_samples
            
        return aug_latent_node_samples, aug_latent_edge_samples, aug_global

