# CGMN.py
import torch
import torch.nn as nn
import torch.nn.functional as F


class GCNLayer(nn.Module):
    def __init__(self, in_features, out_features, bias: bool = False):
        super().__init__()
        self.weight = nn.Parameter(torch.Tensor(in_features, out_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(1, out_features))
        else:
            self.register_parameter("bias", None)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight.data)
        if self.bias is not None:
            nn.init.xavier_uniform_(self.bias.data)

    def forward(self, x, adj):
        support = x @ self.weight          # [N, out]
        out = adj @ support                # [N, out]
        return out + self.bias if self.bias is not None else out


class GCN(nn.Module):
    def __init__(self, in_dim, hid_dim, out_dim):
        super().__init__()
        self.g1 = GCNLayer(in_dim, hid_dim)
        self.g2 = GCNLayer(hid_dim, out_dim)
        self.act = nn.LeakyReLU()
        self.drop = nn.Dropout()

    def forward(self, x, adj):
        # If adj is a square matrix, add self-loop and row-normalize.
        if adj.dim() == 2 and adj.size(0) == adj.size(1):
            A = adj.clone()
            A.fill_diagonal_(1.0)
            deg = A.sum(dim=1, keepdim=True).clamp_min(1.0)
            A = A / deg
        else:
            A = adj
        h = self.act(self.g1(x, A))
        h = self.drop(h)
        out = self.g2(h, A)
        return out


class _BaseCGMN(nn.Module):
    def __init__(self, in_context, in_body, in_depth, in_obj, d_obj, in_label=300, out_cat=26, in_head=None):
        super().__init__()
        self.in_context = in_context
        self.in_body = in_body
        self.in_depth = in_depth
        self.in_obj = in_obj
        self.d_obj = d_obj
        self.in_head = int(in_head or 0)
        self.in_label = in_label
        self.out_cat = out_cat

        self.obj_emb = nn.Embedding(91, self.d_obj)
        self.relu = nn.LeakyReLU()
        self.drop = nn.Dropout()

        # Two GCN branches (512->512 each), concatenated to 1024.
        self.gcn1 = GCN(self.in_label, 512, 512)
        self.gcn2 = GCN(self.in_label, 512, 512)

        # Subclasses define: self.fc_fuse, self.fc_cat

    def _pool_objects(self, f_obj, tag_obj, dist_obj):
        """
        Distance-aware soft pooling (DASOR) over 4 object patches.

        Args:
            f_obj:  [B, 4, in_obj]  - object visual features
            tag_obj: [B, 4]         - object class ids
            dist_obj: [B, 4]        - human-object distances

        Returns:
            [B, d_obj + in_obj] pooled object representation
        """
        B = f_obj.size(0)
        obj_tag = self.obj_emb(tag_obj)                        # [B, 4, d_obj]
        obj_vis = f_obj                                        # [B, 4, in_obj]
        obj_all = torch.cat([obj_tag, obj_vis], dim=2)         # [B, 4, d_obj+in_obj]
        w = torch.softmax(1.0 / (dist_obj + 1e-6), dim=1)      # [B, 4]
        pooled = torch.einsum("bkd,bk->bd", obj_all, w)        # [B, d_obj+in_obj]
        return pooled

    def _maybe_head(self, f_head, B, device):
        if self.in_head > 0 and (f_head is not None):
            return f_head.view(B, self.in_head)
        return torch.zeros(B, self.in_head, device=device)

    def _gcn_concat(self, label_emb, edge_sem, edge_cooccur):
        sem = self.gcn1(label_emb, edge_sem)           # [26, 512]
        coo = self.gcn2(label_emb, edge_cooccur)       # [26, 512]
        cat = torch.cat([sem, coo], dim=1)             # [26, 1024]
        return sem, coo, cat

    def _classify(self, img_feat, label_cat):
        x = self.fc_fuse(img_feat)                     # [B, 1024]
        x_do = self.drop(x)
        x_ = self.relu(x_do)
        logits_onehot = self.fc_cat(x_)                # [B, 26]
        # Use x (pre-activation) for label-graph dot product to match original design.
        logits_label = x @ label_cat.t()               # [B, 26]
        return (logits_onehot + logits_label) / 2.0


