import string
import re
from typing import Dict, Any
from verl import DataProto


from lingua import Language, LanguageDetectorBuilder
from sudachipy import dictionary, tokenizer
#languages = [Language.ENGLISH, Language.SWAHILI, Language.INDONESIAN]
#detector = (
#        LanguageDetectorBuilder.from_languages(*languages)
#        .with_preloaded_language_models()
#        .build()
#    )



def remove_latex_math(text: str) -> str:
    """Remove LaTeX math expressions from a string."""
    # $$ ... $$ (multiline)
    text = re.sub(r"\$\$.*?\$\$", "", text, flags=re.DOTALL)
    # $ ... $ (single line / inline)
    text = re.sub(r"\$(?!\$).*?\$", "", text)
    # \[ ... \] display math
    text = re.sub(r"\\\[.*?\\\]", "", text, flags=re.DOTALL)
    # \\( ... \\) inline math
    text = re.sub(r"\\\\\(.*?\\\\\)", "", text)

    # Specific math environments (equation, align, gather, ...)
    math_envs = [
        "equation", "align", "align*", "cases", "multline", "multline*", "gather", "gather*",
    ]
    for env in math_envs:
        text = re.sub(rf"\\begin\{{{env}\}}.*?\\end\{{{env}\}}", "", text, flags=re.DOTALL)
    return text

# -----------------------------
#  Character‑level language id
# -----------------------------

def is_korean(ch: str) -> bool:
    return "\uAC00" <= ch <= "\uD7A3"

def is_english(ch: str) -> bool:
    return ch in string.ascii_letters

def is_digit(ch: str) -> bool:
    return ch in string.digits

#  Basic Japanese detection (hiragana / katakana)

def is_japanese(ch: str) -> bool:
    return (
        "\u3040" <= ch <= "\u309F"  # Hiragana
        or "\u30A0" <= ch <= "\u30FF"  # Katakana
    )

#  Chinese Han characters (CJK Unified Ideographs + extensions A/B)

def is_chinese(ch: str) -> bool:
    return (
        "\u4E00" <= ch <= "\u9FFF"   # CJK Unified Ideographs
        or "\u3400" <= ch <= "\u4DBF"  # CJK Extension A
        or "\uF900" <= ch <= "\uFAFF"  # CJK Compatibility Ideographs
    )

# -----------------------------
#  Tokenisation (multi‑bleu style)
# -----------------------------

def perl_style_tokenize(text: str) -> str:
    """Roughly replicate the tokenisation logic in multi-bleu.perl."""
    text = (
        text.replace("’", "'")
            .replace("`", "'")
            .replace("“", '"')
            .replace("”", '"')
    )
    text = re.sub(r"([\[\]\(\)\{\}\<\>])", r" \1 ", text)  # brackets
    text = re.sub(r"([!?;:])", r" \1 ", text)                    # sentence enders
    text = re.sub(r"(\.{1,})", r" \1 ", text)                   # ellipsis / dots
    text = re.sub(r"(,)", r" \1 ", text)                        # commas
    text = re.sub(r'(")', r' \1 ', text)                         # double quotes
    text = re.sub(r"(')", r' \1 ', text)                         # single quotes
    return re.sub(r"\s+", " ", text).strip()

# -----------------------------
#  Math / LaTeX token checks
# -----------------------------
#  Utility: LaTeX / math removal
# -----------------------------

def is_math_expression(token: str) -> bool:
    math_chars = set("0123456789+-*/^%=().")
    return token and all(ch in math_chars for ch in token)

def is_latex_expression(token: str) -> bool:
    if len(token) >= 4 and token.startswith("$$") and token.endswith("$$"):
        return True
    return len(token) >= 2 and token.startswith("$") and token.endswith("$")

# -----------------------------
#  Character presence helpers
# -----------------------------

def has_korean(token: str) -> bool:
    return any(is_korean(ch) for ch in token)

def has_english(token: str) -> bool:
    return any(is_english(ch) for ch in token)

def has_chinese(token: str) -> bool:
    return any(is_chinese(ch) for ch in token)

_CYRILLIC_BLOCKS = [
    ("\u0400", "\u04FF"),  # 기본 Cyrillic
    ("\u0500", "\u052F"),  # Supplement
    ("\u2DE0", "\u2DFF"),  # Extended‑A
    ("\uA640", "\uA69F"),  # Extended‑B
    ("\u1C80", "\u1C8F"),  # Extended‑C
]

