# -*- coding: utf-8 -*-
"""
Text-to-Video Key Elements Extractor
Outputs:
- characters: People/characters (names, professions, roles)
- objects:    Objects/items (head nouns, lemmatized for English)
- actions:    Actions/verbs (lemmatized for English)
- locations:  Locations (head nouns, lemmatized for English)
- scenery:    Scene/environment/style (keywords such as sunset/cyberpunk/雨天)
"""

from typing import Dict, List, Callable, Optional
from langchain_deepseek import ChatDeepSeek
import re, os, json
os.environ["DEEPSEEK_API_KEY"] = ""


_HAS_JIEBA = False
_HAS_SPACY = False
try:
    import jieba.posseg as pseg
    _HAS_JIEBA = True
except Exception:
    _HAS_JIEBA = False

try:
    import spacy
    _nlp = spacy.load("en_core_web_sm")
    _HAS_SPACY = True
except Exception:
    _nlp = None
    _HAS_SPACY = False


CN_RE = re.compile(r'[\u4e00-\u9fff]')

def is_cn(text: str) -> bool:
    """Detect if a string contains Chinese characters."""
    return bool(CN_RE.search(text))

def normalize(s: str) -> str:
    """Normalize whitespace and trim a string."""
    s = re.sub(r'\s+', ' ', s or '').strip()
    return s

def uniq(seq: List[str]) -> List[str]:
    """Remove duplicates while preserving order and clean punctuation."""
    seen, out = set(), []
    for x in seq:
        if not x:
            continue
        x = x.strip().strip('，,。.;；!?！？"\'()[]')
        if not x:
            continue
        if x not in seen:
            seen.add(x)
            out.append(x)
    return out

SCENERY_ZH = set("""
日出 日落 黎明 黄昏 夜晚 白天 雨天 下雨 雪天 下雪 暴风雨 雾 霓虹 赛博朋克 蒸汽朋克
未来风 科幻 古代 中世纪 奇幻 童话 废土 末日 水下 火山 极光 星空 银河 太空
像素风 低多边形 水彩 油画
""".split())

SCENERY_EN = set("""
sunrise sunset dawn dusk night daytime rainy rain snow snowy storm fog neon cyberpunk steampunk
futuristic sci-fi ancient medieval fantasy fairytale fairy tale post-apocalyptic underwater
volcanic aurora northern lights starry galaxy space low poly pixel art watercolor oil painting
""".split())

# Primary scene keywords (English, simplified list)
SCENERY_EN_PRIMARY = {
    "sunrise","sunset","dawn","dusk","night","daytime","rain","rainy","snow","snowy","storm","fog",
    "neon","cyberpunk","steampunk","futuristic","sci-fi","medieval","fantasy","underwater","aurora",
    "starry","galaxy","space","watercolor","oil","oil painting","pixel","pixel art","low","low poly"
}

# Common role keywords (Chinese + POS nr)
ROLE_ZH = set("男人 女人 男孩 女孩 小孩 孩子 婴儿 老人 情侣 舞者 士兵 武士 法师 宇航员 飞行员 厨师 医生 护士 警察 超级英雄 导演 演员 歌手 魔术师 画家 摄影师 模特 店员 旅客 游客 冒险者 猎人 渔夫 骑士 僧人".split())
# Common role keywords (English)
ROLE_EN = set("""
man woman boy girl child baby elderly couple dancer soldier samurai wizard astronaut pilot
chef doctor nurse police policeman policewoman superhero actor actress singer magician painter
photographer model clerk tourist traveler adventurer hunter fisher knight monk
""".split())

# Location suffixes for Chinese (heuristic)
LOC_SUFFIX_ZH = tuple("街 路 巷 市 场 公园 森林 丛林 沙漠 海滩 海边 海岸 山 山脉 谷 洞穴 河 湖 海 城 市中心 地铁 车站 机场 港口 码头 寺庙 城堡 学校 办公室 厨房 卧室 屋顶 桥 体育场 博物馆 美术馆 广场 海岸线 海湾 车库 码头区 海域 海面".split())

def _to_singular(word: str) -> str:
    """Return lemma as-is (spaCy already lemmatizes)."""
    return word