class CGMN(_BaseCGMN):
    """All modalities: context + body + depth + object (+ optional head)."""
    def __init__(self, in_context, in_body, in_depth, in_obj, d_obj, in_label=300, out_cat=26, in_head=None):
        super().__init__(in_context, in_body, in_depth, in_obj, d_obj, in_label, out_cat, in_head)
        fuse_in = in_context + in_body + in_depth + (in_obj + d_obj) + self.in_head
        self.fc_fuse = nn.Linear(fuse_in, 1024)
        self.fc_cat = nn.Linear(1024, out_cat)

    def forward(self, f_context, f_body, f_depth, f_obj, tag_obj, dist_obj,
                label_emb, edge_sem, edge_cooccur, f_head=None):
        B = f_body.size(0)
        c = f_context.view(B, self.in_context)
        b = f_body.view(B, self.in_body)
        d = f_depth.view(B, self.in_depth)
        o = self._pool_objects(f_obj, tag_obj, dist_obj)
        h = self._maybe_head(f_head, B, b.device)
        img = torch.cat([c, b, d, o, h], dim=1)
        sem, coo, cat = self._gcn_concat(label_emb, edge_sem, edge_cooccur)
        logits = self._classify(img, cat)
        if self.in_head > 0 and (f_head is not None):
            return logits, sem, coo, cat
        return logits, sem, coo


class CGMN_fs(_BaseCGMN):
    """Drop BODY: keep context + depth + object (+ optional head)."""
    def __init__(self, in_context, in_body, in_depth, in_obj, d_obj, in_label=300, out_cat=26, in_head=None):
        super().__init__(in_context, in_body, in_depth, in_obj, d_obj, in_label, out_cat, in_head)
        fuse_in = in_context + in_depth + (in_obj + d_obj) + self.in_head
        self.fc_fuse = nn.Linear(fuse_in, 1024)
        self.fc_cat = nn.Linear(1024, out_cat)

    def forward(self, f_context, f_body, f_depth, f_obj, tag_obj, dist_obj,
                label_emb, edge_sem, edge_cooccur, f_head=None):
        B = f_context.size(0)
        c = f_context.view(B, self.in_context)
        d = f_depth.view(B, self.in_depth)
        o = self._pool_objects(f_obj, tag_obj, dist_obj)
        h = self._maybe_head(f_head, B, c.device)
        img = torch.cat([c, d, o, h], dim=1)
        sem, coo, cat = self._gcn_concat(label_emb, edge_sem, edge_cooccur)
        logits = self._classify(img, cat)
        if self.in_head > 0 and (f_head is not None):
            return logits, sem, coo, cat
        return logits, sem, coo


class CGMN_fo(_BaseCGMN):
    """Drop OBJECT: keep context + body + depth (+ optional head)."""
    def __init__(self, in_context, in_body, in_depth, in_obj, d_obj, in_label=300, out_cat=26, in_head=None):
        super().__init__(in_context, in_body, in_depth, in_obj, d_obj, in_label, out_cat, in_head)
        fuse_in = in_context + in_body + in_depth + self.in_head
        self.fc_fuse = nn.Linear(fuse_in, 1024)
        self.fc_cat = nn.Linear(1024, out_cat)

    def forward(self, f_context, f_body, f_depth, f_obj, tag_obj, dist_obj,
                label_emb, edge_sem, edge_cooccur, f_head=None):
        B = f_body.size(0)
        c = f_context.view(B, self.in_context)
        b = f_body.view(B, self.in_body)
        d = f_depth.view(B, self.in_depth)
        h = self._maybe_head(f_head, B, b.device)
        img = torch.cat([c, b, d, h], dim=1)
        sem, coo, cat = self._gcn_concat(label_emb, edge_sem, edge_cooccur)
        logits = self._classify(img, cat)
        if self.in_head > 0 and (f_head is not None):
            return logits, sem, coo, cat
        return logits, sem, coo


