import re, json
from typing import List, Dict, Optional, Any, Tuple
import re, string, unicodedata

import math

ALLOWED_TYPES = {"syntactic", "semantic", "generalize"}

def _strip_to_json(txt: str) -> str:
    txt = re.sub(r"```(?:json)?\s*|\s*```", "", txt, flags=re.I).strip()
    m = re.search(r"\{.*\}", txt, flags=re.S)
    if not m:
        raise ValueError("JSON block not found")
    return m.group(0)

def load_jsonl(path: str) -> List[Dict]:
    with open(path, encoding="utf-8") as f:
        return [json.loads(l) for l in f if l.strip()]

def save_jsonl(data: List[Dict], path: str):
    with open(path, "w", encoding="utf-8") as f:
        for item in data:
            f.write(json.dumps(item, ensure_ascii=False) + "\n")
            
def extract_json(text: str):
    try:
        return json.loads(text)
    except Exception:
        pass
    m = re.search(r"\{[\s\S]*\}", text)
    if m:
        try:
            return json.loads(m.group(0))
        except Exception:
            return None
    return None

def canon_type(t: Optional[str]) -> Optional[str]:
    """Map various aliases to exactly one of {'syntactic','semantic','general'}."""
    if not t:
        return None
    key = re.sub(r"[\s_\-]+", "", t.strip().lower())
    alias = {
        "syntactic": "syntactic", "syntax": "syntactic", "syn": "syntactic", "sy": "syntactic",
        "semantic": "semantic", "semantics": "semantic", "sem": "semantic",
        "general": "general", "generalize": "general",
        "generalization": "general", "generalisation": "general",
    }
    v = alias.get(key)
    return v if v in ALLOWED_TYPES else None


# def extract_tokens_and_entropies(lp: Optional[Dict[str, Any]],
#                                  in_bits: bool = True,
#                                  fill_value: Optional[float] = None) -> Tuple[List[str], List[Optional[float]]]:
#     steps = (lp or {}).get("content", [])
#     toks = [s.get("token", "") for s in steps]
#     ents = token_entropies(steps, include_rest=True, in_bits=in_bits)
#     # 길이 정규화
#     L = min(len(toks), len(ents))
#     toks, ents = toks[:L], ents[:L]
#     if fill_value is not None:
#         ents = [fill_value if e is None else e for e in ents]
#     return toks, ents

try:
    from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS as _SK_EN_SW
    EN_STOP = set(_SK_EN_SW)
except Exception:
    EN_STOP = {
        "a","an","the","and","or","but","if","then","that","this","these","those",
        "of","to","in","on","for","from","by","with","about","as","at","into","over","after","before",
        "is","am","are","was","were","be","been","being","do","does","did","doing","have","has","had","having",
        "can","could","may","might","must","shall","should","will","would",
        "i","you","he","she","it","we","they","me","him","her","us","them",
        "my","your","his","its","our","their","not","no","nor","so","too","very","just","only","also","such",
        "which","who","whom","whose","what","when","where","why","how"
    }

_ALPHA = re.compile(r'[a-z]')  # 영문 존재 여부

def _only_punct_or_symbol(s: str) -> bool:
    if not s:
        return True
    for ch in s:
        cat = unicodedata.category(ch)
        if not (cat.startswith('P') or cat.startswith('S')):
            return False
    return True

def _normalize_token_en(tok: str) -> str:
    t = tok.lstrip()
    t = t.strip(string.punctuation)
    return t.lower()

def stopword_mask_en(tokens, extra_stop=None, remove_stop=None, digits_as_stop=True):
    """
    반환: 불용어=0, 비불용어=1 의 리스트
    규칙:
      - 영문자가 없는 토큰 → 0
      - 숫자 전용 토큰(digits_as_stop=True) → 0
      - 정규화 후 EN_STOP에 있으면 → 0
      - 나머지 → 1
    """
    extra_stop = set(extra_stop or [])
    remove_stop = set(remove_stop or [])
    mask = []
    for tok in tokens:
        # 1) 영문자가 하나도 없으면 특수문자/기호/이모지로 보고 0
        if not _ALPHA.search(tok.lower()):
            # 단, 내부에 영문이 없더라도 혼합형 숫자+기호만이라면 0
            mask.append(0)
            continue

        norm = _normalize_token_en(tok)

        # 2) 정규화 후 비어있거나 전부 기호/문장부호만 → 0
        if not norm or _only_punct_or_symbol(norm):
            mask.append(0)
            continue

        # 3) 숫자만이면 정책에 따라 0
        if digits_as_stop and norm.isdigit():
            mask.append(0)
            continue

        # 4) 사용자 사전
        if norm in extra_stop:
            mask.append(0)
            continue
        if norm in remove_stop:
            mask.append(1)
            continue

        # 5) 영어 불용어 사전
        mask.append(0 if norm in EN_STOP else 1)
    return mask


def token_entropies(lp_content, include_rest=True, in_bits=False):
    
    ent = []
    for step in lp_content:  
        tops = step.get("top_logprobs", [])
        if not tops:
            ent.append(None)
            continue
        logs = [d["logprob"] for d in tops]          
        ps = [math.exp(x) for x in logs]             
        H = -sum(p * l for p, l in zip(ps, logs))    
        if include_rest:
            prest = max(0.0, 1.0 - sum(ps))
            if prest > 0:
                H -= prest * math.log(prest)
        if in_bits:
            H /= math.log(2.0)
        ent.append(H)
    return ent