#!/usr/bin/env python3
"""
Shared config-based function loading for all chemistry RPC servers.
Provides elegant extraction and preprocessing function loading from model configs.
"""
import os
from pathlib import Path
import yaml
import importlib
import logging

BASE_DIR = Path(__file__).resolve().parents[3]
CONFIGS_DIR = BASE_DIR / 'configs'

logger = logging.getLogger(__name__)

def load_functions_from_config():
    """Load extraction and preprocessing functions specified in the model config YAML."""
    model_config_ref = os.environ.get('LM_EVAL_MODEL_CONFIG', 'default')
    logger.info(f"Loading functions for model config: {model_config_ref}")
    
    try:
        config_path = _resolve_config_path(model_config_ref)
        logger.debug(f"Resolved chemistry config to: {config_path}")
        
        with config_path.open('r') as f:
            config = yaml.safe_load(f)
        
        # Load extraction function
        extraction_function = load_function_from_spec(
            config.get('extraction_function'), 
            'extraction_function',
            default_fallback_extraction
        )
        
        # Load preprocessing function  
        preprocessing_function = load_function_from_spec(
            config.get('preprocessing_function'),
            'preprocessing_function', 
            default_fallback_preprocessing
        )
        
        # Create prompt processor that handles both preprocessing and inline prompts
        prompt_processor = create_prompt_processor(
            config,
            preprocessing_function,
            config_source=str(config_path)
        )
        
        return extraction_function, prompt_processor
        
    except Exception as e:
        logger.warning(f"Failed to load functions from config '{model_config_ref}': {e}")
        return default_fallback_extraction, default_fallback_preprocessing


def _resolve_config_path(config_ref: str) -> Path:
    """Resolve a config reference to an on-disk YAML path.

    The reference can be:
        - An absolute or relative filesystem path to a YAML file
        - A bare config name (without extension), in which case we look in the
          standard chemistry model_configs directory.
    """
    candidate = Path(config_ref).expanduser()
    if candidate.suffix.lower() != '.yaml':
        # If the reference is a path without extension but exists as directory or file
        if candidate.exists() and candidate.is_file():
            return candidate
        candidate_yaml = candidate.with_suffix('.yaml')
        if candidate_yaml.exists():
            return candidate_yaml

    if candidate.exists():
        if candidate.is_dir():
            raise FileNotFoundError(f"Provided config reference '{config_ref}' points to a directory, expected a YAML file.")
        return candidate

    # Check top-level configs directory by name
    config_name = config_ref.replace('.yaml', '')
    candidate_configs = CONFIGS_DIR / f'{config_name}.yaml'
    if candidate_configs.exists():
        return candidate_configs

    # Fallback to legacy location by name
    root_dir = Path(__file__).resolve().parent.parent / 'model_configs'
    legacy_path = root_dir / f'{config_name}.yaml'
    if legacy_path.exists():
        return legacy_path

    raise FileNotFoundError(
        f"Unable to locate chemistry model config for reference '{config_ref}'."
    )

def load_function_from_spec(func_spec, func_type, fallback_func):
    """Load a function from a module:function specification."""
    if not func_spec:
        logger.info(f"No {func_type} specified, using fallback")
        return fallback_func
    
    try:
        # Parse module:function format
        module_name, function_name = func_spec.split(':')
        
        # Handle the chemsets module path
        if module_name.startswith('chemsets.'):
            module_name = module_name[9:]  # Remove 'chemsets.' prefix
        
        # Import the module and get the function
        module = importlib.import_module(module_name)
        function = getattr(module, function_name)
        
        logger.info(f"Successfully loaded {func_type}: {function_name} from {module_name}")
        return function
        
    except Exception as e:
        logger.warning(f"Failed to load {func_type} '{func_spec}': {e}")
        logger.warning(f"Using fallback {func_type}")
        return fallback_func

def create_prompt_processor(config, preprocessing_func, config_source: str | None = None):
    """Create a prompt processor that handles both preprocessing and inline prompts."""
    inline_prompt = config.get('inline_prompt', '').strip()
    disable_inline = os.environ.get('LM_EVAL_DISABLE_INLINE_PROMPT', '').lower() == 'true'
    inline_logged = False
    source_label = config_source or '<unknown-config>'
    
    def process_prompt(problem_text):
        nonlocal inline_logged
        # First apply model-specific preprocessing
        processed = preprocessing_func(problem_text)
        
        # Then apply inline prompt if configured and not disabled
        if inline_prompt and not disable_inline:
            if not inline_logged:
                logger.info(
                    "Applying inline prompt from %s (length=%d characters)",
                    source_label,
                    len(inline_prompt),
                )
                inline_logged = True
            # Apply inline prompt template
            processed = f"{inline_prompt}\n\n{processed}"
            
        return processed
    
    return process_prompt

def default_fallback_extraction(answer):
    """Fallback extraction function."""
    if isinstance(answer, str) and '<|answer_start|>' in answer and '<|answer_end|>' in answer:
        return answer.split('<|answer_start|>')[-1].split('<|answer_end|>')[0].strip()
    return str(answer) if answer is not None else ""

def default_fallback_preprocessing(text):
    """Fallback preprocessing function."""
    return text

# Convenience functions for servers to use
def create_config_based_functions():
    """
    Create config-based processing functions for use in RPC servers.
    
    Returns:
        tuple: (process_answer_func, process_problem_func) ready to use in servers
    """
    extraction_function, prompt_processor = load_functions_from_config()
    
    def process_answer(answer):
        """Process answer using the extraction function from model config."""
        return extraction_function(answer)
    
    def process_problem(problem):
        """Process problem text with preprocessing and inline prompts from model config."""
        return prompt_processor(problem)
    
    return process_answer, process_problem