def _clean_token_text(t) -> str:
    """Return lowercase lemma of a spaCy token."""
    return t.lemma_.lower().strip()

def head_noun_from_chunk(chunk) -> str:
    """
    Extract the head noun (lemma) from a noun chunk.
    e.g., 'a neon umbrella' -> 'umbrella'
    """
    root = chunk.root
    if root.pos_ in {"NOUN", "PROPN"}:
        return _to_singular(_clean_token_text(root))
    for t in reversed(list(chunk)):
        if t.pos_ in {"NOUN", "PROPN"}:
            return _to_singular(_clean_token_text(t))
    return _clean_token_text(root)

def head_noun_from_prep_span(prep_token, pobj_token, doc) -> str:
    """
    Extract the head noun (lemma) from a prepositional phrase.
    e.g., 'on the beach at sunset' -> 'beach'
    """
    if pobj_token.pos_ in {"NOUN", "PROPN"}:
        return _to_singular(_clean_token_text(pobj_token))
    for t in pobj_token.subtree:
        if t.pos_ in {"NOUN", "PROPN"}:
            return _to_singular(_clean_token_text(t))
    return _clean_token_text(pobj_token)

def just_keyword(phrase: str) -> str:
    """
    Fallback keyword extractor: get the last meaningful word in a phrase.
    e.g., 'a drone following behind' -> 'drone'
    """
    phrase = (phrase or "").lower().strip()
    phrase = re.sub(r'^(a|an|the)\s+', '', phrase)
    phrase = re.sub(r'\b(following behind|in the background|at night|at sunset)\b', '', phrase).strip()
    toks = re.findall(r"[a-z][a-z\-']+", phrase)
    if not toks:
        return phrase
    for w in reversed(toks):
        if w not in {"with","in","on","at","by","near","following","behind","neon","cyberpunk","steampunk"}:
            return w
    return toks[-1]

# ----------------------------
# Chinese extractor
# ----------------------------
def extract_cn(prompt: str) -> Dict[str, List[str]]:
    """
    Extract key elements from Chinese prompts.
    Uses jieba.posseg if available; falls back to regex heuristics.
    """
    chars, objs, acts, locs, scen = [], [], [], [], []

    if _HAS_JIEBA:
        words = list(pseg.cut(prompt))
        for w, flag in words:
            w = w.strip()
            if not w:
                continue

            # Scene keywords
            if w in SCENERY_ZH:
                scen.append(w)

            # Characters
            if flag == 'nr' or w in ROLE_ZH:
                chars.append(w)
                continue

            # Locations
            if flag == 'ns' or w.endswith(LOC_SUFFIX_ZH):
                locs.append(w)
                continue

            # Actions
            if flag.startswith('v'):
                acts.append(w)
                continue

            # Objects
            if flag.startswith('n') and (flag != 'ns') and (w not in ROLE_ZH) and (not w.endswith(LOC_SUFFIX_ZH)):
                if 1 <= len(w) <= 12:
                    objs.append(w)

        # Regex patterns for prepositional phrases and actions
        for m in re.finditer(r'(?:在|于)([^，。；、!?！？]+)', prompt):
            cand = m.group(1).strip()
            cand = re.split(r'[，。；、!?！？]', cand)[0].strip()
            if 1 <= len(cand) <= 20:
                locs.append(cand)

        for m in re.finditer(r'(手持|拿着|带着|佩戴|使用|抱着|背着)([^，。；、!?！？\s]+)', prompt):
            thing = m.group(2).strip()
            if thing and (1 <= len(thing) <= 12):
                objs.append(thing)

        for m in re.finditer(r'(骑着|骑|开着|开|握着|举着|弹奏|演奏|操作)([^，。；、!?！？\s]+)', prompt):
            acts.append(m.group(1))
            thing = m.group(2).strip()
            if thing and (1 <= len(thing) <= 12):
                objs.append(thing)

    else:
        # Lightweight regex-based fallback
        for w in SCENERY_ZH:
            if w in prompt:
                scen.append(w)

        for w in ROLE_ZH:
            if w in prompt:
                chars.append(w)

        for m in re.finditer(r'(?:在|于)([^，。；、!?！？]+)', prompt):
            cand = m.group(1).strip()
            cand = re.split(r'[，。；、!?！？]', cand)[0].strip()
            if 1 <= len(cand) <= 20:
                locs.append(cand)

        for m in re.finditer(r'(手持|拿着|带着|佩戴|使用|抱着|背着)([^，。；、!?！？\s]+)', prompt):
            objs.append(m.group(2).strip())

        for m in re.finditer(r'(骑着|骑|开着|开|握着|举着|弹奏|演奏|操作)([^，。；、!?！？\s]+)', prompt):
            acts.append(m.group(1))
            objs.append(m.group(2).strip())

        for suf in LOC_SUFFIX_ZH:
            for m in re.finditer(r'([\u4e00-\u9fff]{1,8}%s)' % re.escape(suf), prompt):
                locs.append(m.group(1))

    return {
        "characters": uniq(chars),
        "objects":    uniq(objs),
        "actions":    uniq(acts),
        "locations":  uniq(locs),
        "scenery":    uniq(scen),
    }


