import random
from typing import Dict, List, Optional, Tuple, Union
from mmcv import BaseTransform

from mmhug.registry import TRANSFORMS

from pathlib import Path

COMMON_BEGINNING_PHRASES: tuple[str, ...] = (
    "This video",
    "The video",
    "This clip",
    "The clip",
    "The animation",
    "This image",
    "The image",
    "This picture",
    "The picture",
)

COMMON_CONTINUATION_WORDS: tuple[str, ...] = (
    "shows",
    "depicts",
    "features",
    "captures",
    "highlights",
    "introduces",
    "presents",
)

COMMON_LLM_START_PHRASES: tuple[str, ...] = (
    "In the video,",
    "In this video,",
    "In this video clip,",
    "In the clip,",
    "Caption:",
    *(
        f"{beginning} {continuation}"
        for beginning in COMMON_BEGINNING_PHRASES
        for continuation in COMMON_CONTINUATION_WORDS
    ),
)


@TRANSFORMS.register_module()
class LoadText(BaseTransform):

    def __init__(
        self,
        text_path_key: Optional[str] = "caption_path",
        remove_llm_prefixes: bool = False,
        dummy_captions: Union[str, List[str]] = None,
    ):
        """Load text from text_path_key.

        Args:
            text_path_key (Optional[str], optional): Key of text path in results.
                Defaults to "caption_path". If the key is not in results, we will directly load the key "caption" in results.
                If neither text_path_key nor "caption" is in results, we will use an empty string "" as the caption.
            remove_llm_prefixes (bool, optional): Since the caption is often
                generated by LLM, it may contain common LLM prefixes.
                Whether to remove common LLM prefixes. Defaults to False.
        """
        self.remove_llm_prefixes = remove_llm_prefixes
        self.text_path_key = text_path_key

        self.use_dummy_captions = True
        if dummy_captions is None:
            self.use_dummy_captions = False
        elif isinstance(dummy_captions, str):
            self.dummy_captions = [dummy_captions]
        else:
            self.dummy_captions = dummy_captions

        if self.use_dummy_captions:
            assert (
                self.text_path_key is None
            ), "text_path_key should be None when use_dummy_captions is True"

    def transform(
        self, results: Dict[str, Union[str, Path]]
    ) -> Dict[str, Union[str, Dict[str, Optional[str]]]]:
        """
        Load and clean text for a single sample.

        Args:
            results (Dict):
                A dict containing sample metadata. May include:
                  - self.text_path_key: path to a text file to read.
                  - "caption": pre-loaded caption string.

        Returns:
            Dict with:
              - "caption": the loaded and cleaned text (never None).
              - "caption_metadata": a dict containing:
                    "caption_path": original file path (or None if not used).
        """
        # 1. Attempt to extract a file path and load from disk
        caption_path: Optional[Path] = None
        text = None
        if self.use_dummy_captions:
            text = random.choice(self.dummy_captions)

        elif self.text_path_key:
            # Remove the key so downstream transforms don't see the raw path
            raw_path = results.pop(self.text_path_key, None)
            if raw_path:
                caption_path = Path(raw_path)
                # Read the file once; strip whitespace/newlines
                text = caption_path.read_text(encoding="utf-8").strip()

        # 2. Fallback to in-memory caption if file was not provided or empty
        if text is None:
            text = results.get("caption", "") or ""

        # 3. Optionally strip out known LLM prompt prefixes
        if self.remove_llm_prefixes:
            for prefix in COMMON_LLM_START_PHRASES:
                if text.startswith(prefix):
                    # Remove only the first matching prefix
                    text = text[len(prefix) :].strip()
                    break

        # 4. Update results dict for downstream use
        results["caption"] = text
        results["caption_metadata"] = {
            "caption_path": str(caption_path) if caption_path else None
        }

        return results
