"""
Utility functions for multimodal grounding annotation pipeline.

This module contains shared functions used across multiple scripts:
- Dataset configurations
- Text processing (reasoning/answer extraction, cleaning)
- Citation extraction and parsing
- Sentence splitting
- File loading utilities
- Timestamp parsing
"""

import os
import re
import json
import logging
from typing import Dict, List, Optional, Tuple, Any

try:
    from datasets import load_from_disk
except ImportError:
    load_from_disk = None

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
logger = logging.getLogger(__name__)

# ==========================================
# DATASET CONFIGURATIONS
# ==========================================

DATASET_CONFIGS = {
    "videommmu_50": {
        "video_dir": "../data/VideoMMMU/videos/",
        "hf_path": "../data/VideoMMMU_sample_50",
    },
    "worldsense_50": {
        "video_dir": "../data/WorldSense/videos/",
        "hf_path": "../data/WorldSense_sample_50",
    },
    "videommmu_200": {
        "video_dir": "../data/VideoMMMU/videos/",
        "hf_path": "../data/VideoMMMU_sample_200",
    },
    "worldsense_200": {
        "video_dir": "../data/WorldSense/videos/",
        "hf_path": "../data/WorldSense_sample_200",
    },
}


# ==========================================
# TEXT PROCESSING FUNCTIONS
# ==========================================

# Regex pattern for thinking blocks
RE_THINKING_BLOCK = re.compile(r'<think>.*?</think>', re.DOTALL | re.IGNORECASE)


def clean_model_output(text: str) -> str:
    """
    Removes <think>...</think> blocks and whitespace from model output.
    
    Args:
        text: Raw model output text
        
    Returns:
        Cleaned text with thinking blocks removed
    """
    if not text:
        return ""
    cleaned_text = RE_THINKING_BLOCK.sub('', text)
    return cleaned_text.strip()


def extract_reasoning(text: str) -> str:
    """
    Extracts text between 'Reasoning:' and 'Answer:'.
    
    Robustness Fixes:
    - strict_start: Finds 'Reasoning' only if it looks like a header.
    - strict_end: Finds 'Answer' only if followed by a colon (Answer:) 
      or appearing at the start of a line, preventing false positives 
      like "To answer the question..."
    """
    if not text:
        return ""

    # STEP 1: Find the Start
    # Matches "Reasoning:" or "**Reasoning**" or "Reasoning" at start of line
    # re.IGNORECASE makes it handle "reasoning:" too
    start_match = re.search(r"(?:^|\n|\*)\s*Reasoning\s*[:*]*\s*", text, re.IGNORECASE)

    # Slice from the end of the "Reasoning:" label
    content = text[start_match.end():]

    # STEP 2: Find the End
    # We strictly look for "Answer" followed by a COLON (:), 
    # OR "Answer" that appears at the very start of a new line.
    # This ignores "To answer the question..." because it lacks a colon and isn't a new line.
    end_pattern = r"(?:\n|^)\s*Answer\s*[:\s]"
    end_match = re.search(end_pattern, content, re.IGNORECASE)
    
    if end_match:
        content = content[:end_match.start()]

    return content.strip()


def extract_answer(text: str) -> str:
    """
    Extracts text after 'Answer:' marker.
    
    Args:
        text: Full model response text
        
    Returns:
        Extracted answer text
    """
    pattern = r"Answer:\s*(.*?)(?:\n|$)"
    match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
    if match:
        return match.group(1)
    
    # Fallback: try to find answer at the end
    if "Answer:" in text or "answer:" in text:
        parts = re.split(r"Answer:\s*", text, flags=re.IGNORECASE)
        if len(parts) > 1:
            return parts[-1]
    
    return ""


# ==========================================
# CITATION EXTRACTION FUNCTIONS
# ==========================================

