# router.py
# Routing module: G(x) = softmax(Wg * phi(x) + bg).
# phi(x) is a small feature MLP that transforms Phase I context vector c into a routing vector.
# route(c) returns pathway id and probability distribution.

from typing import Dict, Tuple, List, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


DEFAULT_PATHWAYS = ["prompt_only", "constrained", "clinician"]

class Router(nn.Module):
    """
    Router computes routing probabilities over discrete pipelines (pathways).
    Input `c` is expected to be a dict with keys: 'e' (affect scalar),
    'n' (need label scalar), 'p' (risk probs vector), 'd' (demo vector).
    """

    def __init__(self, in_dim: int = 128, hidden: int = 64, n_pathways: int = 3, phi_hidden: int = 64):
        super().__init__()
        # phi: transform concatenated context into a fixed-size vector
        self.phi = nn.Sequential(
            nn.Linear(in_dim, phi_hidden),
            nn.ReLU(),
            nn.Linear(phi_hidden, phi_hidden),
            nn.ReLU()
        )
        # gating linear layer Wg, bg applied to phi(x)
        self.Wg = nn.Linear(phi_hidden, n_pathways)
        self.pathway_names = DEFAULT_PATHWAYS[:n_pathways]

    @staticmethod
    def context_to_vector(c: Dict, expect_demo_dim: int = 64, expect_risk_dim: int = 3, text_embed_dim: int = 768) -> torch.Tensor:
        """
        Convert context dict c into a numeric vector.
        We concatenate: [affect (1), need_one_hot (k), risk_probs (r), demo_vector (d)]
        For need_one_hot we allocate up to 8 classes by default (pad zeros if fewer).
        """
        e = float(c.get("e", 0.0))
        n = int(c.get("n", 0))
        p = c.get("p", None)
        d = c.get("d", None)

        # risk
        if p is None:
            p_vec = np.zeros(expect_risk_dim, dtype=float)
        else:
            p = np.asarray(p, dtype=float)
            if p.ndim == 0:
                # scalar -> put in first slot
                p_vec = np.zeros(expect_risk_dim, dtype=float)
                p_vec[0] = float(p)
            else:
                if p.shape[0] >= expect_risk_dim:
                    p_vec = p[:expect_risk_dim]
                else:
                    p_vec = np.pad(p, (0, expect_risk_dim - p.shape[0]), mode="constant")

        # need one-hot (cap at 8 dims)
        need_dim = 8
        need_one_hot = np.zeros(need_dim, dtype=float)
        idx = n if n < need_dim else (need_dim - 1)
        need_one_hot[idx] = 1.0

        # demo vector (assume numeric vector)
        if d is None:
            d_vec = np.zeros(expect_demo_dim, dtype=float)
        else:
            d = np.asarray(d, dtype=float).ravel()
            if d.shape[0] >= expect_demo_dim:
                d_vec = d[:expect_demo_dim]
            else:
                d_vec = np.pad(d, (0, expect_demo_dim - d.shape[0]), mode="constant")

        # finally produce a concatenated vector; also include affect e
        vec = np.concatenate([[e], need_one_hot, p_vec, d_vec]).astype("float32")
        return torch.from_numpy(vec)

    def forward(self, c: Dict) -> torch.Tensor:
        """
        Compute routing logits and probabilities for a single context dict c.
        Returns probability tensor over pathways (n_pathways,)
        """
        # convert context dict to vector; we need to pick in_dim accordingly
        # infer in_dim from phi input layer
        first_lin = None
        if isinstance(self.phi, nn.Sequential):
            # find in_dim as phi[0].in_features if linear
            for m in self.phi.modules():
                if isinstance(m, nn.Linear):
                    first_lin = m
                    break
        in_dim = first_lin.in_features if first_lin is not None else 128

        ctx_vec = self.context_to_vector(c, expect_demo_dim=in_dim - (1 + 8 + 3))
        ctx_vec = ctx_vec.to(next(self.parameters()).device).unsqueeze(0)  # (1, in_dim)
        phi_x = self.phi(ctx_vec)  # (1, phi_hidden)
        logits = self.Wg(phi_x).squeeze(0)  # (n_pathways,)
        probs = F.softmax(logits, dim=-1)
        return probs

    def route(self, c: Dict) -> Tuple[str, np.ndarray]:
        """
        Convenience: compute probabilities and return argmax pathway id and the probabilities as numpy array.
        """
        self.eval()
        with torch.no_grad():
            probs = self.forward(c).cpu().numpy()
        idx = int(probs.argmax())
        return self.pathway_names[idx], probs