class CGMN_fgs(_BaseCGMN):
    """Drop CONTEXT: keep body + depth + object (+ optional head)."""
    def __init__(self, in_context, in_body, in_depth, in_obj, d_obj, in_label=300, out_cat=26, in_head=None):
        super().__init__(in_context, in_body, in_depth, in_obj, d_obj, in_label, out_cat, in_head)
        fuse_in = in_body + in_depth + (in_obj + d_obj) + self.in_head
        self.fc_fuse = nn.Linear(fuse_in, 1024)
        self.fc_cat = nn.Linear(1024, out_cat)

    def forward(self, f_context, f_body, f_depth, f_obj, tag_obj, dist_obj,
                label_emb, edge_sem, edge_cooccur, f_head=None):
        B = f_body.size(0)
        b = f_body.view(B, self.in_body)
        d = f_depth.view(B, self.in_depth)
        o = self._pool_objects(f_obj, tag_obj, dist_obj)
        h = self._maybe_head(f_head, B, b.device)
        img = torch.cat([b, d, o, h], dim=1)
        sem, coo, cat = self._gcn_concat(label_emb, edge_sem, edge_cooccur)
        logits = self._classify(img, cat)
        if self.in_head > 0 and (f_head is not None):
            return logits, sem, coo, cat
        return logits, sem, coo


class CGMN_fgd(_BaseCGMN):
    """Drop DEPTH: keep context + body + object (+ optional head)."""
    def __init__(self, in_context, in_body, in_depth, in_obj, d_obj, in_label=300, out_cat=26, in_head=None):
        super().__init__(in_context, in_body, in_depth, in_obj, d_obj, in_label, out_cat, in_head)
        fuse_in = in_context + in_body + (in_obj + d_obj) + self.in_head
        self.fc_fuse = nn.Linear(fuse_in, 1024)
        self.fc_cat = nn.Linear(1024, out_cat)

    def forward(self, f_context, f_body, f_depth, f_obj, tag_obj, dist_obj,
                label_emb, edge_sem, edge_cooccur, f_head=None):
        B = f_body.size(0)
        c = f_context.view(B, self.in_context)
        b = f_body.view(B, self.in_body)
        o = self._pool_objects(f_obj, tag_obj, dist_obj)
        h = self._maybe_head(f_head, B, b.device)
        img = torch.cat([c, b, o, h], dim=1)
        sem, coo, cat = self._gcn_concat(label_emb, edge_sem, edge_cooccur)
        logits = self._classify(img, cat)
        if self.in_head > 0 and (f_head is not None):
            return logits, sem, coo, cat
        return logits, sem, coo


class CGMN_wo_sem_1024_noloss(nn.Module):
    """
    w/o label similarity: remove the semantic-similarity GCN, keep only co-occur GCN,
    set the GCN output dim to 1024 and fuse via dot product with image features.
    Returns (logits, None, None) so label loss is not computed.
    """
    def __init__(self, in_context, in_body, in_depth, in_obj, d_obj,
                 in_label=300, out_cat=26, in_head=None):
        super().__init__()
        self.in_context, self.in_body, self.in_depth = in_context, in_body, in_depth
        self.in_obj, self.d_obj = in_obj, d_obj
        self.in_head = int(in_head or 0)
        self.out_cat = out_cat
        self.in_label = in_label

        self.obj_emb = nn.Embedding(91, self.d_obj)
        self.relu = nn.LeakyReLU()
        self.drop = nn.Dropout()

        fuse_in = in_context + in_body + in_depth + (in_obj + d_obj) + self.in_head
        self.fc_fuse = nn.Linear(fuse_in, 1024)
        self.fc_cat  = nn.Linear(1024, out_cat)

        # Co-occur GCN only, directly outputs 1024-d.
        self.gcn_coo = GCN(self.in_label, 512, 1024)

    def forward(self, f_context, f_body, f_depth, f_obj, tag_obj, dist_obj,
                label_emb, edge_sem, edge_cooccur, f_head=None):
        B = f_body.size(0); device = f_body.device
        c = f_context.view(B, self.in_context)
        b = f_body.view(B, self.in_body)
        d = f_depth.view(B, self.in_depth)

        # DASOR
        obj_tag = self.obj_emb(tag_obj)
        obj_vis = f_obj.view(B, 4, self.in_obj)
        obj_all = torch.cat([obj_tag, obj_vis], dim=2)
        w = torch.softmax(1.0 / (dist_obj + 1e-6), dim=1)
        o = torch.einsum("bkd,bk->bd", obj_all, w)      # [B, d_obj+in_obj]

        h = f_head.view(B, self.in_head) if (self.in_head > 0 and f_head is not None) \
            else torch.zeros(B, self.in_head, device=device)

        img = torch.cat([c, b, d, o, h], dim=1)         # [B, fuse_in]
        x_lin = self.fc_fuse(img)                       # [B, 1024]
        x_do  = self.drop(x_lin)
        x_act = self.relu(x_do)
        logits_one_hot = self.fc_cat(x_act)             # [B, 26]

        emb_coo = self.gcn_coo(label_emb, edge_cooccur) # [26, 1024]
        logits_label = x_lin @ emb_coo.t()              # [B, 26]

        logits = (logits_one_hot + logits_label) / 2.0
        return logits, None, None


