"""
UdonCare: backbone encoders + domain pathway with mutual forward learning.

Backbone pathway:
  For k ∈ {D, P, M}, encode per-key index tensors [B,T,E] with the project's TransformerLayer
  (via CodeKeyEncoder) to get patient vectors p_k. Concatenate into p=[p_D;p_P;p_M] and decode
  with a label head d_η.

Domain pathway:
  Given m (multi-label domain IDs) -> r=g_θ(m) -> h=p - proj_r(p) -> label head q_ξ(h).
"""

from __future__ import annotations

from typing import Dict, Optional, Tuple

import torch
import torch.nn as nn

from .encoders import CodeKeyEncoder
from .heads import LabelHead
from .domain import DomainEncoder, invariant_projection


class UdonCare(nn.Module):
    """
    End-to-end model with mutual forward learning and hierarchy-driven domain discovery.

    Parameters
    ----------
    vocab_D, vocab_P, vocab_M : int
        Vocabulary sizes for diseases, procedures, and drugs (exclude padding).
    d_out : int
        Number of output labels for the downstream task.
    h : int
        Hidden width for each per-key encoder; concatenated p has size 3*h.
    n_layers, n_heads, dropout : Transformer hyperparameters for the per-key encoders.
    """

    def __init__(
        self,
        vocab_D: int,
        vocab_P: int,
        vocab_M: int,
        d_out: int,
        h: int = 64,
        n_layers: int = 3,
        n_heads: int = 4,
        dropout: float = 0.2,
    ):
        super().__init__()
        self.h = int(h)
        self.d_out = int(d_out)

        # --- Backbone encoders (per key) using project's backbone blocks ---
        self.enc_D = CodeKeyEncoder(vocab_size=vocab_D, emb_dim=h, n_layers=n_layers, n_heads=n_heads, dropout=dropout)
        self.enc_P = CodeKeyEncoder(vocab_size=vocab_P, emb_dim=h, n_layers=n_layers, n_heads=n_heads, dropout=dropout)
        self.enc_M = CodeKeyEncoder(vocab_size=vocab_M, emb_dim=h, n_layers=n_layers, n_heads=n_heads, dropout=dropout)

        self.backbone_head = LabelHead(in_dim=3 * h, out_dim=d_out)

        # --- Domain path components (lazily created once |C'| is known) ---
        self.domain_encoder: Optional[DomainEncoder] = None   # g_theta
        self.invariant_head = LabelHead(in_dim=3 * h, out_dim=d_out)

    # ------------------------------ Utilities ------------------------------

    def build_domain_encoder(self, pruned_vocab_size: int) -> None:
        """
        (Re)create the domain encoder when the pruned vocabulary C' changes size.
        """
        self.domain_encoder = DomainEncoder(in_dim=pruned_vocab_size, out_dim=3 * self.h)

    def code_embedding_weight(self, key: str = "D", include_padding: bool = True) -> torch.Tensor:
        """
        Expose a per-key embedding table for hierarchy discovery.

        Parameters
        ----------
        key : {'D','P','M'}
        include_padding : bool
            If True, returns table with row 0 reserved for padding.

        Returns
        -------
        torch.Tensor
        """
        if key == "D":
            return self.enc_D.code_embedding_weight(include_padding=include_padding)
        elif key == "P":
            return self.enc_P.code_embedding_weight(include_padding=include_padding)
        elif key == "M":
            return self.enc_M.code_embedding_weight(include_padding=include_padding)
        else:
            raise KeyError(f"Unknown key: {key!r}. Expected one of 'D','P','M'.")

    # ------------------------------ Forward passes ------------------------------

    def forward_backbone(self, batch: Dict[str, torch.Tensor], register_hook: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Encode by key, concatenate, and decode with the backbone head.

        Parameters
        ----------
        batch : dict with keys {'D','P','M'}
            Each element is a LongTensor [B, T, E] of *indices* with 0 = padding.
        register_hook : bool
            If True, enables attention gradient hooks in the backbone (interpretability).

        Returns
        -------
        p         : [B, 3h] patient embedding
        logits_p  : [B, d_out] raw logits from the backbone head
        """
        p_D = self.enc_D(batch["D"], register_hook=register_hook)
        p_P = self.enc_P(batch["P"], register_hook=register_hook)
        p_M = self.enc_M(batch["M"], register_hook=register_hook)
        p = torch.cat([p_D, p_P, p_M], dim=-1)
        logits_p = self.backbone_head(p)
        return p, logits_p

    def forward_domain_path(self, p: torch.Tensor, m: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Domain pathway: r = gθ(m) -> h = proj(p, r) -> logits_h.

        Parameters
        ----------
        p : [B, 3h]
        m : [B, |C'|]

        Returns
        -------
        r         : [B, 3h]
        h         : [B, 3h]
        logits_h  : [B, d_out]
        """
        if self.domain_encoder is None:
            raise RuntimeError("Domain encoder is not initialized. Call build_domain_encoder(|C'|) first.")
        r = self.domain_encoder(m)
        h = invariant_projection(p, r)
        logits_h = self.invariant_head(h)
        return r, h, logits_h

    def forward(self, batch: Dict[str, torch.Tensor], m: Optional[torch.Tensor] = None, register_hook: bool = False) -> Dict[str, torch.Tensor]:
        """
        Joint forward. If m is None, only the backbone path is executed.

        Returns
        -------
        out : dict with keys
            'p', 'logits_p' always present;
            'r', 'h', 'logits_h' present only if m is provided.
        """
        out = {}
        p, logits_p = self.forward_backbone(batch, register_hook=register_hook)
        out["p"] = p
        out["logits_p"] = logits_p

        if m is not None:
            r, h, logits_h = self.forward_domain_path(p, m)
            out.update({"r": r, "h": h, "logits_h": logits_h})
        return out