def parse_timestamp_str(time_str: str) -> float:
    """Helper to convert 'MM:SS' string to seconds (float)."""
    try:
        parts = time_str.strip().split(':')
        if len(parts) == 2:
            minutes, seconds = map(float, parts)
            return minutes * 60 + seconds
        elif len(parts) == 3:
            hours, minutes, seconds = map(float, parts)
            return hours * 3600 + minutes * 60 + seconds
        return -1.0
    except ValueError:
        return -1.0

def extract_citations_from_sentence(sentence: str) -> List[Dict[str, Any]]:
    """
    Extracts citation patterns from a sentence, handling multiple citations
    separated by semicolons and filtering out invalid formats.
    
    Valid formats: 
    - (visual, 0:20)
    - (audio, 1:20-1:30)
    - (visual, 0:02-0:04; visual, 0:16-0:19)
    
    Invalid/Ignored formats:
    - (visual, 0:01, 0:15, 0:26) -> contains comma list of times
    """
    citations = []
    
    # 1. Find all content inside parentheses
    # We catch everything inside () first to handle the split logic manually
    parenthetical_matches = re.finditer(r'\((.*?)\)', sentence)
    
    for match in parenthetical_matches:
        full_content = match.group(1)
        raw_full = match.group(0) # The full ( ... ) string
        
        # 2. Split by semicolon to handle multiple citations in one block
        # e.g. "visual, 0:02-0:04; visual, 0:16-0:19"
        segments = full_content.split(';')
        
        for segment in segments:
            segment = segment.strip()
            
            # 3. Strict Regex for valid citation format: "word, timestamp"
            # ^\s* : Start of string (ignoring whitespace)
            # ([a-zA-Z]+) : Capture Modality (letters only)
            # \s*,\s* : Separator (comma with optional space)
            # (\d+:\d+(?:-\d+:\d+)?) : Capture Time. strictly MM:SS or MM:SS-MM:SS.
            #                          This REJECTS "0:01, 0:15" because of the extra comma.
            # \s*$        : End of string
            citation_pattern = r'^\s*([a-zA-Z]+)\s*,\s*(\d+:\d+(?:-\d+:\d+)?)\s*$'
            
            seg_match = re.match(citation_pattern, segment)
            
            if seg_match:
                modality = seg_match.group(1).lower()
                time_part = seg_match.group(2)
                
                # Parse timestamp range
                if '-' in time_part:
                    start_str, end_str = time_part.split('-', 1)
                    start_time = parse_timestamp_str(start_str)
                    end_time = parse_timestamp_str(end_str)
                else:
                    start_time = parse_timestamp_str(time_part)
                    end_time = start_time
                
                citations.append({
                    'raw': raw_full, # Keeps the original full context
                    'citation_segment': f"({segment})", # The specific parsed segment
                    'modality': modality,
                    'start_time': start_time,
                    'end_time': end_time
                })

    return citations


def split_text_into_sentences(text: str) -> List[str]:
    """
    Splits text into sentences but keeps 'floating citations' attached to the preceding sentence.
    Handles ellipses (...) without splitting.
    
    Refined Logic:
    - Does NOT split on decimals (e.g., "5.50").
    - DOES split on numbered lists (e.g., "1. First item").
    - Removes empty sentences resulting from extra whitespace.
    """
    if not text:
        return []
    
    sentences = []
    current_pos = 0

    splitter_pattern = re.compile(r'(?P<cit>\([^)]+\))|(?P<ellipsis>\.{2,})|(?P<punct>(?:[!?]+|\.(?!\d)|(?<!\d)\.))')
    
    pending_punct_split = False
    
    for match in splitter_pattern.finditer(text):
        if match.group('ellipsis'):
            # Case: Ellipsis found (...). Treat as content.
            pending_punct_split = False
            pass

        elif match.group('punct'):
            # Check if a citation follows immediately (ignoring whitespace/quotes)
            remaining_text = text[match.end():]
            
            # Lookahead: If citation follows, mark pending.
            if re.match(r'^[\s\'"*_]*\(', remaining_text):
                pending_punct_split = True
            else:
                # Standard split
                candidate = text[current_pos:match.end()].strip()
                if candidate:
                    sentences.append(candidate)
                current_pos = match.end()
                pending_punct_split = False
                
        elif match.group('cit'):
            if pending_punct_split:
                # We had a period, now we have the citation. Split now.
                candidate = text[current_pos:match.end()].strip()
                if candidate:
                    sentences.append(candidate)
                current_pos = match.end()
                pending_punct_split = False
            else:
                pass  # Citation mid-sentence, ignore.
    
    # Flush remaining text
    if current_pos < len(text):
        rem = text[current_pos:].strip()
        if rem:
            sentences.append(rem)
    
    return sentences


