import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from model.graph import GraphTripleConvNet, _init_weights, make_mlp

class Sg2ScAEModel(nn.Module):
    def __init__(self,
                 vocab,
                 node_embedding_dim=64,
                 gconv_pooling='avg',
                 gconv_num_layers=3,
                 mlp_normalization='none',
                 num_box_params=7,
                 shape_input_dim=256,
                 residual=False):

        super(Sg2ScAEModel, self).__init__()

        gconv_dim= node_embedding_dim   #64
        box_embedding_dim = int(node_embedding_dim) #64
        shape_embedding_dim = shape_input_dim #256
        node_vec_dim = gconv_dim +box_embedding_dim + shape_embedding_dim #384
        input_dim_obj_dc=gconv_dim*2+node_vec_dim # 512  128+384=512
        gconv_hidden_dim = node_embedding_dim*8 #512


        self.num_box_params = num_box_params #7
        self.shape_input_dim = shape_input_dim #256
        self.mlp_normalization = mlp_normalization

        num_objs = len(set(vocab['object_idx_to_name']))
        #print("AE_num_objs:", num_objs)
        num_preds = len(set(vocab['pred_idx_to_name']))


        # build network components for encoder and decoder
        self.obj_embeddings_ec= nn.Embedding(num_objs + 1, node_embedding_dim) #64
        self.pred_embeddings_ec= nn.Embedding(num_preds, 6 * node_embedding_dim) #384   sematic 64, box 64, shape 256

        self.readout_mlp = nn.Sequential(
        nn.Linear(node_vec_dim, 512),   
        nn.ReLU(inplace=True),
        nn.Linear(512, node_vec_dim)    
        )

        self.obj_embeddings_dc= nn.Embedding(num_objs + 1, 2*node_embedding_dim) #128
        self.pred_embeddings_dc = nn.Embedding(num_preds, 8*node_embedding_dim) # 512 ( 128+384=512 )
  

        self.box_embeddings  = nn.Linear(num_box_params, box_embedding_dim)      # 64
        self.shape_embeddings= nn.Linear(shape_input_dim, shape_embedding_dim)     # 256

        self.box_embeddings_dc  = nn.Linear(num_box_params, box_embedding_dim)      # 64


        gconv_kwargs_ec = {
            'input_dim_obj': gconv_dim+gconv_dim+shape_embedding_dim, #384  sematic 64, box 64, shape 256
            'input_dim_pred': gconv_dim*6, # 384
            'hidden_dim': gconv_hidden_dim, #512
            'pooling': gconv_pooling,
            'num_layers': gconv_num_layers,
            'mlp_normalization': mlp_normalization,
            'residual': True
        }
        gconv_kwargs_dc = {
            'input_dim_obj': input_dim_obj_dc,  # sematic 128 ;node_vec 384=512
            'input_dim_pred': input_dim_obj_dc,
            'hidden_dim': gconv_hidden_dim*2, #1024
            'pooling': gconv_pooling,
            'num_layers': gconv_num_layers,
            'mlp_normalization': mlp_normalization,
            'residual': residual
        }
        self.gconv_net_ec = GraphTripleConvNet(**gconv_kwargs_ec)
        self.gconv_net_dc = GraphTripleConvNet(**gconv_kwargs_dc)
        # --------------------------------------------------

        # --------------------------------------------------
        
        box_net_dim = num_box_params

        box_net_layers = [input_dim_obj_dc, gconv_hidden_dim*2, box_net_dim] #512->1024->7
        shape_net_layers = [input_dim_obj_dc+box_embedding_dim, gconv_hidden_dim*2, shape_input_dim]       #576->1024->256   

        self.decoder_mlp_box = make_mlp(box_net_layers,
                                        batch_norm=mlp_normalization, norelu=True)
        self.decoder_mlp_shape = make_mlp(shape_net_layers,
                                          batch_norm=mlp_normalization, norelu=True)

        self._init_all()

    def _init_all(self):
        # weight init
        self.obj_embeddings_ec.apply(_init_weights)
        self.pred_embeddings_ec.apply(_init_weights)
        self.box_embeddings.apply(_init_weights)
        self.shape_embeddings.apply(_init_weights)
        self.obj_embeddings_dc.apply(_init_weights)
        self.pred_embeddings_dc.apply(_init_weights)
        self.gconv_net_ec.apply(_init_weights)
        self.gconv_net_dc.apply(_init_weights)
        self.decoder_mlp_box.apply(_init_weights)
        self.decoder_mlp_shape.apply(_init_weights)
        self.readout_mlp.apply(_init_weights)

    def encode_scene_graph(self, objs, triples, boxes, shapes):

        # build edges
        s, p, o = triples.chunk(3, dim=1)
        s, p, o = [x.squeeze(1) for x in (s, p, o)]
        edges = torch.stack([s, o], dim=1)

        sematic_ft = self.obj_embeddings_ec(objs)

        # node embedding
        box_ft = self.box_embeddings(boxes)       # 64
        shape_ft = self.shape_embeddings(shapes)    # 64
        # combine => node_embedding_dim
        node_in  =torch.cat([sematic_ft, box_ft, shape_ft], dim=1) #64+64+256=384

        # edge embedding
        pred_vec = self.pred_embeddings_ec(p)           # 384

        # GNN => node_vecs( O, hidden_dim ), pred_vecs( T, hidden_dim )
        node_vecs, pred_vecs = self.gconv_net_ec(node_in, pred_vec, edges)  #384
        
        return node_vecs  # 384
    
    # --------------------------------------------------
    # Encoder: => node_vecs & global_latent
    # --------------------------------------------------
    def encoder(self, objs, triples, boxes, shapes,enc_obj_to_scene,enc_triple_to_scene):
        """
        """
        node_vecs = self.encode_scene_graph(objs, triples, boxes, shapes)
        return node_vecs  # 384

    # --------------------------------------------------
    # Decoder: => boxes_pred, shapes_pred
    # --------------------------------------------------
    def decoder(self,objs,triples, node_vecs):
        """
        => broadcast => decode box + shape
        """
          # build edges
        s, p, o = triples.chunk(3, dim=1)
        s, p, o = [x.squeeze(1) for x in (s, p, o)]
        edges = torch.stack([s, o], dim=1)

        obj_sem  = self.obj_embeddings_dc(objs)  #128

        # combine => node_embedding_dim

        node_in  = torch.cat([obj_sem, node_vecs], dim=1) #  #128+384=512
        # edge embedding
        pred_vec = self.pred_embeddings_dc(p)           # 512

        # GNN => node_vecs( O, hidden_dim ), pred_vecs( T, hidden_dim )
        final_node_vecs, pred_vecs = self.gconv_net_dc(node_in, pred_vec, edges) #512
        # node_vecs => (O, hidden_dim)



        # decode box
        boxes_pred = self.decoder_mlp_box(final_node_vecs)    # 512->7
        box_feat = self.box_embeddings_dc(boxes_pred)          # (O, 64)  


        shape_in = torch.cat([final_node_vecs, box_feat], dim=1)  # (O,512+64=576)

        shapes_pred = self.decoder_mlp_shape(shape_in)  # (O,256)

        return  boxes_pred, shapes_pred

    # --------------------------------------------------
    # forward => AE style
    # --------------------------------------------------
    def forward(self, objs, triples, boxes_gt, shapes_gt,enc_obj_to_scene,enc_triple_to_scene):
        """
        1) encode => node_vecs
        2) decode => boxes_pred, shapes_pred
        
        """
        node_vecs= self.encoder(objs, triples, boxes_gt, shapes_gt,enc_obj_to_scene,enc_triple_to_scene)
        boxes_pred, shapes_pred  = self.decoder(objs, triples,node_vecs)
        return boxes_pred, shapes_pred

