"""
Encoders for EHR sequences by feature key (disease/procedure/drug).

This module integrates *our* index-based EHR inputs with the project’s backbone
Transformer (see models/backbone.py). Each feature key uses:
  - a learnable lookup table with padding_idx=0, and
  - the provided TransformerLayer to model longitudinal dynamics across visits.

Input Format
------------
For each key k ∈ {D, P, M}, we expect an *index tensor*:
    X_k : LongTensor of shape [B, T, E]
where:
  • B = batch size
  • T = number of visits (time steps)
  • E = number of event slots per visit (0 = padded slot)
Each event slot holds a single integer code id (0 means "no code at this slot").
We embed per-slot ids and *sum* over events to produce per-visit vectors.

Output
------
A fixed-width *patient embedding* per key:
    p_k : FloatTensor [B, H]
We use the "CLS-by-convention" aggregation from the backbone's TransformerLayer:
the returned `cls_emb` (first timestep output) is treated as the patient vector.
"""

from typing import Optional

import torch
import torch.nn as nn

# Use the project's backbone block directly
from .backbone import TransformerLayer


class CodeKeyEncoder(nn.Module):
    """
    Encoder for a single feature key (D, P, or M) using the shared backbone blocks.

    Mapping:
      index tensor [B,T,E] --(Embedding[padded]->sum over events)--> [B,T,H]
      --(TransformerLayer)--> sequence + cls --> patient embedding [B,H]
    """

    def __init__(
        self,
        vocab_size: int,
        emb_dim: int,
        n_layers: int = 3,
        n_heads: int = 4,
        dropout: float = 0.2,
        padding_idx: int = 0,
    ):
        """
        Parameters
        ----------
        vocab_size : int
            Number of codes in this key's vocabulary (excluding padding).
        emb_dim : int
            Embedding and hidden width (H).
        n_layers : int
            Number of Transformer blocks.
        n_heads : int
            Attention heads per block.
        dropout : float
            Dropout rate used in backbone blocks.
        padding_idx : int
            Index reserved for padding; defaults to 0.
        """
        super().__init__()
        # NOTE: allocate +1 row to hold padding vector at index 0.
        self.emb = nn.Embedding(vocab_size + 1, emb_dim, padding_idx=padding_idx)
        nn.init.xavier_uniform_(self.emb.weight)
        if padding_idx is not None:
            with torch.no_grad():
                self.emb.weight[padding_idx].zero_()

        # Temporal encoder is the project's backbone block
        self.temporal = TransformerLayer(
            feature_size=emb_dim,
            heads=n_heads,
            dropout=dropout,
            num_layers=n_layers,
        )

    def code_embedding_weight(self, include_padding: bool = True) -> torch.Tensor:
        """
        Expose the lookup table for hierarchy-driven discovery.

        Parameters
        ----------
        include_padding : bool
            If True (default), returns the full table [V+1, H] with row 0 reserved for padding.
            If False, returns weights for real codes only [V, H] (row 1..V).

        Returns
        -------
        torch.Tensor
        """
        if include_padding:
            return self.emb.weight
        return self.emb.weight[1:, :]

    def forward(self, X_idx: torch.Tensor, register_hook: bool = False) -> torch.Tensor:
        """
        Parameters
        ----------
        X_idx : LongTensor [B, T, E]
            Index tensor with 0 = padded slot (ignored).
        register_hook : bool
            If True, saves attention gradients inside the backbone for interpretability.

        Returns
        -------
        patient_emb : FloatTensor [B, H]
        """
        # Embed events, then sum over event slots per visit
        # [B, T, E] -> [B, T, E, H] -> [B, T, H]
        x = self.emb(X_idx)           # (B, T, E, H)
        x = x.sum(dim=-2)             # (B, T, H)

        # Build a visit mask: True = valid timestep, False = padded
        # With padding_idx=0, empty visits remain an all-zero vector after summation.
        visit_mask = x.abs().sum(dim=-1) > 0  # (B, T) bool

        # Pass through backbone temporal encoder
        # TransformerLayer returns (seq_out [B,T,H], cls_emb [B,H])
        _, cls_emb = self.temporal(x, mask=visit_mask, register_hook=register_hook)

        return cls_emb