def extract_en(prompt: str) -> Dict[str, List[str]]:
    """
    Extract key elements from English prompts.
    Uses spaCy if available; falls back to regex heuristics.
    """
    chars, objs, acts, locs, scen = [], [], [], [], []

    if _HAS_SPACY and _nlp:
        doc = _nlp(prompt)

        # Scenery keywords (from tokens)
        for t in doc:
            tt = t.lemma_.lower()
            if tt in SCENERY_EN_PRIMARY:
                scen.append(tt)

        # Named entities: PERSON, LOC, GPE, FAC
        for ent in doc.ents:
            if ent.label_ == "PERSON":
                chars.append(ent.text)
            elif ent.label_ in {"GPE","LOC","FAC"}:
                if ent.root.pos_ in {"NOUN","PROPN"}:
                    locs.append(_clean_token_text(ent.root))
                else:
                    locs.append(ent.text)

        # Roles from lexicon
        for t in doc:
            if t.text.lower() in ROLE_EN:
                chars.append(t.text)

        # Actions: verbs (lemmatized)
        for t in doc:
            if t.pos_ == "VERB":
                lemma = t.lemma_.lower()
                if 1 <= len(lemma) <= 30:
                    acts.append(lemma)

        # Prepositional phrases: extract head nouns
        for token in doc:
            if token.dep_ == "prep" and token.text.lower() in {"in","at","on","near","beside","under","over","by"}:
                pobj = None
                for child in token.children:
                    if child.dep_ in {"pobj","pcomp"}:
                        pobj = child
                        break
                if pobj:
                    loc_head = head_noun_from_prep_span(token, pobj, doc)
                    if loc_head:
                        locs.append(loc_head)

        # Objects: head nouns from noun chunks
        for chunk in doc.noun_chunks:
            key = head_noun_from_chunk(chunk)
            if key and key not in {"i","you","we","they","he","she"}:
                objs.append(key)

        # "with/using/holding..." patterns
        for m in re.finditer(
            r'\b(with|using|holding|carrying|riding|playing)\b\s+([a-z0-9 \-]+?)(?=[\.,;]| and\b| in\b|$)',
            prompt, flags=re.I
        ):
            tail = m.group(2).strip()
            span = _nlp(tail)
            cand = None
            for t in span:
                if t.pos_ in {"NOUN","PROPN"}:
                    cand = t.lemma_.lower()
                    break
            objs.append(cand or just_keyword(tail))

    else:
        # Regex-based fallback
        s = " " + prompt.lower() + " "

        for w in SCENERY_EN_PRIMARY:
            if f" {w} " in s:
                scen.append(w)

        for w in ROLE_EN:
            if f" {w} " in s:
                chars.append(w)

        for m in re.finditer(
            r'\b(in|at|on|near|beside|under|over|by)\b\s+([a-z0-9 \-]+?)(?=[\.,;]| and\b| with\b|$)',
            prompt, flags=re.I
        ):
            locs.append(just_keyword(m.group(2)))

        for m in re.finditer(
            r'\b(with|using|holding|carrying|riding|playing)\b\s+([a-z0-9 \-]+?)(?=[\.,;]| and\b| in\b|$)',
            prompt, flags=re.I
        ):
            objs.append(just_keyword(m.group(2)))

        for m in re.finditer(r'\b(a|an|the)\s+([a-z0-9 \-]+)', prompt, flags=re.I):
            objs.append(just_keyword(m.group(2)))

        for m in re.finditer(r'\b([a-z]{3,}ing|[a-z]{3,}s)\b', prompt):
            acts.append(m.group(1).lower())

    return {
        "characters": uniq(chars),
        "objects":    uniq(objs),
        "actions":    uniq(acts),
        "locations":  uniq(locs),
        "scenery":    uniq(scen),
    }

