import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import BatchNorm1d, Linear
from functools import partial
from torch_geometric.nn import GATConv, GCNConv, global_mean_pool

class GATCausalAttentionRegressor(nn.Module):
    """
    GAT-based causal attention regressor with residuals and MLP readout.
    """
    def __init__(
        self,
        max_atomic_num: int,
        emb_dim: int = 64,
        hidden_dim: int = 256,
        num_gc_layers: int = 7,
        dropout: float = 0.5,
        heads: int = 4,
        lambda_unif: float = 0.5,
        lambda_caus: float = 0.5,
        global_mean: float = 0.0
    ):
        super().__init__()
        self.global_mean = global_mean
        self.lambda_unif = lambda_unif
        self.lambda_caus = lambda_caus
        self.hidden = hidden_dim
        self.dropout = dropout

        # atom embedding
        self.atom_embedding = nn.Embedding(max_atomic_num + 1, emb_dim)

        # GAT backbone
        self.convs = nn.ModuleList()
        self.bns   = nn.ModuleList()
        for i in range(num_gc_layers):
            in_dim = emb_dim if i == 0 else hidden_dim
            self.convs.append(
                GATConv(in_dim, hidden_dim // heads, heads=heads, concat=True)
            )
            self.bns.append(nn.BatchNorm1d(hidden_dim))

        # node-level attention MLP
        self.node_mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 2)
        )
        self.edge_att_mlp = Linear(hidden_dim * 2, 2)

        GConv = partial(GCNConv, normalize=True)
        self.bnc = BatchNorm1d(hidden_dim)
        self.bno = BatchNorm1d(hidden_dim)
        self.context_convs = GConv(hidden_dim, hidden_dim)
        self.objects_convs = GConv(hidden_dim, hidden_dim)


        # regressors for causal & trivial branches
        self.reg_causal  = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, 1)
        )
        self.reg_trivial = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, 1)
        )

    def forward(self, data, return_gates: bool = False):
        x_idx = data.x.squeeze()

        edge_index = data.edge_index
        batch = data.batch

        # embedding
        h = self.atom_embedding(x_idx)

        # GAT layers with residuals
        for conv, bn in zip(self.convs, self.bns):
            h_in = h
            h = conv(h, edge_index)
            h = bn(h)
            h = F.relu(h)
            h = F.dropout(h, p=self.dropout, training=self.training)
            if h.size(1) == h_in.size(1):
                h = h + h_in  # residual connection

        # node-level attention scores
        logits = self.node_mlp(h)
        
        alpha = F.softmax(logits, dim=-1)
        alpha_c, alpha_t = alpha[:, :1], alpha[:, 1:]

        # split features
        h_c = h * alpha_c
        h_t = h * alpha_t


        # edge causal/trivial split
        row, col = edge_index
        e_rep = torch.cat([h[row], h[col]], dim=-1)
        logits_e = self.edge_att_mlp(e_rep)
        alpha_e = F.softmax(logits_e, dim=-1)
        ec, et = alpha_e[:, :1], alpha_e[:, 1:2]


        # apply GCN with edge weights
        edge_weight_c = ec.squeeze()
        edge_weight_o = et.squeeze()

        h_c = F.relu(self.context_convs(self.bnc(h_c), edge_index, edge_weight_c))
        h_t = F.relu(self.objects_convs(self.bno(h_t), edge_index, edge_weight_o))



        # global pooling
        z_c = global_mean_pool(h_c, batch)
        z_t = global_mean_pool(h_t, batch)

        # predictions
        y_c = self.reg_causal(z_c).squeeze()
        y_t = self.reg_trivial(z_t).squeeze()


        if return_gates:
            # gate = alpha_c / (alpha_c.max() + 1e-8)          # [N]
            return y_c, y_t, z_c, z_t, alpha_c.squeeze(), alpha_t.squeeze()


        return y_c, y_t, z_c, z_t

    def pearson_corr(self, x, y, eps=1e-8):
        if x.numel() <= 1 or y.numel() <= 1:
            return x.new_zeros(())
        vx = x - x.mean()
        vy = y - y.mean()
        std_x = vx.std(unbiased=False)
        std_y = vy.std(unbiased=False)
        if std_x < eps or std_y < eps:
            return x.new_zeros(())
        return (vx * vy).mean() / (std_x * std_y + eps)




    def loss(self, y_c, y_t, z_c, z_t, y_true):

        # 1) Coarse‑grained path loss (weak supervision)
        loss_short = F.mse_loss(y_t, y_true)

        rho_target = 0.7

        corr = self.pearson_corr(y_t, y_true)
        loss_corr = (corr - rho_target) ** 2

        # Total coarse loss
        loss_coarse = loss_corr + loss_short #+ 

        # 3) Fine‑grained path loss (learning the residual: y_true − y_t)
        residual = y_true - y_t.detach() #detach or not

        # loss_fine = F.mse_loss(y_c, residual)
        loss_fine = F.mse_loss(y_c, residual)

        # 4) Final fused prediction loss (coarse + fine)
        y_final = y_t + y_c
        loss_final = F.mse_loss(y_final, y_true)

        # 6) Combine all losses with hyperparameter weights
        total_loss = (
            loss_final
            + 1.0 * loss_coarse
            + 1.0 * loss_fine
        )

        return total_loss