def is_cyrillic(ch: str) -> bool:
    """한 글자가 키릴 문자인지 판별"""
    return any(start <= ch <= end for start, end in _CYRILLIC_BLOCKS)

def has_cyrillic(text: str) -> bool:
    """문자열에 키릴 문자가 하나라도 있으면 True"""
    return any(is_cyrillic(ch) for ch in text)


_THAI_BLOCKS = [
    ("\u0E00", "\u0E7F"),  # Thai (자모·모음·성조·숫자 모두 포함)
]

def is_thai(ch: str) -> bool:
    """한 글자가 타이어(Thai script)인지 판별"""
    return any(start <= ch <= end for start, end in _THAI_BLOCKS)

def has_thai(text: str) -> bool:
    """문자열에 Thai 문자가 하나라도 있으면 True"""
    return any(is_thai(ch) for ch in text)


_DEVANAGARI_BLOCKS = [
    ("\u0900", "\u097F"),  # 기본 Devanagari
    ("\uA8E0", "\uA8FF"),  # Extended
    ("\u11B00", "\u11B5F") # Extended-A
]

def is_devanagari(ch: str) -> bool:
    """한 글자가 데바나가리(네팔어에 쓰이는 문자)인지 판별"""
    return any(start <= ch <= end for start, end in _DEVANAGARI_BLOCKS)

def has_devanagari(text: str) -> bool:
    """문자열에 데바나가리 문자가 하나라도 있으면 True"""
    return any(is_devanagari(ch) for ch in text)

# 토큰 단위 검사 헬퍼
def has_nepali(token: str) -> bool:
    """토큰에 네팔어(데바나가리) 글자가 포함돼 있으면 True"""
    return has_devanagari(token)

def has_japan(token: str) -> bool:
    """토큰에 네팔어(데바나가리) 글자가 포함돼 있으면 True"""
    return is_japanese(token)

# -----------------------------
#  Core statistics function
# -----------------------------

