import re
import ast
import os
import yaml
import importlib
from pathlib import Path

from .common.config_loader import _resolve_config_path


def process_answer(answer):
    if not isinstance(answer, str):
        return answer

    # Extract answer using model-specific extraction
    match = extract_answer_with_model_config(answer)
    return match


def extract_answer_with_model_config(answer):
    """
    Extract answer using model-specific configuration.
    Falls back to default extraction if no model config is specified.
    """
    # Get model config name from environment variable
    model_config_name = os.environ.get('LM_EVAL_MODEL_CONFIG', 'default')
    
    try:
        # Load model config
        config = load_model_config(model_config_name)
        
        # Get extraction function
        extraction_fn = get_extraction_function(config)
        
        # Use model-specific extraction
        return extraction_fn(answer)
        
    except Exception as e:
        print(f"Warning: Failed to use model config '{model_config_name}': {e}")
        print("Falling back to default extraction")
        # Fallback to original function
        return extract_general_answer(answer)


def load_model_config(model_name):
    """Load model-specific config or default"""
    try:
        config_path = _resolve_config_path(model_name)
    except FileNotFoundError:
        config_path = Path(__file__).parent / "model_configs" / "default.yaml"

    with open(config_path, 'r') as f:
        return yaml.safe_load(f)


def get_extraction_function(config):
    """Get the extraction function from config"""
    func_path = config.get('extraction_function', 'chemsets.extractors:extract_general_answer')
    
    if ':' in func_path:
        module_name, func_name = func_path.rsplit(':', 1)
    else:
        # Handle legacy format or direct function names
        module_name = 'lm_eval.tasks.chemsets.extractors'
        func_name = func_path
    
    try:
        # Try importing as absolute module path
        module = importlib.import_module(module_name)
    except ImportError:
        # Try importing as relative to chemsets
        module = importlib.import_module(f'lm_eval.tasks.{module_name}')
    
    return getattr(module, func_name)


def extract_general_answer(response):
    """
    Simplified answer extraction for chemistry CoT traces and cleanup.
    1. Check for <answer> and </answer> tags and extract the content if available.
    2. Check for <|answer_start|> and <|answer_end|> tags and extract the content if available.
    3. Check for latex/markdown and extract the content if available. If multiple matches are present, keep the last.
    4. Fallback patterns:
       - Look for content in quotes (both single and double).
       - Look for numbers with units (e.g., "5 g", "10 mL").
       - Look for just numbers.
       
    """
    print(f"Processing response: {response}")
    
    # Handle Qwen3 thinking mode
    if "</think>" in response:
        content = response.split("</think>")[-1].strip()
    elif "<|think_end|>" in response:
        content = response.split("<|think_end|>")[-1].strip()
    else:
        content = response.strip()
        
    print(f"Extracting answer from content: {content}")
    answer = None
        
    if "<answer>" in content and "</answer>" in content:
        answer = content.split("<answer>")[-1].split("</answer>")[0].strip()
    elif "<|answer_start|>" in content and "<|answer_end|>" in content:
        answer = content.split("<|answer_start|>")[-1].split("<|answer_end|>")[0].strip()
    else:
        # Check for boxed or bold answers
        answer = extract_last_formatted_answer(content)
    print(f"Extracted answer: {answer}")

    
    if answer is not None:
        cleaned = clean_extracted_answer(answer)
        if cleaned is not None:
            return convert_to_appropriate_type(cleaned)
    
    # Fallback: try other patterns
    return extract_fallback_patterns(content)

def extract_last_formatted_answer(content):
    """
    Extract the last formatted answer from the content.
    Looks for **content**, \boxed{content}, and (content) patterns.
    Takes the last occurrence among all three types.
    """
    # Combine all matches with their positions
    all_matches = []
    
    # Find positions of bold matches **content**
    for match in re.finditer(r'\*\*(.+?)\*\*', content):
        all_matches.append((match.start(), match.group(1)))
    
    # Find positions of boxed matches \boxed{content}
    for match in re.finditer(r'\\boxed\{([^}]+)\}', content):
        all_matches.append((match.start(), match.group(1)))
    
    # Find positions of parentheses matches (content) - for multiple choice
    for match in re.finditer(r'\(([^)]+)\)', content):
        candidate = match.group(1).strip()
        # Check if this looks like a multiple choice answer vs numeric/math
        if len(candidate) <= 3 and candidate.replace('-', '').replace('+', '').replace('.', '').isdigit() == False:
            all_matches.append((match.start(), candidate))
    
    if all_matches:
        # Sort by position and take the last one
        all_matches.sort(key=lambda x: x[0])
        last_answer = all_matches[-1][1].strip()
        return last_answer

    return None

