"""
Domain discovery and domain-path components.

This module implements three major subsystems:

1) HierarchyPruner
   ----------------
   • Builds hierarchy-aware node embeddings via upward averaging and a similarity-driven
    rectification step ("upward information flow").
   • Scores nodes with a coverage/purity/depth composite score.
   • Produces a pruned frontier (set of ancestor nodes) that partitions the leaf set.
   • Refines the frontier using a rectified beam search with Silhouette Score as the
    global objective.
   • Generates a multi-label patient-by-domain matrix M for downstream learning.

2) DomainEncoder (g_θ)
   --------------------
   • Small MLP mapping multi-label domain IDs m ∈ {0,1}^{|C'|} to a dense "domain factor"
    , where D matches the dimensionality of patient embeddings p.

3) Invariant projection h(·)
   -------------------------
   • Non-parametric orthogonal projection that subtracts from p its component parallel
    to r, yielding domain-invariant features h.

Utility functions for domain prototypes and mean-matching are also included to support
the pretraining of g_θ and the mutual learning objective.
"""

from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Sequence, Tuple

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


# -------------------------------------------------------------------------
# Helper structure for the medical hierarchy
# -------------------------------------------------------------------------

@dataclass
class HierarchyStructure:
    """
    Container describing a rooted concept hierarchy over disease codes.

    Expected fields
    ---------------
    nodes        : List[str]
        All node identifiers (internal + leaves). Identifiers can be ICD codes or stable IDs.
    parent       : Dict[str, Optional[str]]
        Mapping node -> immediate parent (None for root).
    children     : Dict[str, List[str]]
        Mapping node -> list of immediate children (empty for leaves).
    level        : Dict[str, int]
        Mapping node -> depth level, starting from 1 at root. Leaves are at level H = max(level.values()).
    leaves       : List[str]
        Identifier list of leaves (disease codes used by the dataset).
    leaf_to_vocab_idx : Dict[str, int]
        Mapping from leaf node id to row index in the disease embedding table E (alignment for lookups).

    Notes
    -----
    • We precompute descendant bitsets aligned to self.leaves for vectorized operations.
    • Nodes must form a tree (single parent per node). If it's a DAG, convert it to a tree beforehand.
    """
    nodes: List[str]
    parent: Dict[str, Optional[str]]
    children: Dict[str, List[str]]
    level: Dict[str, int]
    leaves: List[str]
    leaf_to_vocab_idx: Dict[str, int]


# -------------------------------------------------------------------------
# Domain Encoder (g_theta) and invariant projection h(·)
# -------------------------------------------------------------------------

class DomainEncoder(nn.Module):
    """
    Deterministic encoder mapping multi-label soft domain IDs to a dense domain factor r.

    Input shape : [B, |C'|]
    Output shape: [B, out_dim]  where out_dim must match patient embedding dimension.
    """

    def __init__(self, in_dim: int, out_dim: int, hidden: int = 128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(inplace=True),
            nn.Linear(hidden, out_dim),
        )
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, m: torch.Tensor) -> torch.Tensor:
        return self.net(m.float())


