import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import BatchNorm1d, Linear
from functools import partial
from torch_geometric.nn import GATConv, GCNConv, global_mean_pool


class SequenceCausalAttentionRegressor(nn.Module):

    def __init__(
        self,
        vocab_size: int,
        emb_dim: int = 96,
        hidden_dim: int = 256,
        nhead: int = 4,
        nlayers: int = 3,
        dropout: float = 0.1,
        lambda_unif: float = 0.5,
        lambda_caus: float = 0.5,
        global_mean: float = 0.0,
    ):
        super().__init__()
        self.lambda_unif, self.lambda_caus = lambda_unif, lambda_caus
        self.global_mean = global_mean

        # embedding & backbone
        self.embed = nn.Embedding(vocab_size, emb_dim, padding_idx=0)

        print('emb_dim',emb_dim)
        print('hidden_dim',hidden_dim)
        self.proj  = nn.Linear(emb_dim, hidden_dim) if emb_dim != hidden_dim else nn.Identity()
        enc_layer  = nn.TransformerEncoderLayer(
            hidden_dim, nhead, hidden_dim * 2, dropout, batch_first=True
        )
        self.encoder = nn.TransformerEncoder(enc_layer, nlayers)

        # region convs
        self.conv1 = nn.Conv1d(hidden_dim, hidden_dim, 1)
        self.conv3 = nn.Conv1d(hidden_dim, hidden_dim, 3, padding=1)
        self.conv9 = nn.Conv1d(hidden_dim, hidden_dim, 9, padding=4)

        # regressors
        self.reg_causal  = nn.Sequential(nn.Linear(hidden_dim, hidden_dim//2),
                                         nn.ReLU(), nn.Dropout(dropout),
                                         nn.Linear(hidden_dim//2, 1))
        self.reg_trivial = nn.Sequential(nn.Linear(hidden_dim, hidden_dim//2),
                                         nn.ReLU(), nn.Dropout(dropout),
                                         nn.Linear(hidden_dim//2, 1))

        self.split_ways = nn.Linear(hidden_dim, 2)  # [B,L,2], for gate conv

    def forward(self, batch, return_gates: bool = False):

        ids   = batch['input_ids']
        pad_mask = ids == 0              

        # contextual reps
        h = self.proj(self.embed(ids))                      # [B,L,H]
        h = self.encoder(h, src_key_padding_mask=pad_mask)  # [B,L,H]

        # region convs
        h3 = h.permute(0,2,1)                               # [B,H,L]
        g1  = self.conv1(h3)
        g3 = self.conv3(h3)
        g9 = self.conv9(h3)
        gate = torch.sigmoid(g1 + g3 + g9).permute(0,2,1)    # [B,L,H]

        alpha_c = gate
        alpha_t = 1 - gate

        h_c = h * alpha_c      # causal path
        h_t = h * alpha_t      # trivial path
           

        # mean-pool with mask
        mask = (~pad_mask).unsqueeze(-1)        # [B,L,1]
        denom = mask.sum(dim=1).clamp(min=1)    # [B,1]
        z_c = (h_c * mask).sum(dim=1) / denom   # [B,H]
        z_t = (h_t * mask).sum(dim=1) / denom   # [B,H]


        y_c = self.reg_causal(z_c).squeeze(-1)  # [B]
        y_t = self.reg_trivial(z_t).squeeze(-1) # [B]

        if return_gates:
            # gate = alpha_c / (alpha_c.max() + 1e-8)          # [N]
            return y_c, y_t, z_c, z_t, alpha_c.squeeze(), alpha_t.squeeze()

        return y_c, y_t, z_c, z_t

    def pearson_corr(self, x, y, eps=1e-8):
        if x.numel() <= 1 or y.numel() <= 1:
            return x.new_zeros(())
        vx = x - x.mean()
        vy = y - y.mean()
        std_x = vx.std(unbiased=False)
        std_y = vy.std(unbiased=False)
        if std_x < eps or std_y < eps:
            return x.new_zeros(())
        return (vx * vy).mean() / (std_x * std_y + eps)




    def loss(self, y_c, y_t, z_c, z_t, y_true):

        # 1) Coarse‑grained path loss (weak supervision)
        loss_short = F.mse_loss(y_t, y_true)

        rho_target = 0.7

        corr = self.pearson_corr(y_t, y_true)
        loss_corr = (corr - rho_target) ** 2

        # Total coarse loss
        loss_coarse = loss_corr + loss_short #+ 

        # 3) Fine‑grained path loss (learning the residual: y_true − y_t)
        residual = y_true - y_t.detach() #detach or not

        # loss_fine = F.mse_loss(y_c, residual)
        loss_fine = F.mse_loss(y_c, residual)

        # 4) Final fused prediction loss (coarse + fine)
        y_final = y_t + y_c
        loss_final = F.mse_loss(y_final, y_true)

        # 6) Combine all losses with hyperparameter weights
        total_loss = (
            loss_final
            + 1.0 * loss_coarse
            + 1.0 * loss_fine
        )

        return total_loss

