import tensorflow as tf
from tf_gnns import make_mlp_graphnet_functions, make_graph_to_graph_and_global_functions, make_full_graphnet_functions, make_graph_indep_graphnet_functions
from tf_gnns import GraphNet
import tensorflow_probability as tfp 
from relational_vae import RelationalGraphVAE
import numpy as np
def _assign_add_tensor_dict(d_,od):
    """
    add nodes/edges/globals of d2 to d1 and return.
    """
    d_['nodes']       = d_['nodes']       + od['nodes']
    d_['edges']       = d_['edges']       + od['edges']
    d_['global_attr'] = d_['global_attr'] + od['global_attr'] 
    return d_

class SCADAGraphNetVAE(tf.keras.Model):
    def __init__(self,
                 N_IWAE = 10,
                 NLATENT = 5,
                 NLATENT_GLOBAL = 5,
                 GLOBAL_OUT_SIZE = 3,
                 ncore_enc_networks = 3,
                 ncore_dec_networks = 3,
                 activation = "gelu",
                 agg_fcn = "mean_max_min",
                 gnn_size = 128,
                 core_state_size = 64,
                 node_size = 10,
                 edge_size = 3,
                 **kwargs):
        super().__init__(**kwargs)
        """
        A VAE for graph structured data, that includes a message-passing graph encoder, a message passing decoder 
        and a latent space for node, edge and global variables. 
        
        parameters:
          N_IWAE             : the number of samples to be taken from the latent space. 
          NLATENT            : the size of the node and edge latent spaces
          NLATENT_GLOBAL     :
          ncore_enc_networks : the message passing steps of the encoder networks
          ncore_dec_networks : the >>                           decoder >>
          activation         : the activation used in the GNNs all over
          agg_fcn            : the aggregation function used for all the GNs
          gnn_size           : the width of the internal layers of the MLPs used in the GNNs 
          core_state_size    : the size of the core networks - the core networks have all the same size 
                               (convenient for residual connections)
          node_size          : the feature size of the nodes of the auto-encoded graph 
          edge_size          : the feature size of the edges of the auto-encoded graph
        """

        self.N_IWAE = N_IWAE
        self.NLATENT_GLOBAL = NLATENT_GLOBAL
        self.NLATENT = NLATENT

        core_enc_networks = []; # <- to hold the core networks for the encoder.
        core_dec_networks = []; # <- to hold the core networks for the decoder

        gt_to_global_size = gnn_size

        repar_mlp_size = gnn_size


        ######################################
        # For conditioning the decoder.      # 
        ######################################
        cond_node_size   = 2 # 
        cond_edge_size   = 3 # 
        cond_global_size = 3 # cos/sin of mean wind orientation, and mean wind speed
        #edge_input_size = 3
        
        print("Edge size: %s, Nodes size: %s. Global to decoder: %i\n gnn_size: %i, core_size: %i , latents: node %i, edge %i, global %i "%(str(edge_size), str(node_size), cond_global_size+NLATENT_GLOBAL, gnn_size, core_state_size, 
                                                                                                                                            NLATENT,NLATENT, NLATENT_GLOBAL))

        node_input_size_enc,  node_output_size_enc  = [node_size, gnn_size]
        node_input_size_core, node_output_size_core = [gnn_size,  gnn_size]


        edge_input_size_enc  = edge_size

        ## --- Encoder ---
        with tf.name_scope("gnn_encoder"):
            with tf.name_scope("enc_in_gnn"):
                  graph_fcn_enc = make_graph_to_graph_and_global_functions(gnn_size, 
                                                                           node_input_size_enc, 
                                                                           global_output_size = core_state_size,
                                                                           edge_input_size=edge_size,
                                                                           aggregation_function = agg_fcn)

            for n_ in range(ncore_enc_networks):
                with tf.name_scope("core_enc_%i"%n_):

                    graph_fcn_core_i = make_full_graphnet_functions(gnn_size,
                                                                    node_or_core_input_size = core_state_size,
                                                                    global_input_size       = core_state_size,
                                                                    global_output_size      = core_state_size,
                                                                    aggregation_function = agg_fcn)
                    core_enc_networks.append(graph_fcn_core_i)

        rvae = RelationalGraphVAE(gnn_size,core_state_size, n_iwae=N_IWAE,
                                 nlatent_nodes = NLATENT,
                                 nlatent_edges = NLATENT,
                                 nlatent_globs = NLATENT_GLOBAL,aggregation_function = agg_fcn,
                                 use_graph_indep = True)



        with tf.name_scope("gnn_decoder"):
            ## --- Decoder ----
            for n_ in range(ncore_dec_networks):
                with tf.name_scope("core_dec_%i"%n_):
                    graph_fcn_core_i =  make_full_graphnet_functions(gnn_size,
                                                                     node_or_core_input_size=core_state_size,
                                                                    aggregation_function = agg_fcn)
                    core_dec_networks.append(graph_fcn_core_i)


            with tf.name_scope("gnn_final"):
                graph_fcn_dec_final = make_full_graphnet_functions(gnn_size,
                                                                  node_or_core_input_size=core_state_size,
                                                                  node_or_core_output_size=node_input_size_enc,
                                                                  edge_output_size=edge_input_size_enc,
                                                                  global_output_size=GLOBAL_OUT_SIZE,
                                                                  aggregation_function = agg_fcn)


        gn_enc = GraphNet(**graph_fcn_enc)
        
        gn_core_enc = [GraphNet(**nw_params) for nw_params in core_enc_networks]
        gn_core_dec = [GraphNet(**nw_params) for nw_params in core_dec_networks]
        gn_dec_final    = GraphNet(**graph_fcn_dec_final)
    
        self.gn_enc, self.gn_core_dec, self.gn_core_enc, self.gn_dec_final = gn_enc, gn_core_dec, gn_core_enc, gn_dec_final

        ## The prior distributions for the latent space:
        edge_prior = tfp.distributions.Normal(loc = np.zeros(NLATENT,       dtype="float32"), scale = np.ones(NLATENT, dtype="float32"))
        node_prior = tfp.distributions.Normal(loc = np.zeros(NLATENT,       dtype="float32"), scale = np.ones(NLATENT,dtype="float32"))
        glob_prior = tfp.distributions.Normal(loc = np.zeros(NLATENT_GLOBAL, dtype = "float32"), scale = np.ones(NLATENT_GLOBAL, dtype = "float32"))
        self.edge_prior = edge_prior
        self.node_prior = node_prior
        self.glob_prior = glob_prior
        
        rvae._build(core_state_size, 
                    core_state_size,
                    core_state_size, 
                    cond_node_size,
                    cond_edge_size,
                    cond_global_size)

        self.rvae = rvae
        ## All networks:
        all_networks = [gn_enc, *gn_core_enc, rvae, *gn_core_dec,  gn_dec_final]
        self.all_networks = all_networks
        all_weights_flat = [];
        for nw in all_networks:
            all_weights_flat.extend(nw.weights)
        self.all_weights_flat = all_weights_flat
    


    def encoder_gn(self,od__, return_resid_outputs = False):
        """
        params:
            og : an ordered dict. that represents a GraphTuple
        """

        od_ = self.gn_enc.eval_tensor_dict(od__)
        resid_outputs = [];
        for gn_ in self.gn_core_enc:
            d_ = gn_.eval_tensor_dict(od_)
            d_ = _assign_add_tensor_dict(d_, od_)
            resid_outputs.append(d_.copy())
            od_ = d_

        if return_resid_outputs:
            return resid_outputs

        return od_

    @tf.function
    def encode_to_latent(self,gt_in, cond_inputs):
        gt_out_z = self.encoder_gn(gt_in)
        h = {'nodes' : cond_inputs[0], 'edges' : cond_inputs[1], 'global_attr' : cond_inputs[2]}
        td_out , post_nodes, post_edges, post_globs = self.rvae(gt_out_z,h, return_post = True)
        return td_out, post_nodes,post_edges, post_globs

    @tf.function
    def decode_from_latent(self,td_to_decoder):
        """
        takes a graph tuple containing nodes and edges (sampled - several graphs) and an appropriately re-shaped 
        global attr. tensor and computes the decoder.
        """
        od = td_to_decoder
        for gn_ in self.gn_core_dec:
            g_ = gn_.eval_tensor_dict(od)
            g_ = _assign_add_tensor_dict(g_, od)
            od = g_

        d_ = self.gn_dec_final.eval_tensor_dict(g_)
        return d_

    @tf.function
    def eval_tot(self,gt_in, static_node_chans, static_edge_chans,
                 global_determ_decoder_input = None, eval_kl_div = False):
        """
        evaluation of the whole computational graph.
        parameters:
          gt_in       : the `GraphTuple` to compute on (not computing in-place but on a copy)
          eval_kl_div : [False] whether to return the KL divergence (for nodes, edges and global latents) or not.
        """
        cond_inputs = [static_node_chans, static_edge_chans, global_determ_decoder_input]
        ###################
        gt_out, post_nodes, post_edges, post_globs = self.encode_to_latent(gt_in, cond_inputs)

        ###################
        g_out = self.decode_from_latent(gt_out)

        if eval_kl_div:

            kln = post_nodes.kl_divergence(self.node_prior)
            kle = post_edges.kl_divergence(self.edge_prior) #_kl_normal_normal (post_edges.loc, edge_prior.loc , post_edges.scale, edge_prior.scale)#post_edges.kl_divergence(edge_prior)
            klg = post_globs.kl_divergence(self.glob_prior)
            kldivs = [kln, kle, klg]
            return g_out, kldivs

        return g_out
