import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import BatchNorm1d, Linear
from functools import partial
from torch_geometric.nn import GATConv, GCNConv, global_mean_pool
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from torch_geometric.nn import GATConv, GCNConv, BatchNorm, global_mean_pool
from models.egnn import EGNN   


# ---------- Causal/Trivial split block ----------
class CausalSplitBlock(nn.Module):
    """
    One causal-trivial split + edge gating + 2 GCN refiners.
    """
    def __init__(self, dim: int, dropout: float = 0.0):
        super().__init__()
        # node gate
        self.node_mlp = nn.Sequential(
            nn.Linear(dim, dim // 2),     
            nn.Linear(dim // 2, dim),      
            nn.Linear(dim, 2)           
        )
        # edge gate
        self.edge_mlp = Linear(dim * 2, 2)

        GConv = partial(GCNConv, normalize=True)
        self.bnc = BatchNorm1d(dim)
        self.bno = BatchNorm1d(dim)
        self.gnn_c = GConv(dim, dim)
        self.gnn_t = GConv(dim, dim)
        self.tau = nn.Parameter(torch.tensor(1.5))  # learnable scale
        self.ln = nn.LayerNorm(dim)  # for residual connection


    def forward(self, h, edge_index):
        # ----- node gating -----
        logits = self.node_mlp(self.ln(h))                    # [N,2]

        mu = logits.mean(dim=0, keepdim=True)
        sd = logits.std(dim=0, keepdim=True).clamp_min(1e-3)
        logits = (logits - mu) / sd

        tau = torch.clamp(self.tau, 0.6, 2.5)               
        alpha = F.softmax(logits / tau, dim=-1)        # [N, 2]
        # # alpha_c, alpha_t = alpha.split(1, dim=-1)
        alpha_c, alpha_t = alpha[:, :1], alpha[:, 1:]

        h_c = h * alpha_c                            # causal part
        h_t = h * alpha_t                            # trivial part

        # ----- edge gating -----
        row, col = edge_index
        e_rep = torch.cat([h[row], h[col]], dim=-1)  # [E, 2*dim]
        beta = F.softmax(self.edge_mlp(e_rep), dim=-1)
        ec, et = beta[:, :1], beta[:, 1:2]

        edge_weight_c = ec.squeeze()
        edge_weight_t = et.squeeze()

        # refine C/T by two GCNs
        h_c = F.relu(self.gnn_c(self.bnc(h_c), edge_index, edge_weight_c))
        h_t = F.relu(self.gnn_t(self.bno(h_t), edge_index, edge_weight_t))

        return h_c, h_t, alpha_t.squeeze(-1)        # for visualization


# ---------- Unified model ----------
class CausalAttentionRegressor(nn.Module):
    """
    backbone = "gat"  →  GAT + CausalSplit
    backbone = "egnn" →  EGNN + CausalSplit  
    return  y_c, y_t, z_c, z_t_layers
    """
    def __init__(self,
                 mode: str,  # "smiles", "peptide", "geometry", "fusion"
                 backbone:          str,         # "gat" or "egnn"
                 max_atomic_num:    int,
                 emb_dim:           int = 64,   
                 hidden_dim:        int = 256,   
                 edge_dim:          int = 32,    
                 num_backbone_layers: int = 3,   
                 heads:             int = 4,    
                 num_causal_blocks: int = 3,
                 dropout:           float = 0.5,
                 rho_target:        float = 0.7,
                 lambda_unif:       float = 0.5,
                 lambda_caus:       float = 0.5):
        super().__init__()
        assert backbone in ("gat", "egnn")
        self.backbone_type = backbone
        self.lambda_unif = lambda_unif
        self.lambda_caus = lambda_caus
        self.rho_target  = rho_target
        self.dropout     = dropout
        self.hidden      = hidden_dim

        # --- Embedding ---
        self.atom_embed = nn.Embedding(max_atomic_num + 1, emb_dim)
        self.bond_embed = nn.Embedding(10, edge_dim)

        # --- Backbone ---
        if backbone == "gat":
            self.convs = nn.ModuleList()
            self.bns   = nn.ModuleList()
            for i in range(num_backbone_layers):
                in_dim = emb_dim if i == 0 else hidden_dim
                self.convs.append(
                    GATConv(in_dim, hidden_dim // heads,
                            heads=heads, concat=True, edge_dim = edge_dim)
                )
                self.bns.append(nn.BatchNorm1d(hidden_dim))
        else:  # EGNN
            self.egnn = EGNN(in_node_nf=emb_dim,
                             out_node_nf=hidden_dim,
                             in_edge_nf=edge_dim,
                             hidden_nf=hidden_dim,
                             n_layers=num_backbone_layers,
                             attention=True,
                             normalize=True,
                             tanh=False)

        # --- Causal-split hierarchy ---
        self.split_blocks = nn.ModuleList([
            CausalSplitBlock(hidden_dim, dropout) for _ in range(num_causal_blocks)
        ])

        # --- Regression heads ---
        self.reg_causal = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, 1)
        )
        self.reg_trivial = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, 1)
        )

        self.heads = nn.ModuleList([nn.Linear(hidden_dim, 1) for _ in range(num_causal_blocks)])

    # ---------- forward ----------
    def forward(self, data, return_gates: bool = False):
        batch = data.batch

        h = self.atom_embed(data.z)                  # (N, emb_dim)
        h = h.flatten(1)



        # 1) backbone
        if self.backbone_type == "gat":
            
            for conv, bn in zip(self.convs, self.bns):
                h0 = h
                kwargs = {"edge_index": data.edge_index}
                if data.edge_attr is not None:
                    kwargs["edge_attr"] = self.bond_embed(data.edge_attr)
                h = F.relu(bn(conv(h, **kwargs)))
                if h.shape[1] == h0.shape[1]:
                    h = h + h0                   # residual
        else:  # EGNN
            h, _ = self.egnn(h = h, edges = data.edge_index ,x = data.pos, edge_attr = data.edge_attr)

        # 2) causal / trivial hierarchy
        z_t_layers = []
        z_c_layers = []
        for blk in self.split_blocks:
            h_c, h_t, gate = blk(h, data.edge_index)
            h = h + h_c
            z_t_layers.append(global_mean_pool(h_t, batch))
            z_c_layers.append(global_mean_pool(h, batch))

        # 3) pooling & heads
        z_c = global_mean_pool(h, batch)
        y_c = self.reg_causal(z_c).squeeze(-1)
        per_layer_scalar = [head(z) for head, z in zip(self.heads, z_t_layers)]  # list of [B,1]
        y_t = torch.stack(per_layer_scalar, dim=0).sum(dim=0)

        if return_gates:
            return y_c, y_t, z_c_layers, z_t_layers, gate
        return y_c, y_t, z_c_layers, z_t_layers