class CGMN_wo_cooccur_1024_noloss(nn.Module):
    """
    w/o label co-occur: remove the co-occurrence GCN, keep only semantic-similarity GCN,
    set the GCN output dim to 1024 and fuse via dot product with image features.
    Returns (logits, None, None) so label loss is not computed.
    """
    def __init__(self, in_context, in_body, in_depth, in_obj, d_obj,
                 in_label=300, out_cat=26, in_head=None):
        super().__init__()
        self.in_context, self.in_body, self.in_depth = in_context, in_body, in_depth
        self.in_obj, self.d_obj = in_obj, d_obj
        self.in_head = int(in_head or 0)
        self.out_cat = out_cat
        self.in_label = in_label

        self.obj_emb = nn.Embedding(91, self.d_obj)
        self.relu = nn.LeakyReLU()
        self.drop = nn.Dropout()

        fuse_in = in_context + in_body + in_depth + (in_obj + d_obj) + self.in_head
        self.fc_fuse = nn.Linear(fuse_in, 1024)
        self.fc_cat  = nn.Linear(1024, out_cat)

        # Semantic-similarity GCN only, directly outputs 1024-d.
        self.gcn_sem = GCN(self.in_label, 512, 1024)

    def forward(self, f_context, f_body, f_depth, f_obj, tag_obj, dist_obj,
                label_emb, edge_sem, edge_cooccur, f_head=None):
        B = f_body.size(0); device = f_body.device
        c = f_context.view(B, self.in_context)
        b = f_body.view(B, self.in_body)
        d = f_depth.view(B, self.in_depth)

        # DASOR
        obj_tag = self.obj_emb(tag_obj)
        obj_vis = f_obj.view(B, 4, self.in_obj)
        obj_all = torch.cat([obj_tag, obj_vis], dim=2)
        w = torch.softmax(1.0 / (dist_obj + 1e-6), dim=1)
        o = torch.einsum("bkd,bk->bd", obj_all, w)      # [B, d_obj+in_obj]

        h = f_head.view(B, self.in_head) if (self.in_head > 0 and f_head is not None) \
            else torch.zeros(B, self.in_head, device=device)

        img = torch.cat([c, b, d, o, h], dim=1)         # [B, fuse_in]
        x_lin = self.fc_fuse(img)                       # [B, 1024]
        x_do  = self.drop(x_lin)
        x_act = self.relu(x_do)
        logits_one_hot = self.fc_cat(x_act)             # [B, 26]

        emb_sem = self.gcn_sem(label_emb, edge_sem)     # [26, 1024]
        logits_label = x_lin @ emb_sem.t()              # [B, 26]

        logits = (logits_one_hot + logits_label) / 2.0
        return logits, None, None


CGMN_Basic = CGMN


