# templates.py
# Storage and utilities for templates and constraint checking.
# Provides get_templates(risk_level, need_label) and ConstraintSet for banned phrases / constraint penalties.

from typing import List, Dict, Optional, Tuple
import re


# Example templates organized by risk_level and need_label
# In practice, replace with evidence-based templates loaded from files under TEMPLATE_PATH.
_DEFAULT_TEMPLATES = {
    "low": {
        0: ["I hear you. Could you tell me a bit more about that?"],
        1: ["Thanks for sharing — would you like some resources on this topic?"],
    },
    "medium": {
        0: ["I understand this feels difficult. You might consider contacting X service."],
        1: ["I can help find relevant resources and steps you can take right now."],
    },
    "high": {
        0: ["This sounds urgent. If you are in danger, please call your local emergency number or go to the nearest emergency room."],
        1: ["I am worried about your safety. Can you tell me if you have immediate plans to harm yourself?"]
    }
}


def get_templates(risk_level: str, need_label: int) -> List[str]:
    """
    Return a list of templates given risk_level ('low'|'medium'|'high') and need_label (int).
    Falls back to generic templates if specific not available.
    """
    rl = risk_level.lower()
    if rl not in _DEFAULT_TEMPLATES:
        rl = "low"
    bucket = _DEFAULT_TEMPLATES[rl]
    if need_label in bucket:
        return bucket[need_label]
    # fallback to first available templates for that risk level
    return list(bucket.values())[0]


class ConstraintSet:
    """
    Manage a set of constraints, banned phrases, and simple penalty calculation.
    This can be used by ConstrainedDecoder to penalize outputs violating constraints.
    """
    def __init__(self, banned_phrases: Optional[List[str]] = None, banned_ngrams: Optional[List[Tuple[str, ...]]] = None):
        # store lower-cased banned phrase regexes
        self.banned_phrases = banned_phrases or [
            "I cannot help with suicide",  # example
        ]
        # compile regex for fast checks (word boundaries)
        self._compiled = [re.compile(re.escape(pat), flags=re.IGNORECASE) for pat in self.banned_phrases]
        # banned ngrams as tuples of tokens
        self.banned_ngrams = banned_ngrams or []

    def load_constraints(self, path: Optional[str] = None):
        """
        Placeholder: load constraints from file if needed.
        """
        # left as exercise: read json/csv and update banned_phrases, banned_ngrams
        return

    def is_allowed(self, text: str) -> bool:
        """
        Return True if text does NOT match any banned phrase.
        """
        for rx in self._compiled:
            if rx.search(text):
                return False
        return True

    def constraint_penalty(self, text: str) -> float:
        """
        Simple penalty: count number of banned phrase matches and banned n-gram matches.
        Return a non-negative penalty score (higher => worse).
        """
        penalty = 0.0
        for rx in self._compiled:
            m = rx.findall(text)
            penalty += len(m)
        # ngram penalty: split into tokens
        tokens = re.findall(r"\w+", text.lower())
        for ng in self.banned_ngrams:
            n = len(ng)
            if n == 0:
                continue
            for i in range(len(tokens) - n + 1):
                if tuple(tokens[i:i + n]) == tuple([t.lower() for t in ng]):
                    penalty += 1.0
        return float(penalty)