def get_statistics(text: str, lang:str) -> Dict[str, float]:
    """Return token‑level language statistics for a single string."""
    # Pre‑processing
    text = remove_latex_math(text)
    text = perl_style_tokenize(text)
    text = text.replace("\n", " ").replace("\t", " ").replace("\r", " ").strip()
    if lang == "ja":
        tokenizer_obj = dictionary.Dictionary().create()
        mode = tokenizer.Tokenizer.SplitMode.C  # C=coarse, B=middle, A=fine
        tokens = [] 
        for sent in text.split("\n"):
            tokens += [m.surface() for m in tokenizer_obj.tokenize(sent, mode)]
    else:
        tokens = text.split()

    # Aggregates
    total_char_cnt = 0
    total_word_cnt = 0
    non_eng_word_cnt = 0
    non_eng_char_cnt = 0
    code_switch_word_cnt = 0

    korean_word_cnt = 0
    korean_char_cnt = 0
    chinese_word_cnt = 0
    chinese_char_cnt = 0
    
    cyrillic_word_cnt = 0
    cyrillic_char_cnt = 0


    japanese_word_cnt = 0
    japanese_char_cnt = 0


    tahi_word_cnt = 0
    tahi_char_cnt = 0

    nepali_word_cnt = 0
    nepali_char_cnt = 0

    for token in tokens:
        # Skip standalone math / LaTeX tokens
        if is_math_expression(token) or is_latex_expression(token):
            continue

        total_word_cnt += 1
        token_len = len(token)
        total_char_cnt += token_len

        kor = has_korean(token)
        eng = has_english(token)
        chi = has_chinese(token)
        cyri= has_cyrillic(token)
        neplai = has_nepali(token) if lang == "ne" else False
        #swahi = has_swahili(token) if lang == "sw" else False
        #indon = has_indonesian(token) if lang == "id" else False
        
        swahi = False
        indon = False 
        thai = has_thai(token) if lang == "th" else False
        jap = has_japan(token) 


        # ---- Mono‑lingual token checks ----
        if kor and not eng and not chi and not cyri and not swahi and not indon and not jap:
            korean_word_cnt += 1
            korean_char_cnt += token_len 

        elif chi and not eng and not kor and not cyri and not swahi and not indon and not jap:
            chinese_word_cnt += 1
            chinese_char_cnt += token_len

        elif jap and not eng and not kor and not cyri and not swahi and not indon and not chi:
            japanese_word_cnt += 1
            japanese_char_cnt += token_len

        elif thai and not eng and not kor and not chi and not cyri and not swahi and not indon:
            tahi_word_cnt += 1
            tahi_char_cnt += token_len

        elif cyri and not eng and not chi and not kor and not swahi and not indon and not jap:
            cyrillic_word_cnt += 1
            cyrillic_char_cnt += token_len

        elif neplai and not eng and not chi and not kor and not cyri and not swahi and not indon:
            nepali_word_cnt += 1
            nepali_char_cnt += token_len

        #elif swahi and not chi and not kor and not cyri and not indon:
        #    swahili_word_cnt += 1
        #    swahili_char_cnt += token_len
        #elif indon and not chi and not kor and not cyri and not swahi:
        #    indonesia_word_cnt += 1
        #    indonesia_char_cnt += token_len

        
            

        # ---- Non‑English (anything without English letters)
        if not eng and not swahi and not indon:
            non_eng_word_cnt += 1
            non_eng_char_cnt += token_len


        # ---- Code‑switch tokens (English + another CJK script)
        if eng and (kor or chi or cyri ):
            code_switch_word_cnt += 1

    # Safe divide helper
    def div(num: int, den: int) -> float:
        return num / den if den else 0.0

    return {
        "response/non_english_word_ratio": div(non_eng_word_cnt, total_word_cnt),
        "response/non_english_character_ratio": div(non_eng_char_cnt, total_char_cnt),
        "response/korean_word_ratio": div(korean_word_cnt, total_word_cnt),
        "response/korean_character_ratio": div(korean_char_cnt, total_char_cnt),
        "response/chinese_word_ratio": div(chinese_word_cnt, total_word_cnt),
        "response/chinese_character_ratio": div(chinese_char_cnt, total_char_cnt),
        "response/code_switching_word_ratio": div(code_switch_word_cnt, total_word_cnt),
        "response/cyrillic_word_ratio": div(cyrillic_word_cnt, total_word_cnt),
        "response/cyrillic_character_ratio": div(cyrillic_char_cnt, total_char_cnt),
        #"response/swahili_word_ratio": div(swahili_word_cnt, total_word_cnt),
        #"response/swahili_character_ratio": div(swahili_char_cnt, total_char_cnt),
        #"response/indonesian_word_ratio": div(indonesia_word_cnt, total_word_cnt),
        #"response/indonesian_character_ratio": div(indonesia_char_cnt, total_char_cnt),
        #"response/thai_word_ratio": div(tahi_word_cnt, total_word_cnt),
        #"response/thai_character_ratio": div(tahi_char_cnt, total_char_cnt),
        "response/nepali_word_ratio": div(nepali_word_cnt, total_word_cnt),
        "response/nepali_character_ratio": div(nepali_char_cnt, total_char_cnt),
        "response/japanese_word_ratio": div(japanese_word_cnt, total_word_cnt),
        "response/japanese_character_ratio": div(japanese_char_cnt, total_char_cnt),
    }

# -----------------------------
#  Batch‑level aggregator
# -----------------------------

def compute_language_statics_metrics(batch: DataProto, tokenizer) -> Dict[str, Any]:
    """Compute averaged language statistics over a DataProto batch."""
    # Decode responses
    responses = tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True)

    # Accumulators (initialise every tracked metric to 0)
    acc: Dict[str, float] = {
        key: 0.0
        for key in [
            "response/non_english_word_ratio",
            "response/non_english_character_ratio",
            "response/korean_word_ratio",
            "response/korean_character_ratio",
            "response/chinese_word_ratio",
            "response/chinese_character_ratio",
            "response/code_switching_word_ratio",
            "response/cyrillic_word_ratio",
            "response/cyrillic_character_ratio",
            #"response/swahili_word_ratio",
            #"response/swahili_character_ratio",
            #"response/indonesian_word_ratio",
            #"response/indonesian_character_ratio",
            #"response/thai_word_ratio",
            #"response/thai_character_ratio",
            "response/nepali_word_ratio",
            "response/nepali_character_ratio",
            "response/japanese_word_ratio",
            "response/japanese_character_ratio"]
    }

    # Iterate – only consider non‑English source examples (as original logic)
    valid_cnt = 0
    for idx, resp in enumerate(responses):
        if batch.non_tensor_batch["extra_info"][idx].get("lang") == "en":
            continue
        valid_cnt += 1
        stats = get_statistics(resp, batch.non_tensor_batch["extra_info"][idx].get("lang") )
        for k, v in stats.items():
            acc[k] += v

    if valid_cnt:
        for k in acc:
            acc[k] /= valid_cnt
    return acc
