import torch
import torch.nn as nn
from typing import List, Tuple
from model.graph import GraphTripleConvNet, _init_weights, make_mlp


class Sg2ScAEModelBaeline2(nn.Module):
    def __init__(self, vocab,
                 node_embedding_dim=128,
                 gconv_pooling="avg",
                 gconv_num_layers=3,
                 mlp_normalization="none",
                 num_box_params=7,
                 shape_input_dim=256,
                 residual=True):
        super().__init__()
        D = node_embedding_dim
        self.D = D
        self.shape_input_dim = shape_input_dim

        n_obj = len(set(vocab['object_idx_to_name'])) + 1
        n_rel = len(set(vocab['pred_idx_to_name'])) + 1

        # ----- Embeddings -----
        self.ec_emb_obj  = nn.Embedding(n_obj, D)
        self.ec_emb_pred = nn.Embedding(n_rel, 2 * D)   # encoder 
        self.dc_emb_obj  = nn.Embedding(n_obj, D)
        self.dc_emb_pred = nn.Embedding(n_rel, 3 * D)   # decoder 
        self.emb_box     = nn.Linear(7, D)              # box -> D

        # ----- Encoder GCN ( 2D = D(sem) + D(box)) -----
        enc_cfg = dict(
            input_dim_obj=2 * D,
            input_dim_pred=2 * D,
            hidden_dim=4 * D,
            pooling=gconv_pooling,
            num_layers=gconv_num_layers,
            mlp_normalization=mlp_normalization,
            residual=residual
        )
        self.gcn_enc = GraphTripleConvNet(**enc_cfg)

        # ----- Decoder GCN (3D = D(sem) + 2D(from enc)) -----
        dec_in = 3 * D
        dec_cfg = dict(
            input_dim_obj=dec_in,
            input_dim_pred=3 * D,
            hidden_dim=4 * D,
            pooling=gconv_pooling,
            num_layers=gconv_num_layers,
            mlp_normalization=mlp_normalization,
            residual=residual
        )
        self.gcn_dec = GraphTripleConvNet(**dec_cfg)

        self.mlp_box = make_mlp(
            [dec_in, 4 * D, num_box_params],
            mlp_normalization, norelu=True
        )
        self.mlp_shape = make_mlp(
            [dec_in + D, 5 * D, 3 * D, shape_input_dim],
            mlp_normalization, norelu=True
        )

        self.apply(_init_weights)

    # =========================================================
    def encoder(self,
                batch_scene_ids,           # (S,)
                new_objs_l, new_box_l,     # List[(M_i,), (M_i,7)]
                old_objs_l, old_box_l,     # List[(C_i,), (C_i,7)]
                tri_l):                    # List[(T_i,3)]  
        device = new_objs_l[0].device
        node_chunks, tri_chunks, split = [], [], []
        offset = 0

        for new_obj, new_box, old_obj, old_box, tri_local in zip(new_objs_l, new_box_l, old_objs_l, old_box_l, tri_l):

            sem_new = self.ec_emb_obj(new_obj)
            bft_new = self.emb_box(new_box)
            node_new = torch.cat([sem_new, bft_new], -1)             # (M,2D)

            if old_obj.numel():
                sem_old = self.ec_emb_obj(old_obj)
                bft_old = self.emb_box(old_box)
                node_old = torch.cat([sem_old, bft_old], -1)         # (C,2D)
                node_scene = torch.cat([node_new, node_old], 0)      # (M+C, 2D)
            else:
                node_scene = node_new                                 # (M, 2D)

            node_chunks.append(node_scene)

            if tri_local.numel():
                s, p, o = tri_local.t()
                tri_chunks.append(torch.stack([s + offset, p, o + offset], 1))

            M, C = new_box.size(0), old_box.size(0)
            split.append((offset, M, C))
            offset += (M + C)

        node_in = torch.cat(node_chunks, 0) if node_chunks else torch.empty(0, 2 * self.D, device=device)

        if tri_chunks:
            triples = torch.cat(tri_chunks, 0)
            s, p, o = triples.t()
            edges = torch.stack([s, o], 1)
            pred  = self.ec_emb_pred(p)
            node_out, _ = self.gcn_enc(node_in, pred, edges)
        else:
            node_out = node_in

        return node_out, split

    # =========================================================
    def decoder(self,
                node_out,                  # (N_tot, 2D)
                split,                     # [(start, M, C), ...]
                new_objs_l,                # List[(M_i,)]
                old_objs_l,                # List[(C_i,)]
                tri_l):                    # List[(T_i,3)] 0..M+C-1
        device = node_out.device
        dec_node_chunks = []
        edge_chunks, pid_chunks = [], []
        new_ranges = []                    # (start_in_big, length)

        offset = 0
        for (start, M, C), new_obj, old_obj, tri_local in \
                zip(split, new_objs_l, old_objs_l, tri_l):

            # 1) 
            new_vec = node_out[start : start + M]                         # (M,2D)
            old_vec = node_out[start + M : start + M + C] if C > 0 else node_out.new_empty(0, node_out.size(1))            # (C,2D)

            # 2) 3D = D(sem) + 2D(enc)
            sem_new = self.dc_emb_obj(new_obj)                            # (M,D)
            sem_old = self.dc_emb_obj(old_obj) if old_obj.numel() else sem_new.new_empty(0, sem_new.size(1))               # (C,D)

            dec_in_scene = torch.cat([
                torch.cat([sem_new, new_vec], -1),                        # (M,3D)
                torch.cat([sem_old, old_vec], -1)                         # (C,3D)
            ], 0)                                                          # (M+C,3D)
            dec_node_chunks.append(dec_in_scene)

            # 3)tri_local（0..M+C-1）
            if tri_local.numel():
                s, p, o = tri_local.t()
                edge_chunks.append(torch.stack([s + offset, o + offset], 1))
                pid_chunks.append(p)

            new_ranges.append((offset, M))
            offset += (M + C)

        dec_in_all = torch.cat(dec_node_chunks, 0) if dec_node_chunks else torch.empty(0, 3 * self.D, device=device)
        if edge_chunks:
            edges_all = torch.cat(edge_chunks, 0)
            pid_all   = torch.cat(pid_chunks, 0)
            pred_vec  = self.dc_emb_pred(pid_all)
            dec_vec_all, _ = self.gcn_dec(dec_in_all, pred_vec, edges_all)
        else:
            dec_vec_all = dec_in_all

        box_chunks, shp_chunks = [], []
        for st, ln in new_ranges:
            if ln == 0: continue
            dec_vec_new = dec_vec_all[st : st + ln]                       # (ln, 2D+? )
            box_p  = self.mlp_box(dec_vec_new)                            # (ln,7)
            shp_in = torch.cat([dec_vec_new, self.emb_box(box_p)], -1)    # (ln, 3D + D)
            shp_p  = self.mlp_shape(shp_in)                               # (ln,256)
            box_chunks.append(box_p); shp_chunks.append(shp_p)

        return torch.cat(box_chunks, 0), torch.cat(shp_chunks, 0)

    # =========================================================
    def forward_batch_step(self,
                           obj_batch_scene_ids,   # (N_step_obj,)
                           objs, boxes,           # (N_step_obj,), (N_step_obj,7)
                           triples,               # (T_step,3) 
                           new_mask,              # (N_step_obj,)
                         obj_indices,          # (N_step_obj,) 
                         triple_scene_ids=None  ):     
        device = objs.device

        uniq_scene, seen = [], set()
        for sid in obj_batch_scene_ids.tolist():
            if sid not in seen:
                uniq_scene.append(sid); seen.add(sid)

        new_objs_l, new_box_l = [], []
        old_objs_l, old_box_l = [], []
        tri_l, scene_id_l     = [], []

        for sid in uniq_scene:
            sel_scene = (obj_batch_scene_ids == sid)
            if not sel_scene.any(): 
                continue

            obj_loc  = objs [sel_scene]
            box_loc  = boxes[sel_scene]
            new_loc  = new_mask[sel_scene]
            if not new_loc.any():                 
                continue
            old_loc  = ~new_loc

            new_obj, new_box = obj_loc[new_loc], box_loc[new_loc]      # (M,), (M,7)
            old_obj, old_box = obj_loc[old_loc], box_loc[old_loc]      # (C,), (C,7)

            pos_all = obj_indices[sel_scene]        # (N_scene,)
            pos_new = pos_all[new_loc]              # (M,)
            pos_old = pos_all[old_loc]              # (C,)

            # ----  tri_local ----
            tri_local = objs.new_empty(0, 3, dtype=torch.long) 
            if triples.numel():
                # 1) 
                if (triple_scene_ids is not None) and (triple_scene_ids.numel() == triples.size(0)):
                    t = triples[triple_scene_ids == sid]
                else:
                    t = triples

                if t.numel():
                    s, p, o = t.t()

                    # 2) 
                    in_scene = torch.isin(s, pos_all) & torch.isin(o, pos_all)
                    if in_scene.any():
                        s_k, p_k, o_k = s[in_scene], p[in_scene], o[in_scene]

                        # 3) remap
                        order = torch.cat([pos_new, pos_old], 0)   # (M+C,)
                        lut = -torch.ones(int(order.max().item()) + 1,
                                        dtype=torch.long, device=device)
                        lut[order] = torch.arange(order.size(0), device=device)

                        s_loc = lut[s_k]; o_loc = lut[o_k]
                        valid = (s_loc >= 0) & (o_loc >= 0)
                        if valid.any():
                            tri_local = torch.stack([s_loc[valid], p_k[valid], o_loc[valid]], 1)


            new_objs_l.append(new_obj);  new_box_l.append(new_box)
            old_objs_l.append(old_obj);  old_box_l.append(old_box)
            tri_l.append(tri_local)
            scene_id_l.append(sid)

        node_out, split = self.encoder(None, new_objs_l, new_box_l, old_objs_l, old_box_l, tri_l)
        box_p, shp_p    = self.decoder(node_out, split, new_objs_l, old_objs_l, tri_l)
        return box_p, shp_p