class CGMN_darn(_BaseCGMN):
    """w/o DASOR: object pooling without distance weighting (simple sum)."""
    def __init__(self, in_context, in_body, in_depth, in_obj, d_obj,
                 in_label=300, out_cat=26, in_head=None):
        super().__init__(in_context, in_body, in_depth, in_obj, d_obj, in_label, out_cat, in_head)
        fuse_in = in_context + in_body + in_depth + (in_obj + d_obj) + self.in_head
        self.fc_fuse = nn.Linear(fuse_in, 1024)
        self.fc_cat  = nn.Linear(1024, out_cat)

    def forward(self, f_context, f_body, f_depth, f_obj, tag_obj, dist_obj,
                label_emb, edge_sem, edge_cooccur, f_head=None):
        B = f_body.size(0)

        # Base modalities
        c = f_context.view(B, self.in_context)
        b = f_body.view(B, self.in_body)
        d = f_depth.view(B, self.in_depth)

        # Key change: object pooling without distance weighting (sum over 4 patches).
        obj_tag = self.obj_emb(tag_obj)                 # [B, 4, d_obj]
        obj_vis = f_obj                                 # [B, 4, in_obj]
        obj_all = torch.cat([obj_tag, obj_vis], 2)      # [B, 4, d_obj+in_obj]
        o = obj_all.sum(dim=1)                          # [B, d_obj+in_obj]

        # Optional head
        h = self._maybe_head(f_head, B, b.device)

        # Fusion & graph
        img = torch.cat([c, b, d, o, h], dim=1)         # [B, fuse_in]
        sem, coo, cat = self._gcn_concat(label_emb, edge_sem, edge_cooccur)

        # Two-branch fusion: (MLP + Graph) / 2
        x   = self.fc_fuse(img)                         # [B, 1024]
        x_d = self.drop(x)
        x_  = self.relu(x_d)
        logits_mlp   = self.fc_cat(x_)                  # [B, 26]
        logits_graph = x @ cat.t()                      # [B, 26]
        pred_cat = (logits_mlp + logits_graph) / 2.0

        return pred_cat, sem, coo


class CGMN_hegrold(nn.Module):
    """
    Legacy HEGR ablation: remove GCN branches and keep only MLP head.
    Returns (logits, None, None) to disable label loss in training.
    """
    def __init__(self, in_context, in_body, in_depth, in_obj, d_obj,
                 in_label=300, out_cat=26, in_head=None):
        super().__init__()
        self.in_context, self.in_body, self.in_depth = in_context, in_body, in_depth
        self.in_obj, self.d_obj = in_obj, d_obj
        self.in_head = int(in_head or 0)
        self.out_cat = out_cat

        self.obj_emb = nn.Embedding(91, self.d_obj)
        self.relu = nn.LeakyReLU()
        self.drop = nn.Dropout()

        fuse_in = in_context + in_body + in_depth + (in_obj + d_obj) + self.in_head
        self.fc_fuse = nn.Linear(fuse_in, 1024)
        self.fc_cat  = nn.Linear(1024, out_cat)

    def forward(self, f_context, f_body, f_depth, f_obj, tag_obj, dist_obj,
                label_emb, edge_sem, edge_cooccur, f_head=None):
        B = f_body.size(0); device = f_body.device
        c = f_context.view(B, self.in_context)
        b = f_body.view(B, self.in_body)
        d = f_depth.view(B, self.in_depth)

        # DASOR distance-weighted pooling.
        obj_tag = self.obj_emb(tag_obj)                 # [B, 4, d_obj]
        obj_vis = f_obj.view(B, 4, self.in_obj)         # [B, 4, in_obj]
        obj_all = torch.cat([obj_tag, obj_vis], dim=2)  # [B, 4, d_obj+in_obj]
        w = torch.softmax(1.0 / (dist_obj + 1e-6), dim=1)
        o = torch.einsum("bkd,bk->bd", obj_all, w)      # [B, d_obj+in_obj]

        h = f_head.view(B, self.in_head) if (self.in_head > 0 and f_head is not None) \
            else torch.zeros(B, self.in_head, device=device)

        img = torch.cat([c, b, d, o, h], dim=1)
        x = self.fc_fuse(img)
        x = self.drop(x)
        x = self.relu(x)
        logits = self.fc_cat(x)
        return logits, None, None