def parse_llm_list(text: str) -> List[str]:
    """
    Breaks text by new lines and only includes those that start with "-".
    
    Args:
        text: Text containing bulleted list items
        
    Returns:
        List of bulleted items (without the leading "-")
    """
    if not text:
        return []
    lines = text.split('\n')
    return [line.strip() for line in lines if line.strip().startswith("-")]


# ==========================================
# FILE AND DATA LOADING FUNCTIONS
# ==========================================

def load_prompt(path: str) -> str:
    """
    Loads a prompt template from a file.
    
    Args:
        path: Path to the prompt file
        
    Returns:
        Prompt text as string, empty string if file not found
    """
    if not os.path.exists(path):
        logger.warning(f"Prompt file not found: {path}")
        return ""
    with open(path, "r", encoding='utf-8') as f:
        return f.read()


def load_metadata(ds_name: str) -> Dict[str, Dict[str, str]]:
    """
    Loads metadata for a dataset from HuggingFace disk.
    
    Args:
        ds_name: Dataset name (must be in DATASET_CONFIGS)
        
    Returns:
        Dictionary mapping video IDs to metadata (question, path)
    """
    config = DATASET_CONFIGS.get(ds_name)
    if not config:
        logger.error(f"Unknown dataset: {ds_name}")
        return {}
    
    if load_from_disk is None:
        logger.error("datasets library not available")
        return {}
    
    try:
        ds = load_from_disk(config["hf_path"])
        return {
            row["video"]: {
                "question": row.get('question', ''),
                "path": os.path.join(config["video_dir"], f"{row['video']}.mp4")
            } for row in ds
        }
    except Exception as e:
        logger.error(f"Error loading metadata for {ds_name}: {e}")
        return {}


def get_video_path(ds_name: str, video_id: str) -> Optional[str]:
    """
    Gets the full path to a video file for a given dataset and video ID.
    
    Args:
        ds_name: Dataset name
        video_id: Video identifier
        
    Returns:
        Full path to video file, or None if dataset not found
    """
    config = DATASET_CONFIGS.get(ds_name)
    if not config:
        return None
    return os.path.join(config["video_dir"], f"{video_id}.mp4")


# ==========================================
# TIMESTAMP PARSING FUNCTIONS
# ==========================================

def parse_timestamp(time_str: Any) -> float:
    """
    Parses a timestamp string into seconds (float).
    
    Supports formats:
    - Single number: "5" -> 5.0
    - MM:SS: "1:30" -> 90.0
    - HH:MM:SS: "1:05:30" -> 3930.0
    
    Args:
        time_str: Timestamp string or number
        
    Returns:
        Time in seconds, or -1.0 if parsing fails
    """
    try:
        if isinstance(time_str, (int, float)):
            return float(time_str)
        
        parts = [float(p) for p in str(time_str).strip().split(':')]
        
        if len(parts) == 1:
            return parts[0]
        elif len(parts) == 2:
            return parts[0] * 60 + parts[1]
        else:  # len(parts) == 3
            return parts[0] * 3600 + parts[1] * 60 + parts[2]
    except Exception:
        return -1.0


# ==========================================
# FILE UTILITY FUNCTIONS
# ==========================================

def is_valid_file(filepath: str) -> bool:
    """
    Returns True if file exists and has content.
    
    Args:
        filepath: Path to file to check
        
    Returns:
        True if file exists and has size > 0
    """
    return os.path.exists(filepath) and os.path.getsize(filepath) > 0
