# metaphor_detector.py
# Minimal metaphor detection module: takes text embedding + optional knowledge-graph features
# and returns a ratio/score m in [0,1] representing metaphor likelihood/ratio (Eq.13 in paper).

from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class MetaphorDetector(nn.Module):
    """
    Small MLP that consumes text embedding and optional KG features and outputs
    a scalar in [0,1] indicating the degree of metaphorical expression.
    """

    def __init__(self, text_dim: int = 768, kg_dim: int = 64, hidden: int = 256):
        super().__init__()
        self.text_dim = text_dim
        self.kg_dim = kg_dim
        self.net = nn.Sequential(
            nn.Linear(text_dim + kg_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden // 2),
            nn.ReLU(),
            nn.Linear(hidden // 2, 1)
        )

    def forward(self, text_embed: torch.Tensor, kg_features: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Args:
          text_embed: (text_dim,) or (B, text_dim)
          kg_features: (kg_dim,) or (B, kg_dim) or None (filled with zeros)
        Returns:
          scalar(s) in [0,1] (torch.Tensor)
        """
        if text_embed.dim() == 1:
            text_embed = text_embed.unsqueeze(0)
        B = text_embed.shape[0]
        if kg_features is None:
            kg = torch.zeros((B, self.kg_dim), device=text_embed.device)
        else:
            if kg_features.dim() == 1:
                kg = kg_features.unsqueeze(0)
            else:
                kg = kg_features
        x = torch.cat([text_embed, kg.to(text_embed.device)], dim=-1)
        logits = self.net(x).squeeze(-1)
        score = torch.sigmoid(logits)  # in [0,1]
        return score


def detect_metaphor(text_embed: torch.Tensor, kg_features: Optional[torch.Tensor] = None, model: Optional[MetaphorDetector] = None) -> float:
    """
    Convenience function: returns a python float score in [0,1].
    If model is None, a default lightweight model is created (for prototype).
    """
    if model is None:
        model = MetaphorDetector(text_dim=text_embed.shape[-1], kg_dim=(kg_features.shape[-1] if (kg_features is not None and kg_features.dim() > 0) else 0))
    model.eval()
    with torch.no_grad():
        s = model(text_embed, kg_features)
    # return scalar (if batch) take first
    return float(s.squeeze(0).cpu().item())