def extract_fallback_patterns(content):
    """Fallback extraction methods"""
    # Content in quotes
    quoted = re.findall(r'"([^"]+)"', content)
    if quoted:
        return clean_extracted_answer(quoted[-1])
    
    single_quoted = re.findall(r"'([^']+)'", content)
    if single_quoted:
        return clean_extracted_answer(single_quoted[-1])
    
    # Numbers with units
    numbers_with_units = re.findall(r'([0-9.]+(?:\.[0-9]+)?)\s*([a-zA-Z]+)', content)
    if numbers_with_units:
        return f"{numbers_with_units[-1][0]} {numbers_with_units[-1][1]}"
    
    # Just numbers
    numbers = re.findall(r'([0-9.]+(?:\.[0-9]+)?)', content)
    if numbers:
        return numbers[-1]
    
    return None

def clean_extracted_answer(answer):
    """Clean extracted answers"""
    if not answer:
        return None
    
    answer = answer.strip()
    
    # Remove common artifacts
    answer = re.sub(r'^\$+|\$+$', '', answer)
    answer = re.sub(r'^\{|\}$', '', answer)
    answer = re.sub(r'^\[|\]$', '', answer)
    answer = re.sub(r'[.,;:]+$', '', answer)
    
    # Remove LaTeX commands
    answer = re.sub(r'\\text\{([^}]+)\}', r'\1', answer)
    answer = re.sub(r'\\mathrm\{([^}]+)\}', r'\1', answer)
    answer = re.sub(r'\\mathbf\{([^}]+)\}', r'\1', answer)
    
    # Normalize subscript and superscript digits
    answer = normalize_sub_super_scripts(answer)
    
    # Remove chemistry units from numerical answers
    answer = remove_chemistry_units(answer)
    
    return answer.strip()

def remove_chemistry_units(answer):
    """
    Remove common chemistry units from numerical answers.
    Only removes units if they appear exactly at the end of a number.
    Uses a single optimized regex pattern instead of looping.
    """
    if not answer:
        return answer
    
    # Single regex pattern that matches: number + space + common chemistry unit + optional punctuation
    # This covers 99% of chemistry units with one fast regex instead of 50+ iterations
    unit_pattern = r'^([+-]?(?:\d+\.?\d*|\d*\.\d+)(?:[eE][+-]?\d+)?)\s*(?:' + \
        r'Ų|Å²|Å\^2|A²|A\^2|' + \
        r'g/mol|Da|kDa|amu|' + \
        r'logP|cLogP|' + \
        r'mol/L|mmol/L|' + \
        r'kJ/mol|kcal/mol|' + \
        r'°C|°F|' + \
        r'mmHg|' + \
        r'μM|μL|μg|μm|μs|μA|' + \
        r'uM|uL|ug|um|us|uA|' + \
        r'mM|mL|mg|mm|ms|mV|mA|' + \
        r'nM|nm|ns|' + \
        r'pM|Pa|' + \
        r'kPa|kDa|kJ|kg|km|kHz|' + \
        r'MPa|MHz|' + \
        r'GHz|THz|' + \
        r'eV|' + \
        r'atm|bar|torr|' + \
        r'cal|kcal|' + \
        r'min|hr|hours|' + \
        r'percent|units|' + \
        r'M|L|K|J|V|A|C|F|g|m|s|h|%' + \
        r')(?:\s*[.,;:]?)*$'
    
    match = re.search(unit_pattern, answer, re.IGNORECASE | re.UNICODE)
    if match:
        return match.group(1)
    
    # Return original if no units found
    return answer

def convert_to_appropriate_type(answer):
    """
    Convert answer to appropriate type (int, float, or string).
    Very lenient - tries to convert numbers but defaults to string.
    """
    if not answer:
        return answer
    
    answer_str = str(answer).strip()
    
    # Try integer first (including negative numbers)
    try:
        if '.' not in answer_str:
            return int(answer_str)
    except:
        pass
    
    # Try float (including negative numbers and scientific notation)
    try:
        return float(answer_str)
    except:
        pass
    
    try:
        # Try to evaluate as a Python literal (e.g., lists, tuples)
        return ast.literal_eval(answer_str)
    except:
        pass
    
    # Return as string if conversion fails
    return answer_str

def normalize_sub_super_scripts(text):
    """
    Normalize Unicode subscript and superscript digits to standard digits.
    """
    subscripts = str.maketrans("₀₁₂₃₄₅₆₇₈₉₊₋", "0123456789+-")
    superscripts = str.maketrans("⁰¹²³⁴⁵⁶⁷⁸⁹⁺⁻", "0123456789+-")
    return text.translate(subscripts).translate(superscripts)
