import torch
import torch.nn as nn
import torch.nn.functional as F

from model.graph import GraphTripleConvNet, _init_weights, make_mlp


class Sg2ScAEModelBaseline1(nn.Module):
    """
    """

    def __init__(self, vocab,
                 node_embedding_dim=128,
                 gconv_pooling="avg",
                 gconv_num_layers=6,
                 mlp_normalization="none",
                 num_box_params=7,
                 shape_input_dim=256,
                 residual=True):
        super().__init__()
        self.D = node_embedding_dim
        self.shape_input_dim = shape_input_dim

        n_obj = len(set(vocab['object_idx_to_name'])) + 1
        n_rel = len(set(vocab['pred_idx_to_name'])) + 1

        # ----- embedding -----
        self.ec_emb_obj  = nn.Embedding(n_obj, node_embedding_dim)
        self.ec_emb_pred = nn.Embedding(n_rel, 2 * node_embedding_dim)   # encoder 
        self.dc_emb_obj  = nn.Embedding(n_obj, node_embedding_dim)
        self.dc_emb_pred = nn.Embedding(n_rel, 3 * node_embedding_dim)   # decoder 
        self.emb_box     = nn.Linear(7, node_embedding_dim)

        # ----- Encoder GCN -----
        enc_cfg = dict(
            input_dim_obj     = 2 * node_embedding_dim,   # [sem(D), box(D)] = 2D
            input_dim_pred    = 2 * node_embedding_dim,
            hidden_dim        = 4 * node_embedding_dim,
            pooling           = gconv_pooling,
            num_layers        = gconv_num_layers,
            mlp_normalization = mlp_normalization,
            residual          = residual
        )
        self.gcn_enc = GraphTripleConvNet(**enc_cfg)

        # ----- Decoder GCN -----
        dec_in = 3 * node_embedding_dim
        dec_cfg = dict(
            input_dim_obj     = dec_in,
            input_dim_pred    = 3 * node_embedding_dim,
            hidden_dim        = 4 * node_embedding_dim,
            pooling           = gconv_pooling,
            num_layers        = gconv_num_layers,
            mlp_normalization = mlp_normalization,
            residual          = residual
        )
        self.gcn_dec = GraphTripleConvNet(**dec_cfg)

        self.mlp_box = make_mlp([dec_in, 4 * node_embedding_dim, 4 * node_embedding_dim, 2 * node_embedding_dim, num_box_params],
                                mlp_normalization, norelu=True)
        self.mlp_shape = make_mlp([dec_in + node_embedding_dim,  # 3D + D(box emb) = 4D
                                   5 * node_embedding_dim,
                                   5 * node_embedding_dim,
                                   3 * node_embedding_dim,
                                   2 * node_embedding_dim,
                                   shape_input_dim],
                                  mlp_normalization, norelu=True)

        self.apply(_init_weights)

    # ================================================= #
    def encoder(self,
                batch_scene_ids,       
                new_objs_l, new_box_l, 
                old_objs_l, old_box_l, 
                tri_l                   
                ):
        device = new_objs_l[0].device
        node_chunks, tri_chunks, split = [], [], []
        offset = 0

        for new_objs, new_box, old_objs, old_box, tri_local in zip(new_objs_l, new_box_l, old_objs_l, old_box_l, tri_l):

            sem_new = self.ec_emb_obj(new_objs)
            bft_new = self.emb_box(new_box)
            node_new = torch.cat([sem_new, bft_new], -1)    # (M, 2D)

            if old_objs.numel():
                sem_old = self.ec_emb_obj(old_objs)
                bft_old = self.emb_box(old_box)
                node_old = torch.cat([sem_old, bft_old], -1)  # (C, 2D)
                nodes = torch.cat([node_new, node_old], 0)    # (M+C, 2D)
            else:
                nodes = node_new

            #print(old_objs)

            node_chunks.append(nodes)

            #  +offset
            if tri_local.numel():
                s = tri_local[:, 0:1] + offset
                p = tri_local[:, 1:2]
                o = tri_local[:, 2:3] + offset
                tri_chunks.append(torch.cat([s, p, o], 1))

            M = new_box.size(0)
            C = old_box.size(0)
            split.append((offset, M, C))  
            offset += (M + C)

        node_in = torch.cat(node_chunks, 0) 
        triples = torch.cat(tri_chunks, 0) if tri_chunks else node_in.new_empty(0, 3)

        if triples.numel():
            s, p, o = triples.t()
            edges = torch.stack([s, o], 1)
            pred  = self.ec_emb_pred(p)
            node_out, _ = self.gcn_enc(node_in, pred, edges)
        else:
            node_out = node_in

        return node_out, split

    # ================================================= #
    def decoder(self,
            node_out,              # (N_tot, 2D)
            split,                 # [(start, M, C), ...]
            new_objs_l,            # List[Tensor]
            old_objs_l,            # List[Tensor]
            tri_l                  # List[Tensor]，
            ):
        device = node_out.device
        dec_node_chunks = []
        edge_chunks, pid_chunks = [], []
        new_ranges = []     
        all_ranges = []     

        offset = 0  
        for (start, M, C), new_objects, old_objects, tri_local in zip(split, new_objs_l, old_objs_l, tri_l):

            # 1) 
            new_node_vec = node_out[start : start + M]                   # (M, 2D)
            old_node_vec = node_out[start + M : start + M + C] if C > 0 \
                        else node_out.new_empty(0, node_out.size(1))  # (C, 2D)

            # 2) 
            sem_new = self.dc_emb_obj(new_objects)                       # (M, D)
            sem_old = self.dc_emb_obj(old_objects) if old_objects.numel() \
                    else sem_new.new_empty(0, sem_new.size(1))         # (C, D)

            # 3) 
            dec_nodes_in = torch.cat([new_node_vec, old_node_vec], 0)    # (M+C, 2D)
            sem_all      = torch.cat([sem_new,     sem_old     ], 0)     # (M+C, D)
            dec_in_scene = torch.cat([sem_all, dec_nodes_in], -1)        # (M+C, 3D)
            dec_node_chunks.append(dec_in_scene)

            # 4)  +offset
            if tri_local.numel():
                t = tri_local.clone()
                edges_data = torch.stack([t[:, 0] + offset,
                                        t[:, 2] + offset], 1)          # (E, 2)
                pid_data   = t[:, 1]                                      # (E,)
                edge_chunks.append(edges_data)
                pid_chunks.append(pid_data)

            # 5)
            new_ranges.append((offset, M))        
            all_ranges.append((offset, M + C))     
            offset += (M + C)

        if dec_node_chunks:
            dec_in_all = torch.cat(dec_node_chunks, 0)                   # (Σ(M+C), 3D)
        else:
            dec_in_all = torch.empty(0, 3*self.D, device=device)

        if edge_chunks:
            edges_all = torch.cat(edge_chunks, 0)                        # (E, 2)
            pid_all   = torch.cat(pid_chunks, 0)                         # (E,)
            pred_vec  = self.dc_emb_pred(pid_all)
            dec_vec_all, _ = self.gcn_dec(dec_in_all, pred_vec, edges_all)
        else:
            dec_vec_all = dec_in_all                                     

        # 7) 
        box_all  = self.mlp_box(dec_vec_all)                         # (Σ(M+C), 7)
        shp_in   = torch.cat([dec_vec_all, self.emb_box(box_all)], -1)
        shp_all  = self.mlp_shape(shp_in)                            # (Σ(M+C), 256)

        return box_all, shp_all

    # ================================================= #
    def forward_batch_step(self,
                           obj_batch_scene_ids,  # (N_step_obj,)
                           objs,                 # (N_step_obj,)
                           boxes,                # (N_step_obj,7)
                           triples,              # (T_step,3) —— new-new/new-old/old-old
                           new_mask,             # (N_step_obj,)
                           obj_indices,          # (N_step_obj,) 
                         triple_scene_ids=None  ):         
        device = objs.device

        uniq_scene, seen = [], set()
        for sid in obj_batch_scene_ids.tolist():
            if sid not in seen:
                uniq_scene.append(sid); seen.add(sid)

        new_objs_l, new_box_l = [], []
        old_objs_l, old_box_l = [], []
        tri_l, scene_id_l     = [], []

        for sid in uniq_scene:
            sel_scene = (obj_batch_scene_ids == sid)
            if not sel_scene.any():
                continue

            obj_loc  = objs [sel_scene]
            box_loc  = boxes[sel_scene]
            new_loc  = new_mask[sel_scene]
            if not new_loc.any():
                continue
            old_loc  = ~new_loc

            #  new / old_all
            new_obj, new_box = obj_loc[new_loc], box_loc[new_loc]     # (M,), (M,7)
            old_obj, old_box = obj_loc[old_loc], box_loc[old_loc]     # (C,), (C,7)

            pos_all = obj_indices[sel_scene]        # (N_scene,)
            pos_new = pos_all[new_loc]              # (M,)
            pos_old = pos_all[old_loc]              # (C,)

            tri_local = objs.new_empty(0, 3, dtype=torch.long)  
            if triples.numel():
                if (triple_scene_ids is not None) and (triple_scene_ids.numel() == triples.size(0)):
                    t = triples[triple_scene_ids == sid]
                else:
                    t = triples

                if t.numel():
                    s, p, o = t.t()

                    in_scene = torch.isin(s, pos_all) & torch.isin(o, pos_all)
                    if in_scene.any():
                        s_k, p_k, o_k = s[in_scene], p[in_scene], o[in_scene]

                        order = torch.cat([pos_new, pos_old], 0)   # (M+C,)
                        lut = -torch.ones(int(order.max().item()) + 1,
                                        dtype=torch.long, device=device)
                        lut[order] = torch.arange(order.size(0), device=device)

                        s_loc = lut[s_k]; o_loc = lut[o_k]
                        valid = (s_loc >= 0) & (o_loc >= 0)
                        if valid.any():
                            tri_local = torch.stack([s_loc[valid], p_k[valid], o_loc[valid]], 1)

            new_objs_l.append(new_obj);  new_box_l.append(new_box)
            old_objs_l.append(old_obj);  old_box_l.append(old_box)
            tri_l.append(tri_local)
            scene_id_l.append(sid)

        batch_scene_ids = torch.tensor(scene_id_l, device=device, dtype=torch.long)

        node_out, split = self.encoder(batch_scene_ids, new_objs_l, new_box_l, old_objs_l, old_box_l, tri_l)
        box_p, shp_p    = self.decoder(node_out, split, new_objs_l, old_objs_l, tri_l)
        return box_p, shp_p
