from typing import Dict, List, Optional, Tuple, Union
from mmhug.registry import TRANSFORMS
from mmcv import BaseTransform
from mmhug.models.custom_transformers.auto_avsr.datamodule import TextTransform

import re

# 10个以上相同字符连写；8组以上双字符连写；只含“uh/um/er/ah/mmm/h/…”之类填充词
RE_LONG_CHAR = re.compile(r"(.)\1{9,}", re.I)
RE_LONG_BIGRAM = re.compile(r"(.{2})\1{7,}", re.I)
RE_FILLERS_ONLY = re.compile(r"^(?:\s|uh+|um+|er+|ah+|mmm+|h+|…+)+$", re.I)


def is_bad_text_en(s: str, max_total_len: int = 2000, max_token_len: int = 64) -> bool:
    s = s.strip()
    if not s:
        return True
    if len(s) > max_total_len:
        return True
    if any(len(tok) > max_token_len for tok in s.split()):
        return True
    if RE_LONG_CHAR.search(s):
        return True
    if RE_LONG_BIGRAM.search(s):
        return True
    if RE_FILLERS_ONLY.fullmatch(s):
        return True
    return False


def sanitize_en(s: str) -> str:
    # 可选：先简单规整一下（折叠空白、把>=3的连写压到2）
    s = re.sub(r"\s+", " ", s.strip())
    s = re.sub(r"(.)\1{2,}", r"\1\1", s, flags=re.I)
    return s


manual_filter_list = [
    "data/hallo3/hallo3_training_data/videos_cropped_new/dd6565780447013d01a431b35ce4eca6.mp4",
    "data/celebv-text/video_resampled/4bG2LSzaVkQ_12_0.mp4",
]


@TRANSFORMS.register_module()
class ScriptFilterTransform(BaseTransform):
    def __init__(
        self,
        script_key: str = "speech_script",
        language_key: str = "language",
        langauge_pool=["English"],
    ):
        super().__init__()
        self.script_key = script_key
        self.language_key = language_key
        self.langauge_pool = langauge_pool
        self.text_transform = TextTransform()

    def transform(self, results: Dict) -> Dict | Tuple[List, List] | None:
        video_path = results["video_metadata"]["video_path"]
        num_frames = results["video_metadata"]["num_frames"]
        scripts = self.text_transform.tokenize(results[self.script_key])
        if len(scripts) > num_frames:
            raise ValueError(
                f"Script {results[self.script_key]} is longer than {num_frames} frames for {video_path}"
            )
        if "/processed/" in video_path:
            raise ValueError(f"Video {video_path} is from douyin living, skip")
        if "/RAVDNESS/" in video_path:
            raise ValueError(f"Video {video_path} is from RAVDNESS, skip")
        if "/MEAD/" in video_path:
            raise ValueError(f"Video {video_path} is from MEAD, skip")
        language = results.get(self.language_key, None)
        speech_script = results.get(self.script_key, None)
        if language is None:
            raise ValueError(f"No language found in {video_path}")
        if speech_script is None:
            raise ValueError(f"No speech found in {video_path}")
        if self.langauge_pool is not None:
            if language not in self.langauge_pool:
                raise ValueError(
                    f"Language {language} not in {self.langauge_pool} for {video_path}"
                )

        if is_bad_text_en(speech_script):
            raise ValueError(f"Bad script {speech_script} for {video_path}")
        if video_path in manual_filter_list:
            raise ValueError(f"Video {video_path} is in manual filter list, skip")
        return results
