import re
import unicodedata
import torch
import torch.nn as nn

def normalize_medical_str(s):
    if s is None: return ""
    s = unicodedata.normalize("NFKD", s).encode("ascii", "ignore").decode("ascii")
    s = re.sub(r"[^\w\s\-']", " ", s.lower())
    return re.sub(r"\s+", " ", s).strip()

class SymbolicViewExtractor:
    """Extraction via lexiques SNOMED-CT (ou autres)"""
    def __init__(self, lexicon_df):
        # lexicon_df doit avoir les colonnes 'id' et 'label'
        self.patterns = []
        for _, row in lexicon_df.iterrows():
            norm_label = normalize_medical_str(row['label'])
            pat = re.compile(rf"\b{re.escape(norm_label)}\b")
            self.patterns.append((row['id'], pat))

    def extract(self, text):
        text_norm = normalize_medical_str(text)
        return [cid for cid, pat in self.patterns if pat.search(text_norm)]

class ConceptEncoder(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size + 1, embed_dim)
        self.no_entity_id = vocab_size

    def forward(self, ids_list):
        if not ids_list:
            ids_tensor = torch.tensor([self.no_entity_id], dtype=torch.long)
        else:
            ids_tensor = torch.tensor(ids_list, dtype=torch.long)
        # Mean pooling sur les entités trouvées dans le document
        return self.embedding(ids_tensor).mean(dim=0)