@torch.no_grad()
def normalize_batchwise(x: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    """
    Normalize a batch of vectors to have comparable magnitudes.

    Parameters
    ----------
    x : [B, D]

    Returns
    -------
    x_norm : [B, D]
    """
    scale = x.norm(dim=-1, keepdim=True).clamp_min(eps)
    return x / scale


def invariant_projection(p: torch.Tensor, r: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    """
    Orthogonal projection removing the component of p that is parallel to r.

    Parameters
    ----------
    p : [B, D]
        Patient-level embeddings.
    r : [B, D]
        Domain factor vectors.

    Returns
    -------
    h : [B, D]
        Domain-invariant embeddings.

    Notes
    -----
    We compute: h = p - proj_r(p) = p - ( (p·r)/(||r||^2) ) * r
    """
    r_norm_sq = (r * r).sum(dim=-1, keepdim=True).clamp_min(eps)
    alpha = (p * r).sum(dim=-1, keepdim=True) / r_norm_sq
    return p - alpha * r


# -------------------------------------------------------------------------
# Prototype helpers for self-supervised pretraining of g_theta
# -------------------------------------------------------------------------

@torch.no_grad()
def compute_domain_prototypes(p_all: torch.Tensor, M_all: torch.Tensor) -> torch.Tensor:
    """
    Compute domain prototypes μ_j by averaging patient embeddings among samples
    where domain j is active.

    Parameters
    ----------
    p_all : [N, D]
    M_all : [N, C'] binary or multi-hot

    Returns
    -------
    mu : [C', D]
    """
    N, D = p_all.shape
    _, C = M_all.shape
    denom = M_all.sum(dim=0).clamp_min(1.0)  # [C']
    mu = (M_all.t() @ p_all) / denom.unsqueeze(-1)  # [C', D]
    return mu


def target_proto_from_m(m: torch.Tensor, mu: torch.Tensor) -> torch.Tensor:
    """
    Build a per-sample target prototype by averaging active domain prototypes.

    Parameters
    ----------
    m  : [B, C'] binary or multi-hot
    mu : [C', D]

    Returns
    -------
    pbar : [B, D]
    """
    denom = m.sum(dim=1, keepdim=True).clamp_min(1.0)
    return (m @ mu) / denom


def mean_match_penalty(r: torch.Tensor, p: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    """
    Batchwise mean-matching penalty (MMD-style L2 on means).

    Returns
    -------
    scalar tensor
    """
    r_mu = r.mean(dim=0)
    p_mu = p.mean(dim=0)
    num = (r_mu - p_mu).pow(2).sum()
    den = p_mu.pow(2).sum().clamp_min(eps)
    return num / den


# -------------------------------------------------------------------------
# Hierarchy Pruning
# -------------------------------------------------------------------------

class HierarchyPruner:
    """
    Knowledge-guided pruner that selects an ancestor frontier C' partitioning the leaf set.

    Workflow
    --------
    1) Upward information flow builds node embeddings for all nodes from the disease
       embedding table:
         • Initialize leaf embeddings from the lookup table E (aligned via leaf_to_vocab_idx).
         • Average descendants to each parent (level by level).
         • Rectify ancestors by iteratively averaging the two most-similar leaves/subclusters
           into their Lowest Common Ancestor (LCA) until similarity < ρ.

    2) Score nodes with a composite S(n) that balances purity, coverage, and depth.

    3) Bottom-up pruning pass creates a candidate frontier C0 and a list of flagged
       (parent, children) pairs for which local comparisons are ambiguous.

    4) Rectified beam search resolves flagged pairs globally by maximizing the silhouette
       score over the leaf embeddings clustered by the frontier.

    5) Domain assignment builds M ∈ {0,1}^{N × |C'|} by linking each patient’s
       aggregated disease vector to pruned nodes that cover at least one of their leaves.

    Expected inputs
    ---------------
    • hierarchy : HierarchyStructure
      Fully specified disease hierarchy (ICD-9).

    • leaf_table : torch.Tensor [|L|, H]
      Leaf embedding table aligned to `hierarchy.leaves`.

    Design guarantees
    -----------------
    * The selected set C' forms a valid partition ("frontier"): no selected node is a descendant
      of another selected node; every leaf has exactly one selected ancestor.
    * The assignment M is multi-label at the patient level (patients can hit multiple pruned nodes).
    """

    def __init__(
        self,
        hierarchy: HierarchyStructure,
        alpha: float = 0.5,
        rho: float = 0.3,
        beam_width: int = 8,
        device: Optional[torch.device] = None,
    ):
        self.H = hierarchy
        self.alpha = float(alpha)
        self.rho = float(rho)
        self.beam_width = int(beam_width)
        self.device = device if device is not None else torch.device("cpu")

        # Indexing helpers
        self.nodes = list(self.H.nodes)
        self.node_to_idx = {n: i for i, n in enumerate(self.nodes)}
        self.leaves = list(self.H.leaves)
        self.leaf_to_leafidx = {n: i for i, n in enumerate(self.leaves)}

        # Precompute: descendant bitsets per node over leaves
        self.level_max = max(self.H.level.values())
        self.desc_matrix = self._compute_desc_matrix()  # [|V_all|, |L|] bool

        # Children index mapping for faster local passes
        self.children_idx = {self.node_to_idx[n]: [self.node_to_idx[c] for c in self.H.children.get(n, [])]
                             for n in self.nodes}

    # ---------- topology utilities ----------

    def _compute_desc_matrix(self) -> torch.Tensor:
        """Build a boolean matrix M_desc[node_idx, leaf_idx] indicating coverage."""
        num_nodes = len(self.nodes)
        num_leaves = len(self.leaves)
        M = torch.zeros(num_nodes, num_leaves, dtype=torch.bool)

        # Mark leaf coverage
        for leaf in self.leaves:
            ni = self.node_to_idx[leaf]
            li = self.leaf_to_leafidx[leaf]
            M[ni, li] = True

        # Propagate upwards (bottom-up)
        # Sort nodes by level descending so children come before parent
        nodes_by_level = sorted(self.nodes, key=lambda x: self.H.level[x], reverse=True)
        for n in nodes_by_level:
            idx = self.node_to_idx[n]
            for c in self.H.children.get(n, []):
                ci = self.node_to_idx[c]
                M[idx] |= M[ci]
        return M

    def _lca(self, a: str, b: str) -> str:
        """Compute Lowest Common Ancestor (LCA) using parent pointers."""
        # Gather ancestors for a
        seen = set()
        cur = a
        while cur is not None:
            seen.add(cur)
            cur = self.H.parent.get(cur, None)
        # Walk up from b
        cur = b
        while cur is not None:
            if cur in seen:
                return cur
            cur = self.H.parent.get(cur, None)
        raise RuntimeError("Hierarchy is not rooted or has cycles.")

    # ---------- Step 1: Build node embeddings ----------

    def build_node_embeddings(
        self,
        leaf_table: torch.Tensor,
        oov_fallback: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Construct embeddings for all nodes using upward averaging and similarity-based rectification.

        `leaf_table` may be:
          • [|L|, H] already aligned to `self.leaves`, OR
          • [V(+1), H] full lookup table (e.g., with padding at row 0).
        We detect the shape and, if needed, gather leaf vectors using `HierarchyStructure.leaf_to_vocab_idx`.

        Parameters
        ----------
        leaf_table : torch.Tensor
            Either [L, H] (aligned) or [V(+1), H] (full table, possibly with padding row).
        oov_fallback : Optional[torch.Tensor], shape [H]
            Used only if a node has no descendant leaves (should be rare for a tree).

        Returns
        -------
        node_embs : [|V_all|, H]
        """
        device = leaf_table.device
        Hdim = leaf_table.shape[1]
        num_nodes = len(self.nodes)

        # 1) Normalize to a leaf-aligned matrix [L, H]
        if leaf_table.shape[0] == len(self.leaves):
            leaf_vecs = leaf_table
        else:
            # Gather rows by mapping leaf -> embedding row index
            idx = [self.H.leaf_to_vocab_idx[l] for l in self.leaves]
            idx_t = torch.tensor(idx, dtype=torch.long, device=device)
            leaf_vecs = leaf_table.index_select(0, idx_t)  # [L, H]

        node_embs = torch.zeros(num_nodes, Hdim, device=device)

        # Initialize leaves from leaf_vecs aligned to self.leaves
        for li, leaf in enumerate(self.leaves):
            ni = self.node_to_idx[leaf]
            node_embs[ni] = leaf_vecs[li]

        # 2) Bottom-up averaging for internal nodes
        for n in sorted(self.nodes, key=lambda x: self.H.level[x], reverse=True):
            idx = self.node_to_idx[n]
            if len(self.H.children.get(n, [])) == 0:
                continue  # leaf already set
            # mean over descendant leaves (vectorized via a boolean mask)
            leaf_mask = self.desc_matrix[idx]  # [L]
            if leaf_mask.any():
                node_embs[idx] = leaf_vecs[leaf_mask].mean(dim=0)
            elif oov_fallback is not None:
                node_embs[idx] = oov_fallback.to(device)
            else:
                pass  # leave zeros (shouldn't happen for valid trees)

        # 3) Similarity-driven rectification (stop when max cosine < rho)
        leaf_norm = F.normalize(leaf_vecs, dim=1)   # normalized leaf vectors
        sim = (leaf_norm @ leaf_norm.t())           # [L, L]
        sim.fill_diagonal_(-1.0)                    # avoid selecting (i, i)

        max_sim = sim.max().item()
        while max_sim >= self.rho:
            i, j = torch.nonzero(sim == sim.max(), as_tuple=False)[0].tolist()
            li, lj = self.leaves[i], self.leaves[j]
            anc = self._lca(li, lj)
            ai = self.node_to_idx[anc]
            node_embs[ai] = (node_embs[ai] + leaf_vecs[i] + leaf_vecs[j]) / 3.0
            sim[i, j] = sim[j, i] = -1.0
            max_sim = sim.max().item()

        return node_embs

    # ---------- Step 2: Node scoring ----------

    def score_nodes(self, node_embs: torch.Tensor) -> torch.Tensor:
        """
        Compute composite score S(n) = α·exp(purity) + (1-α)·(coverage × depth).

        purity  : mean cosine similarity between node embedding and its descendant leaves.
        coverage: |Desc(n)| / |L|
        depth   : level(n) / H

        Returns
        -------
        scores : [|V_all|] tensor
        """
        device = node_embs.device
        num_nodes = len(self.nodes)
        num_leaves = len(self.leaves)

        # Pre-normalize for cosine
        node_norm = F.normalize(node_embs, dim=1)
        # Build leaf matrix aligned to leaves for quick cosine with each node
        # Derive leaf embeddings from node_embs via leaf indices for consistent normalization
        leaf_idx_tensor = torch.tensor([self.node_to_idx[l] for l in self.leaves], device=device)
        leaf_embs = node_norm.index_select(0, leaf_idx_tensor)  # [L, H]

        # cosine(node, each descendant leaf) -> masked mean
        scores = torch.empty(num_nodes, device=device)
        level_max = float(self.level_max)
        L = float(len(self.leaves))

        for n in self.nodes:
            ni = self.node_to_idx[n]
            desc_mask = self.desc_matrix[ni]  # [L]
            cov = desc_mask.float().mean().item() if L > 0 else 0.0
            dep = float(self.H.level[n]) / level_max
            if desc_mask.any():
                nl = node_norm[ni].unsqueeze(0)              # [1, H]
                sims = (nl @ leaf_embs[desc_mask].t()).squeeze(0)  # [#desc]
                pur = sims.mean().item()
            else:
                pur, cov = 0.0, 0.0
            S = self.alpha * math.exp(pur) + (1.0 - self.alpha) * (cov * dep)
            scores[ni] = S
        return scores

    # ---------- Step 3: Bottom-up pruning (local decisions) ----------

    def initial_prune(self, scores: torch.Tensor) -> Tuple[List[str], List[Tuple[str, List[str]]]]:
        """
        Bottom-up pass that generates a candidate frontier and a list of flagged pairs.

        Returns
        -------
        C0 : List[str]
            Initial frontier (partition) nodes.
        flagged : List[Tuple[parent, children_list]]
            Ambiguous cases requiring global resolution.
        """
        # Start from leaves as frontier
        frontier = set(self.leaves)
        flagged: List[Tuple[str, List[str]]] = []

        # Iterate internal nodes bottom-up (deepest first)
        internal_nodes = [n for n in self.nodes if len(self.H.children.get(n, [])) > 0]
        internal_nodes.sort(key=lambda x: self.H.level[x], reverse=True)

        for p in internal_nodes:
            pi = self.node_to_idx[p]
            children = self.H.children.get(p, [])
            if not children:
                continue
            child_scores = [scores[self.node_to_idx[c]].item() for c in children]
            S_p = scores[pi].item()

            if S_p > max(child_scores):
                # unify: select parent p, remove any selected descendants
                # Remove entire subtree under p from frontier, then add p
                to_remove = self._descendants_of(p)
                frontier -= to_remove
                frontier.add(p)
            elif S_p < min(child_scores):
                # keep children (no change to frontier)
                continue
            else:
                # ambiguous: keep children for now but record for beam search
                flagged.append((p, children))

        # Normalize frontier to ensure it's a proper partition
        C0 = self._normalize_frontier(frontier)
        return C0, flagged

    def _descendants_of(self, node: str) -> set:
        """Return the set of nodes in the subtree rooted at `node` (including leaves)."""
        # Use leaf coverage + all internal descendants that lie on paths to those leaves.
        desc = set()
        stack = [node]
        while stack:
            cur = stack.pop()
            desc.add(cur)
            for c in self.H.children.get(cur, []):
                stack.append(c)
        return desc

    def _normalize_frontier(self, nodes: Iterable[str]) -> List[str]:
        """
        Remove any node whose ancestor is also in the set to ensure partition property.
        Prefer deeper nodes when conflicts arise.
        """
        S = list(nodes)
        # Sort by level descending (deeper first)
        S.sort(key=lambda n: self.H.level[n], reverse=True)
        accepted = []
        selected = set()
        for n in S:
            cur = n
            ok = True
            while cur is not None:
                if cur in selected:
                    ok = False
                    break
                cur = self.H.parent.get(cur, None)
            if ok:
                accepted.append(n)
                selected.add(n)
        return accepted

    # ---------- Step 4: Rectified beam search with silhouette objective ----------

    def beam_search_refinement(
        self,
        flagged_pairs: List[Tuple[str, List[str]]],
        C0: List[str],
        node_embs: torch.Tensor,
        max_pairs_eval: Optional[int] = None,
    ) -> List[str]:
        """
        Resolve flagged parent-children pairs by beam search maximizing global silhouette.

        Parameters
        ----------
        flagged_pairs : list of (parent, children_list)
        C0            : initial frontier from local pruning
        node_embs     : [|V_all|, H_dim]  (only used for leaf embeddings extraction)
        max_pairs_eval: if set, restrict number of flagged pairs processed (for speed)

        Returns
        -------
        Cprime : List[str]  final refined frontier
        """
        # Prepare leaf embedding matrix for silhouette evaluation
        device = node_embs.device
        leaf_indices = torch.tensor([self.node_to_idx[l] for l in self.leaves], device=device)
        leaf_embs = F.normalize(node_embs.index_select(0, leaf_indices), dim=1)  # [L, H]

        beams: List[Tuple[float, frozenset]] = []
        init_score = self._silhouette_for_frontier(set(C0), leaf_embs)
        beams.append((init_score, frozenset(C0)))

        pairs = flagged_pairs if max_pairs_eval is None else flagged_pairs[:max_pairs_eval]

        for p, children in pairs:
            new_beams: List[Tuple[float, frozenset]] = []
            for score, frontier in beams:
                # Branch A: unify -> replace descendants of p with p
                f_unify = set(frontier)
                desc_p = self._descendants_of(p)
                f_unify -= desc_p
                f_unify.add(p)
                f_unify = set(self._normalize_frontier(f_unify))
                score_unify = self._silhouette_for_frontier(f_unify, leaf_embs)

                # Branch B: split -> ensure children are present, remove any ancestor of children
                f_split = set(frontier)
                # Remove p if present and any ancestor that would violate partition
                f_split -= {p}
                for c in children:
                    # Remove ancestors of c already in frontier
                    cur = self.H.parent.get(c, None)
                    while cur is not None:
                        if cur in f_split:
                            f_split.remove(cur)
                        cur = self.H.parent.get(cur, None)
                    f_split.add(c)
                f_split = set(self._normalize_frontier(f_split))
                score_split = self._silhouette_for_frontier(f_split, leaf_embs)

                new_beams.append((score_unify, frozenset(f_unify)))
                new_beams.append((score_split, frozenset(f_split)))

            # Keep top-K beams
            new_beams.sort(key=lambda x: x[0], reverse=True)
            beams = new_beams[: self.beam_width]

        # Select best
        beams.sort(key=lambda x: x[0], reverse=True)
        best_frontier = list(beams[0][1])
        return best_frontier

    def _silhouette_for_frontier(self, frontier: Iterable[str], leaf_embs: torch.Tensor) -> float:
        """
        Compute mean silhouette over leaves under the clustering induced by `frontier`.

        Implementation details
        ----------------------
        • Distance metric: 1 - cosine_similarity.
        • Each leaf is assigned to the deepest selected ancestor that covers it.
        • Vectorized pairwise computations; still O(L^2), so consider sub-sampling for very large L.
        """
        frontier = list(frontier)
        if len(frontier) <= 1:
            return -1.0  # silhouette undefined/degenerate

        # Map each leaf to a single cluster index
        # Build coverage matrix for selected nodes: [C', L]
        Cidx = torch.tensor([self.node_to_idx[n] for n in frontier], device=leaf_embs.device)
        cover = self.desc_matrix.index_select(0, Cidx).to(leaf_embs.device)  # [C', L] bool

        # For each leaf, choose the deepest covering node
        depth = torch.tensor([self.H.level[n] for n in frontier], device=leaf_embs.device)  # [C']
        depth = depth.unsqueeze(1).expand_as(cover)  # [C', L]
        # Mask depths where not covered
        masked_depth = depth.masked_fill(~cover, -1)
        # Take argmax over C' for each leaf
        cluster_ids = masked_depth.argmax(dim=0)  # [L]
        # Guard: a leaf must be covered by at least one selected node
        if (masked_depth.max(dim=0).values < 0).any():
            return -1.0

        # Pairwise cosine distances among leaves
        sim = (leaf_embs @ leaf_embs.t()).clamp(-1.0, 1.0)  # [L, L]
        dist = 1.0 - sim

        L = leaf_embs.shape[0]
        # For each leaf i, compute a(i) and b(i)
        a_vals = torch.zeros(L, device=leaf_embs.device)
        b_vals = torch.zeros(L, device=leaf_embs.device)

        for c in cluster_ids.unique():
            mask_c = (cluster_ids == c)
            mask_not = ~mask_c
            idx_c = torch.nonzero(mask_c, as_tuple=False).squeeze(-1)
            if idx_c.numel() == 1:
                a_vals[idx_c] = 0.0  # silhouette for singleton: a=0 by convention
            else:
                # mean intra-cluster distance
                d_intra = dist[idx_c][:, idx_c]
                # exclude diagonal
                a_vals[idx_c] = (d_intra.sum(dim=1) / (idx_c.numel() - 1))

            # mean nearest other-cluster distance
            if mask_not.any():
                d_inter = dist[idx_c][:, mask_not]                  # [#c, #not]
                # Compute clusterwise mean distances to each other cluster, then take min
                other_clusters = cluster_ids[mask_not]
                # Map each column to its cluster; compute per-row grouped means
                # Efficient approach: for each distinct other cluster, get its column mask
                mins = []
                for c2 in other_clusters.unique():
                    col_mask = (cluster_ids == c2) & mask_not
                    if col_mask.any():
                        mins.append(d_inter[:, col_mask].mean(dim=1))  # [#c]
                if mins:
                    b_vals[idx_c] = torch.min(torch.stack(mins, dim=1), dim=1).values

        s = (b_vals - a_vals) / torch.clamp(torch.maximum(a_vals, b_vals), min=1e-8)
        # Clamp to [-1, 1] and average
        return float(s.clamp(-1.0, 1.0).mean().item())

    # ---------- Step 5: Domain assignment ----------

    def assign_domains(self, X_all: torch.Tensor, pruned_nodes: List[str]) -> torch.Tensor:
        """
        Build multi-label domain assignment matrix for all patients.

        Parameters
        ----------
        X_all : [N, |L|] binary or count (aggregated disease codes per patient across visits)
               Aligned to self.leaves.
        pruned_nodes : List[str]
            Selected frontier C'.

        Returns
        -------
        M : [N, |C'|] binary int tensor
            M[i, j] = 1 if patient i has at least one leaf under pruned node j; else 0.
        """
        device = X_all.device
        Cidx = torch.tensor([self.node_to_idx[n] for n in pruned_nodes], device=device)
        cover = self.desc_matrix.index_select(0, Cidx).float()   # [C', L]
        # patient-by-pruned: sign( X @ cover^T )
        M = (X_all.float() @ cover.t()) > 0.0
        return M.to(dtype=torch.int32)
