# phase1_parsing.py
# Implements Phase I parser components: AffectNet, NeedClassifier, risk projection, demo embedding, and PhaseI parser wrapper.

from typing import Dict, Any, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from config import TEXT_EMBED_DIM, DEMO_EMBED_DIM, RISK_PROJECTION_DIM

class AffectNet(nn.Module):
    """
    Small network that maps text embeddings to an affect score (scalar).
    For prototype: a 2-layer MLP producing a scalar between 0 and 1 (sigmoid).
    """
    def __init__(self, in_dim: int = TEXT_EMBED_DIM, hidden: int = 256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1)
        )

    def forward(self, text_embed: torch.Tensor) -> torch.Tensor:
        """
        text_embed: (batch, in_dim) or (in_dim,)
        returns: (batch,) of affect scores in [0,1]
        """
        x = self.net(text_embed)
        return torch.sigmoid(x).squeeze(-1)

class NeedClassifier(nn.Module):
    """
    Simple classifier mapping embedding -> discrete need label.
    """
    def __init__(self, in_dim: int = TEXT_EMBED_DIM, n_classes: int = 4, hidden: int = 256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, n_classes)
        )

    def forward(self, text_embed: torch.Tensor) -> torch.Tensor:
        """
        Returns logits over need classes (no softmax applied).
        """
        return self.net(text_embed)

class DemoEmbedder(nn.Module):
    """
    Embed demographic/meta features into a dense vector.
    Accepts a dict or tensor; for prototype we expect a numeric tensor.
    """
    def __init__(self, input_dim: int = 8, emb_dim: int = DEMO_EMBED_DIM):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, emb_dim),
            nn.ReLU()
        )
        self.input_dim = input_dim

    def forward(self, meta_tensor: torch.Tensor) -> torch.Tensor:
        """
        meta_tensor: (batch, input_dim) numeric tensor
        returns: (batch, emb_dim)
        """
        return self.net(meta_tensor)

class RiskProjector(nn.Module):
    """
    Project concatenated features (hlex + hmeta + affect + demo) -> risk logits.
    """
    def __init__(self, in_dim: int, out_dim: int = RISK_PROJECTION_DIM, hidden: int = 128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, out_dim)
        )

    def forward(self, features: torch.Tensor) -> torch.Tensor:
        """
        features: (batch, in_dim)
        returns: softmax probabilities over risk categories
        """
        logits = self.net(features)
        return F.softmax(logits, dim=-1)

class PhaseIParser:
    """
    High-level Phase I parser that combines AffectNet, NeedClassifier, risk projector, and demo embedder.
    parse(text_embed, hlex, hmeta, meta_tensor) -> dict c = {'e','n','p','d'}
    """
    def __init__(self,
                 text_embed_dim: int = TEXT_EMBED_DIM,
                 hlex_dim: int = 4,
                 hmeta_dim: int = 3,
                 demo_input_dim: int = 8,
                 risk_out_dim: int = RISK_PROJECTION_DIM,
                 device: Optional[str] = None):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.affect_net = AffectNet(in_dim=text_embed_dim).to(self.device)
        # assume need classes 4 by default
        self.need_clf = NeedClassifier(in_dim=text_embed_dim).to(self.device)
        self.demo_embedder = DemoEmbedder(input_dim=demo_input_dim).to(self.device)
        # risk projector input dim = hlex + hmeta + affect(1) + demo_emb
        in_dim = hlex_dim + hmeta_dim + 1 + DEMO_EMBED_DIM
        self.risk_proj = RiskProjector(in_dim=in_dim, out_dim=risk_out_dim).to(self.device)

    def compute_risk(self, hlex: np.ndarray, hmeta: np.ndarray, affect_score: float, demo_emb: torch.Tensor) -> np.ndarray:
        """
        Compute risk probability vector p using softmax projection.
        hlex: np array shape (hlex_dim,), hmeta: (hmeta_dim,), affect_score: scalar, demo_emb: tensor (emb_dim,)
        """
        hlex_t = torch.from_numpy(hlex).float().unsqueeze(0).to(self.device)   # (1, hlex_dim)
        hmeta_t = torch.from_numpy(hmeta).float().unsqueeze(0).to(self.device) # (1, hmeta_dim)
        aff_t = torch.tensor([[affect_score]], dtype=torch.float32).to(self.device)  # (1,1)
        demo_t = demo_emb.unsqueeze(0)  # (1, emb_dim)
        concat = torch.cat([hlex_t, hmeta_t, aff_t, demo_t], dim=-1)  # (1, in_dim)
        p = self.risk_proj(concat)  # (1, risk_dim) softmaxed
        return p.squeeze(0).cpu().numpy()

    def embed_demo(self, meta_tensor: torch.Tensor) -> torch.Tensor:
        """
        Embed demographic features via demo_embedder.
        meta_tensor: (input_dim,) or (batch, input_dim)
        """
        if meta_tensor.dim() == 1:
            meta_tensor = meta_tensor.unsqueeze(0)
        return self.demo_embedder(meta_tensor.to(self.device)).squeeze(0).cpu()

    def parse(self, text_embed: torch.Tensor, hlex: np.ndarray, hmeta: np.ndarray, meta_tensor: Optional[torch.Tensor] = None) -> Dict[str, Any]:
        """
        Main API.
        text_embed: torch.Tensor of shape (embed_dim,) or (1, embed_dim)
        hlex/hmeta: numpy arrays produced from preprocess_text.compute_hlex_hmeta
        meta_tensor: optional numeric tensor for demo features, fallback zeros if None
        Returns c = {'e': float, 'n': int, 'p': np.array, 'd': np.array}
        """
        # normalize text_embed
        if text_embed.dim() == 1:
            text_embed = text_embed.unsqueeze(0)
        text_embed = text_embed.to(self.device).float()

        with torch.no_grad():
            affect = self.affect_net(text_embed).cpu().item()  # scalar in [0,1]
            need_logits = self.need_clf(text_embed).cpu().squeeze(0).numpy()  # (n_classes,)
            need_label = int(np.argmax(need_logits))

            if meta_tensor is None:
                # default zero meta
                meta_tensor = torch.zeros(self.demo_embedder.input_dim)
            demo_emb = self.embed_demo(meta_tensor)  # tensor on CPU

            p = self.compute_risk(hlex, hmeta, affect, torch.from_numpy(demo_emb.numpy() if isinstance(demo_emb, np.ndarray) else demo_emb.numpy() if hasattr(demo_emb,'numpy') else demo_emb).float())
            # ensure p is numpy array
            p = np.asarray(p, dtype=float)

        return {"e": float(affect), "n": int(need_label), "p": p, "d": np.asarray(demo_emb.cpu())}
