import torch, torch.nn as nn
from typing import List, Tuple
from collections import defaultdict
from model.graph import GraphTripleConvNet, _init_weights, make_mlp


class Sg2ScAEModelIncremental3D(nn.Module):
    

    # ------------------------- init ------------------------- #
    def __init__(self, vocab,
                node_embedding_dim=128,
                gconv_pooling="avg",
                gconv_num_layers=6, 
                mlp_normalization="none", 
                num_box_params=7,
                shape_input_dim=256,
                residual=True):
        super().__init__()
        self.D, self.shape_input_dim = node_embedding_dim, shape_input_dim, 
        self.GLOBAL_PID = 0          # CLS↔node  predicate id
        self._scene_states = {}            # scene_id → CLS latent

        n_obj = len(set(vocab['object_idx_to_name'])) + 1
        n_rel = len(set(vocab['pred_idx_to_name'])) + 1

        # --- embedding ---
        self.ec_emb_obj  = nn.Embedding(n_obj, node_embedding_dim)
        self.ec_emb_pred = nn.Embedding(n_rel, 2*node_embedding_dim)
        self.dc_emb_obj  = nn.Embedding(n_obj, node_embedding_dim)
        self.dc_emb_pred = nn.Embedding(n_rel, 3*node_embedding_dim)
        self.emb_box     = nn.Linear(7, node_embedding_dim)
        self.cls_tok     = nn.Parameter(torch.randn(1, 2*node_embedding_dim))

        # --- encoder GCN ---
        enc_cfg = dict(input_dim_obj=2*node_embedding_dim,
                       input_dim_pred=2*node_embedding_dim,
                       hidden_dim=4*node_embedding_dim,
                       pooling=gconv_pooling,
                       num_layers=gconv_num_layers,
                       mlp_normalization=mlp_normalization,
                       residual=residual)
        self.gcn_enc = GraphTripleConvNet(**enc_cfg)

        dec_in  = 3 * node_embedding_dim          #   D  (semantic)
                                                # + 2D (latent from encoder)
        dec_cfg = dict(
            input_dim_obj   = dec_in,               # 384
            input_dim_pred  = 3 * node_embedding_dim,   #  256**
            hidden_dim      = 4 * node_embedding_dim,
            pooling         = gconv_pooling,
            num_layers      = gconv_num_layers,
            mlp_normalization = mlp_normalization,
            residual        = residual)

        self.gcn_dec   = GraphTripleConvNet(**dec_cfg)

        self.mlp_box   = make_mlp(
            [dec_in, 4 * node_embedding_dim, num_box_params],
            mlp_normalization, norelu=True)

        self.mlp_shape = make_mlp(
            [dec_in + node_embedding_dim,             # 3D + D = 4D
            5 * node_embedding_dim,
            3 * node_embedding_dim,
            shape_input_dim],
            mlp_normalization, norelu=True)

        self.apply(_init_weights)


    # ================================================= #
    # 1. encode_batch  (mega graph)
    # ================================================= #
    def encoder(self,
        batch_scene_ids,                 # (S,)
        new_objs_l, new_box_l,           # List[Tensor]
        old_objs_l, old_box_l,           # only last step nodes.
        tri_l                             # List[Tensor]， new↔new & new↔old
    ):
        device = new_objs_l[0].device
        node_chunks, tri_chunks, split = [], [], []
        offset = 0

        for sid, new_objs, new_box, old_objs, old_box, tri_local in \
            zip(batch_scene_ids.tolist(), new_objs_l, new_box_l, old_objs_l, old_box_l, tri_l):   #traverse every scene

            # 1) CLS token
            cls_vec = self._scene_states.get(sid, self.cls_tok)  # (1,2D)

            # 2) CLS + [new ; old]
            sem_new = self.ec_emb_obj(new_objs)
            bft_new = self.emb_box(new_box)
            node_new = torch.cat([sem_new, bft_new], -1)         # (M,2D)

            if old_objs.numel():
                sem_old = self.ec_emb_obj(old_objs)
                bft_old = self.emb_box(old_box)
                node_old = torch.cat([sem_old, bft_old], -1)     # (C,2D)
                nodes_no_cls = torch.cat([node_new, node_old], 0)  # (M+C,2D)
            else: # may no old nodes in the first step
                nodes_no_cls = node_new

            node_chunks.append(torch.cat([cls_vec, nodes_no_cls], 0))  # (1+M+C, 2D)

            # 3) 0..(M+C-1) → +1 for CLS
            if tri_local.numel():
                t = tri_local.clone()
                t[:, 0] += 1
                t[:, 2] += 1
                edges_data = t[:, [0, 2]]
                p_data     = t[:, 1]
            else:   # may no triples in the first step
                edges_data = torch.empty(0, 2, dtype=torch.long, device=device)
                p_data     = torch.empty(0,   dtype=torch.long, device=device)

            # 4) CLS↔new + CLS↔old（if C>0）
            M = new_box.size(0)
            C = old_box.size(0)
            dst_new  = torch.arange(1, 1 + M, device=device)
            edges_cls_new = torch.cat([
                torch.stack([torch.zeros_like(dst_new), dst_new], 1),
                torch.stack([dst_new, torch.zeros_like(dst_new)], 1)
            ], 0)

            if C > 0:
                dst_old  = torch.arange(1 + M, 1 + M + C, device=device)
                edges_cls_old = torch.cat([
                    torch.stack([torch.zeros_like(dst_old), dst_old], 1),
                    torch.stack([dst_old, torch.zeros_like(dst_old)], 1)
                ], 0)
                e_cls = torch.cat([edges_cls_new, edges_cls_old], 0)
            else:
                e_cls = edges_cls_new

            p_cls = torch.full((e_cls.size(0),), self.GLOBAL_PID, dtype=torch.long, device=device)

            # 5) 
            edges_loc = torch.cat([e_cls, edges_data], 0)
            p_loc     = torch.cat([p_cls, p_data], 0)

            if edges_loc.numel():
                tri_chunks.append(torch.cat([
                    edges_loc[:, 0:1] + offset,
                    p_loc[:, None],
                    edges_loc[:, 1:2] + offset
                ], 1))

            split.append((offset + 1, M, C))
            offset += 1 + M + C

        node_in = torch.cat(node_chunks, 0)
        triples = torch.cat(tri_chunks, 0) if tri_chunks else node_in.new_empty(0, 3)

        if triples.numel():
            s, p, o = triples.t()
            edges = torch.stack([s, o], 1)
            pred  = self.ec_emb_pred(p)
            node_out, _ = self.gcn_enc(node_in, pred, edges)
        else:
            node_out = node_in

        return node_out, split

    # ================================================= #
    # 2. decoder_batch
    # ================================================= #
    def decoder(self,
            node_out,                
            split,                   
            new_objs_l,              
            old_objs_l,              
            tri_l                    
        ):
        device = node_out.device
        dec_node_chunks = []      
        edge_chunks     = []      
        pid_chunks      = []      
        new_ranges      = []      

        offset = 0
        for (cls_idx, n_new, n_old), new_objects, old_objects, tri_local in zip(split, new_objs_l, old_objs_l, tri_l):  #traverse every scene

            # 1) 
            cls_vec      = node_out[cls_idx - 1]                               # (2D,)
            new_node_vec = node_out[cls_idx : cls_idx + n_new]                 # (M,2D)
            old_node_vec = node_out[cls_idx + n_new : cls_idx + n_new + n_old] \
                        if n_old > 0 else node_out.new_empty(0, node_out.size(1))  # (C,2D)

            # 2) 
            sem_new = self.dc_emb_obj(new_objects)                              # (M,D)
            sem_old = self.dc_emb_obj(old_objects) if old_objects.numel() else \
                    sem_new.new_empty(0, sem_new.size(1))                     # (C,D)

            cls_vec_expand = cls_vec.unsqueeze(0)                               # (1,2D)
            sem_cls = torch.zeros_like(sem_new[:1])                             # (1,D)

            dec_nodes_in = torch.cat([cls_vec_expand, new_node_vec, old_node_vec], 0)  # (1+M+C,2D)
            sem_all      = torch.cat([sem_cls,       sem_new,      sem_old     ], 0)   # (1+M+C,D)
            dec_in_scene = torch.cat([sem_all, dec_nodes_in], -1)                        # (1+M+C,3D)
            dec_node_chunks.append(dec_in_scene)

            # 3) 
            # 3.1 CLS↔new/old
            if n_new > 0:
                dst_new = torch.arange(1, 1 + n_new, device=device)
                edges_cls_new = torch.cat([
                    torch.stack([torch.zeros_like(dst_new), dst_new], 1),
                    torch.stack([dst_new, torch.zeros_like(dst_new)], 1),
                ], 0)
            else:
                edges_cls_new = dec_in_scene.new_empty(0, 2, dtype=torch.long)

            if n_old > 0:
                dst_old = torch.arange(1 + n_new, 1 + n_new + n_old, device=device)
                edges_cls_old = torch.cat([
                    torch.stack([torch.zeros_like(dst_old), dst_old], 1),
                    torch.stack([dst_old, torch.zeros_like(dst_old)], 1),
                ], 0)
            else:
                edges_cls_old = dec_in_scene.new_empty(0, 2, dtype=torch.long)

            edges_cls = torch.cat([edges_cls_new, edges_cls_old], 0)
            pid_cls   = torch.full((edges_cls.size(0),), self.GLOBAL_PID,
                                dtype=torch.long, device=device)

            # 3.2 
            if tri_local.numel():
                t = tri_local.clone()
                t[:, 0] += 1; t[:, 2] += 1
                edges_data = t[:, [0, 2]]
                pid_data   = t[:, 1]
                edges_all  = torch.cat([edges_cls, edges_data], 0)
                pid_all    = torch.cat([pid_cls,  pid_data], 0)
            else:
                edges_all  = edges_cls
                pid_all    = pid_cls

            #  batch offset
            if edges_all.numel():
                edge_chunks.append(torch.cat([
                    edges_all[:, 0:1] + offset,
                    edges_all[:, 1:2] + offset
                ], 1))
                pid_chunks.append(pid_all)

            new_ranges.append((offset + 1, n_new))  
            offset += 1 + n_new + n_old

        dec_in_all = torch.cat(dec_node_chunks, 0)                           # (Σ(1+M+C), 3D)
        if edge_chunks:
            edges_all = torch.cat(edge_chunks, 0)                             # (E, 2)
            pid_all   = torch.cat(pid_chunks, 0)                              # (E,)
            pred_vec  = self.dc_emb_pred(pid_all)
            dec_vec_all, _ = self.gcn_dec(dec_in_all, pred_vec, edges_all)    
        else:
            dec_vec_all = dec_in_all

        box_chunks, shp_chunks = [], []
        for start, n in new_ranges:
            dec_vec_new = dec_vec_all[start : start + n]                     # (n,2D)
            box_p  = self.mlp_box(dec_vec_new)                               # (n,7)
            shp_in = torch.cat([dec_vec_new, self.emb_box(box_p)], -1)       # (n, 2D+?)
            shp_p  = self.mlp_shape(shp_in)                                  # (n,256)
            box_chunks.append(box_p); shp_chunks.append(shp_p)


        return torch.cat(box_chunks, 0), torch.cat(shp_chunks, 0)
    
    def forward_batch_step(self,
        obj_batch_scene_ids,   # (N_step_obj,)
        objs,                  # (N_step_obj,)
        boxes,                 # (N_step_obj,7)
        triples,               # (T_step,3) 
        new_mask,              # (N_step_obj,)
        obj_indices,           # (N_step_obj,) 
        triple_scene_ids=None  # (T_step,) 
    ):
        device = objs.device

        uniq_scene, seen = [], set()
        for sid in obj_batch_scene_ids.tolist():
            if sid not in seen:
                uniq_scene.append(sid); seen.add(sid)

        new_objs_l, new_box_l = [], []
        old_objs_l, old_box_l = [], []
        tri_l, scene_id_l     = [], []

        for sid in uniq_scene:
            sel_scene = (obj_batch_scene_ids == sid)
            if not sel_scene.any():
                continue

            obj_loc  = objs [sel_scene]
            box_loc  = boxes[sel_scene]
            new_loc  = new_mask[sel_scene]
            if not new_loc.any():               
                continue
            old_loc  = ~new_loc

            new_obj, new_box = obj_loc[new_loc], box_loc[new_loc]     # (M,), (M,7)
            old_obj, old_box = obj_loc[old_loc], box_loc[old_loc]     # (C,), (C,7)

            pos_all = obj_indices[sel_scene]         # (N_scene,)
            pos_new = pos_all[new_loc]               # (M,)
            pos_old = pos_all[old_loc]               # (C,)

            tri_local = objs.new_empty(0, 3, dtype=torch.long)  
            if triples.numel():
                # 1) 
                if (triple_scene_ids is not None) and (triple_scene_ids.numel() == triples.size(0)):
                    t = triples[triple_scene_ids == sid]
                else:
                    t = triples

                if t.numel():
                    s, p, o = t.t()

                    # 2) 
                    in_scene = torch.isin(s, pos_all) & torch.isin(o, pos_all)
                    if in_scene.any():
                        s_k, p_k, o_k = s[in_scene], p[in_scene], o[in_scene]

                        # 3) remap
                        order = torch.cat([pos_new, pos_old], 0)   # (M+C,)
                        lut = -torch.ones(int(order.max().item()) + 1,
                                        dtype=torch.long, device=device)
                        lut[order] = torch.arange(order.size(0), device=device)

                        s_loc = lut[s_k]; o_loc = lut[o_k]
                        valid = (s_loc >= 0) & (o_loc >= 0)
                        if valid.any():
                            tri_local = torch.stack([s_loc[valid], p_k[valid], o_loc[valid]], 1)

            new_objs_l.append(new_obj);  new_box_l.append(new_box)
            old_objs_l.append(old_obj);  old_box_l.append(old_box)
            tri_l.append(tri_local)
            scene_id_l.append(sid)

        batch_scene_ids = torch.tensor(scene_id_l, device=device, dtype=torch.long)

        node_out, split = self.encoder(batch_scene_ids, new_objs_l, new_box_l, old_objs_l, old_box_l, tri_l)
        box_p,  shp_p   = self.decoder(node_out, split, new_objs_l, old_objs_l, tri_l)

        for (cls_idx, _, _), sid in zip(split, scene_id_l):
            self._scene_states[sid] = node_out[cls_idx-1: cls_idx].detach()

        return box_p, shp_p


    def reset_all_scene_states(self):
        self._scene_states.clear() 