# ----------------------------
# Optional: LLM integration
# ----------------------------
def extract_with_llm(prompt: str) -> Dict[str, List[str]]:
    """
    Use a custom LLM call to extract structured key elements.
    The llm_call function should return JSON:
    {
      "characters": [],
      "objects": [],
      "actions": [],
      "locations": [],
      "scenery": []
    }
    """
    llm = ChatDeepSeek(
        model="deepseek-chat",
        temperature=0,
        max_tokens=None,
        timeout=None,
        max_retries=2,
    )

    sys_prompt = """
        Extract key elements for text-to-video generation.\n
        Return strict JSON with keys: characters, objects, actions, locations, scenery.\n
        Each value is a list of single-word or short head nouns/verbs (<=2 words). No commentary.
        """

    query = [
        ("system", sys_prompt),
        ("human", prompt),
    ]
    # query = f"{sys_hint}\nTEXT:\n{prompt}\n"
    # try:
    raw = llm.invoke(query)
    raw_text = raw.content.strip()
    match = re.search(r'\{.*\}', raw_text, re.S)
    if not match:
        raise ValueError("No valid JSON object found in the LLM response.")

    data = json.loads(match.group(0))

    return data
    # return {
    #     "characters": uniq([str(x) for x in data.get("characters", [])]),
    #     "objects":    uniq([str(x) for x in data.get("objects", [])]),
    #     "actions":    uniq([str(x) for x in data.get("actions", [])]),
    #     "locations":  uniq([str(x) for x in data.get("locations", [])]),
    #     "scenery":    uniq([str(x) for x in data.get("scenery", [])]),
    # }

    # except Exception:
    #     return {"characters": [], "objects": [], "actions": [], "locations": [], "scenery": []}

def merge_lists(*lists: List[str]) -> List[str]:
    """Merge multiple lists and deduplicate."""
    out = []
    for ls in lists:
        out.extend(ls or [])
    return uniq(out)

def extract_atoms(prompt: str,
                  use_llm: bool = False) -> Dict[str, List[str]]:
    """
    Main API for extracting key elements from prompts.
    - Chinese: uses jieba.posseg if available, falls back to regex.
    - English: uses spaCy if available, falls back to regex.
    - Keeps only keywords (head nouns, verb lemmas).
    - Optionally merges with an external LLM-based extractor.
    """
    prompt = normalize(prompt)
    if not prompt:
        return {"characters": [], "objects": [], "actions": [], "locations": [], "scenery": []}

    if use_llm:
        res = extract_with_llm(prompt)
    else:
        res = extract_cn(prompt) if is_cn(prompt) else extract_en(prompt)
        # return {
        #     "characters": merge_lists(base["characters"], llm_res["characters"]),
        #     "objects":    merge_lists(base["objects"],    llm_res["objects"]),
        #     "actions":    merge_lists(base["actions"],    llm_res["actions"]),
        #     "locations":  merge_lists(base["locations"],  llm_res["locations"]),
        #     "scenery":    merge_lists(base["scenery"],    llm_res["scenery"]),
        # }
    return res

# ----------------------------
# Quick tests
# ----------------------------
if __name__ == "__main__":
    samples = [
        # English
        "An astronaut riding a bicycle on the beach at sunset, with a drone following behind.",
        # "A woman in a medieval castle at night, holding a lantern with neon reflections.",
    ]
    for s in samples:
        print("TEXT:", s)
        print(extract_atoms(s, use_llm=True))