class CGMN_hegr(nn.Module):
    """
    HEGR ablation: remove GCN branches entirely; classify using only visual features
    (context, body, depth, object, optional head).
    Returns (logits, None, None) to disable label loss in training.
    """
    def __init__(self, in_context, in_body, in_depth, in_obj, d_obj,
                 in_label=300, out_cat=26, in_head=None, dropout_p: float = 0.3):
        super().__init__()
        self.in_context, self.in_body, self.in_depth = in_context, in_body, in_depth
        self.in_obj, self.d_obj = in_obj, d_obj
        self.in_head = int(in_head or 0)
        self.out_cat = out_cat

        self.obj_emb = nn.Embedding(91, self.d_obj)
        self.relu = nn.LeakyReLU()
        self.drop = nn.Dropout(dropout_p)

        # Visual fusion input dim = context + body + depth + (object_vis + object_tag) + head
        fuse_in = in_context + in_body + in_depth + (in_obj + d_obj) + self.in_head
        self.fc_fuse = nn.Linear(fuse_in, 1024)
        self.fc_cat  = nn.Linear(1024, out_cat)

    @torch.no_grad()
    def _dasor_weight(self, dist_obj):
        # DASOR: softmax over inverse distances.
        inv = 1.0 / (dist_obj + 1e-6)      # [B, 4]
        return torch.softmax(inv, dim=1)   # [B, 4]

    def _pool_objects(self, f_obj, tag_obj, dist_obj):
        """
            f_obj:   [B, 4, in_obj]   object visual features
            tag_obj: [B, 4]           object class ids
            dist_obj:[B, 4]           human-object distances
            return:  [B, d_obj + in_obj]
        """
        B = f_obj.size(0)
        obj_tag = self.obj_emb(tag_obj)                  # [B, 4, d_obj]
        obj_vis = f_obj.view(B, 4, self.in_obj)          # [B, 4, in_obj]
        obj_all = torch.cat([obj_tag, obj_vis], dim=2)   # [B, 4, d_obj+in_obj]
        w = self._dasor_weight(dist_obj)                 # [B, 4]
        return torch.einsum("bkd,bk->bd", obj_all, w)    # [B, d_obj+in_obj]

    def forward(self, f_context, f_body, f_depth, f_obj, tag_obj, dist_obj,
                label_emb=None, edge_sem=None, edge_cooccur=None, f_head=None):
        B = f_body.size(0); device = f_body.device
        c = f_context.view(B, self.in_context)
        b = f_body.view(B, self.in_body)
        d = f_depth.view(B, self.in_depth)
        o = self._pool_objects(f_obj, tag_obj, dist_obj)

        if self.in_head > 0 and (f_head is not None):
            h = f_head.view(B, self.in_head)
        else:
            h = torch.zeros(B, self.in_head, device=device)

        img = torch.cat([c, b, d, o, h], dim=1)         # [B, fuse_in]
        x = self.fc_fuse(img)                           # [B, 1024]
        x = self.drop(x)
        x = self.relu(x)
        logits = self.fc_cat(x)                         # [B, 26]
        return logits, None, None


class CGMN_mvec(_BaseCGMN):
    """w/o MLP: remove the MLP classification head; score using the graph branch only."""
    def __init__(self, in_context, in_body, in_depth, in_obj, d_obj,
                 in_label=300, out_cat=26, in_head=None):
        super().__init__(in_context, in_body, in_depth, in_obj, d_obj, in_label, out_cat, in_head)
        fuse_in = in_context + in_body + in_depth + (in_obj + d_obj) + self.in_head
        self.fc_fuse = nn.Linear(fuse_in, 1024)
        # No fc_cat here to avoid unused parameters.

    def forward(self, f_context, f_body, f_depth, f_obj, tag_obj, dist_obj,
                label_emb, edge_sem, edge_cooccur, f_head=None):
        B = f_body.size(0)

        # Base modalities + DASOR pooling
        c = f_context.view(B, self.in_context)
        b = f_body.view(B, self.in_body)
        d = f_depth.view(B, self.in_depth)
        o = self._pool_objects(f_obj, tag_obj, dist_obj)

        # Optional head
        h = self._maybe_head(f_head, B, b.device)

        # Fusion & graph
        img = torch.cat([c, b, d, o, h], dim=1)
        sem, coo, cat = self._gcn_concat(label_emb, edge_sem, edge_cooccur)

        # Graph-only scoring
        x = self.fc_fuse(img)           # [B, 1024]
        pred_cat = x @ cat.t()          # [B, 26]

        return pred_cat, sem, coo
