# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""
YAML-based Judge Configuration Manager

This module provides a flexible configuration system for Judge's Verdict,
supporting multiple serving frameworks and configuration options via YAML files.
"""

import os
import yaml
from pathlib import Path
from typing import Dict, List, Optional, Any, Union
from dataclasses import dataclass, field, asdict
from enum import Enum


class ServingFramework(Enum):
    """Supported serving frameworks for Judge's Verdict."""
    LITELLM = "litellm"
    ONEAPI = "oneapi"
    NVDEV = "nvdev"
    OPENAI = "openai"
    CUSTOM = "custom"


@dataclass
class JudgeModelConfig:
    """Configuration for a single judge model."""
    name: str
    framework: ServingFramework
    model: str  # Model identifier/name for the API
    base_url: Optional[str] = None
    api_key: Optional[str] = None  # Direct API key value (for local models)
    api_key_env: Optional[str] = None  # Environment variable name for API key
    api_version: Optional[str] = None
    deployment: Optional[str] = None  # For OneAPI deployments
    temperature: float = 0.0
    max_tokens: int = 8
    num_workers: int = 16
    timeout: int = 60
    max_retries: int = 3
    reasoning_mode: Optional[str] = None  # For models with reasoning modes
    custom_headers: Dict[str, str] = field(default_factory=dict)
    additional_params: Dict[str, Any] = field(default_factory=dict)
    
    def __post_init__(self):
        """Validate and convert framework string to enum if needed."""
        if isinstance(self.framework, str):
            try:
                self.framework = ServingFramework(self.framework.lower())
            except ValueError:
                raise ValueError(f"Unsupported framework: {self.framework}")
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary representation."""
        result = asdict(self)
        result['framework'] = self.framework.value
        return result
    
    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> 'JudgeModelConfig':
        """Create instance from dictionary."""
        return cls(**data)
    
    def get_api_key(self) -> Optional[str]:
        """Get API key from direct value or environment variable."""
        # First check for direct API key value (useful for local models)
        if self.api_key:
            return self.api_key
        
        # Then check for environment variable
        if self.api_key_env:
            return os.getenv(self.api_key_env)
        
        # For liteLLM, determine API key based on model provider
        if self.framework == ServingFramework.LITELLM:
            # Check if it's a local model (doesn't need real API key)
            if self.model.startswith('local/'):
                return "EMPTY"
            # Check model name to determine provider
            elif self.model.startswith('nvidia_nim/'):
                return os.getenv("NVIDIA_NIM_API_KEY")
            elif self.model.startswith('anthropic/'):
                return os.getenv("ANTHROPIC_API_KEY")
            elif self.model.startswith('openai/'):
                return os.getenv("OPENAI_API_KEY")
            else:
                # Default to OpenAI for unknown models
                return os.getenv("OPENAI_API_KEY")
        
        # Default environment variables for other frameworks
        default_env_vars = {
            ServingFramework.NVDEV: "NVIDIA_API_KEY",
            ServingFramework.OPENAI: "OPENAI_API_KEY",
            ServingFramework.ONEAPI: "ONE_API_KEY",
        }
        
        if self.framework in default_env_vars:
            return os.getenv(default_env_vars[self.framework])
        
        return None
    
    def get_effective_base_url(self) -> Optional[str]:
        """Get the effective base URL for the model."""
        if self.base_url:
            return self.base_url
        
        # For liteLLM with NVIDIA models, check for NVIDIA_NIM_API_BASE
        if self.framework == ServingFramework.LITELLM and self.model.startswith('nvidia_nim/'):
            nvidia_base = os.getenv("NVIDIA_NIM_API_BASE")
            if nvidia_base:
                return nvidia_base
        
        # For OneAPI with deployment, construct the URL
        if self.framework == ServingFramework.ONEAPI and self.deployment:
            return f"https://llm-proxy.perflab.nvidia.com/openai/deployments/{self.deployment}"
        
        return None





class JudgeConfigManager:
    """Manager for loading and accessing judge configurations."""
    
    def __init__(self, config_path: Optional[Union[str, Path]] = None):
        """
        Initialize the configuration manager.
        
        Args:
            config_path: Path to the YAML configuration file.
                        If None, looks for 'judge_config_litellm.yaml' in standard locations.
        """
        self.config_path = self._resolve_config_path(config_path)
        self.models: Dict[str, JudgeModelConfig] = {}
        self.defaults: Dict[str, Any] = {}
        
        if self.config_path and self.config_path.exists():
            self.load_config()
    
    def _resolve_config_path(self, config_path: Optional[Union[str, Path]]) -> Optional[Path]:
        """Resolve the configuration file path."""
        if config_path:
            return Path(config_path)
        
        # Look for config in standard locations (prioritize config/ folder)
        search_paths = [
            Path.cwd() / "config" / "judge_config_litellm.yaml",
            Path(__file__).parent.parent / "config" / "judge_config_litellm.yaml",
            Path.cwd() / "judge_config_litellm.yaml",
            Path(__file__).parent / "judge_config_litellm.yaml",
        ]
        
        for path in search_paths:
            if path.exists():
                return path
        
        return None
    
    def load_config(self, config_path: Optional[Union[str, Path]] = None):
        """
        Load configuration from YAML file.
        
        Args:
            config_path: Path to YAML file. If None, uses the instance's config_path.
        """
        if config_path:
            self.config_path = Path(config_path)
        
        if not self.config_path or not self.config_path.exists():
            raise FileNotFoundError(f"Configuration file not found: {self.config_path}")
        
        with open(self.config_path, 'r') as f:
            config = yaml.safe_load(f)
        
        # Load defaults
        self.defaults = config.get('defaults', {})
        
        # Load models
        models_config = config.get('models', {})
        for model_name, model_data in models_config.items():
            # Apply defaults
            for key, value in self.defaults.items():
                if key not in model_data:
                    model_data[key] = value
            
            model_data['name'] = model_name
            self.models[model_name] = JudgeModelConfig.from_dict(model_data)
    
    def get_model(self, model_name: str) -> Optional[JudgeModelConfig]:
        """Get configuration for a specific model."""
        return self.models.get(model_name)
    
    def get_models_by_framework(self, framework: Union[str, ServingFramework]) -> List[JudgeModelConfig]:
        """Get all models using a specific framework."""
        if isinstance(framework, str):
            try:
                framework = ServingFramework(framework.lower())
            except ValueError:
                return []
        
        return [model for model in self.models.values() if model.framework == framework]
    
    def list_models(self) -> List[str]:
        """List all available model names."""
        return list(self.models.keys())
    
    def add_model(self, model_config: JudgeModelConfig):
        """Add or update a model configuration."""
        self.models[model_config.name] = model_config
    
    def save_config(self, config_path: Optional[Union[str, Path]] = None):
        """
        Save current configuration to YAML file.
        
        Args:
            config_path: Path to save the configuration. If None, uses the instance's config_path.
        """
        save_path = Path(config_path) if config_path else self.config_path
        if not save_path:
            raise ValueError("No configuration path specified")
        
        config = {
            'defaults': self.defaults,
            'models': {name: model.to_dict() for name, model in self.models.items()},
        }
        
        save_path.parent.mkdir(parents=True, exist_ok=True)
        with open(save_path, 'w') as f:
            yaml.dump(config, f, default_flow_style=False, sort_keys=False)
    
    def create_langchain_llm(self, model_name: str):
        """
        Create a LangChain LLM instance for a given model.
        
        Args:
            model_name: Name of the model to create LLM for.
        
        Returns:
            LangChain LLM instance configured according to the model settings.
        """
        model_config = self.get_model(model_name)
        if not model_config:
            raise ValueError(f"Model not found: {model_name}")
        
        api_key = model_config.get_api_key()
        
        if model_config.framework == ServingFramework.NVDEV:
            from langchain_nvidia_ai_endpoints.chat_models import ChatNVIDIA
            
            if not api_key:
                raise ValueError(f"API key not found for NVDEV model: {model_name}")
            
            return ChatNVIDIA(
                model=model_config.model,
                temperature=model_config.temperature,
                max_tokens=model_config.max_tokens,
                api_key=api_key,
                **model_config.additional_params
            )
        
        elif model_config.framework in [ServingFramework.OPENAI, ServingFramework.ONEAPI]:
            from langchain_openai import ChatOpenAI
            
            if not api_key:
                raise ValueError(f"API key not found for {model_config.framework.value} model: {model_name}")
            
            kwargs = {
                'model': model_config.model,
                'max_tokens': model_config.max_tokens,
                'api_key': api_key,
            }
            
            # Add temperature only if not an o1/o3 model (they don't support it)
            if not any(x in model_config.model for x in ['o1', 'o3']):
                kwargs['temperature'] = model_config.temperature
            
            # Add base URL if specified
            base_url = model_config.get_effective_base_url()
            if base_url:
                kwargs['base_url'] = base_url
            
            # Add custom headers
            headers = dict(model_config.custom_headers)
            if model_config.api_version:
                headers['api-version'] = model_config.api_version
            
            if headers:
                kwargs['default_headers'] = headers
            
            # Add any additional parameters
            kwargs.update(model_config.additional_params)
            
            return ChatOpenAI(**kwargs)
        
        elif model_config.framework == ServingFramework.LITELLM:
            # For liteLLM, we don't create a LangChain LLM
            # The scoring script will create a direct RAGAS wrapper instead
            raise NotImplementedError(
                "LiteLLM models should not use create_langchain_llm(). "
                "Use create_ragas_litellm_wrapper() directly in the scoring script."
            )
        
        else:
            raise ValueError(f"Unsupported framework: {model_config.framework.value}")


def load_default_config() -> JudgeConfigManager:
    """Load the default judge configuration from YAML file."""
    manager = JudgeConfigManager()
    if not manager.config_path or not manager.config_path.exists():
        raise FileNotFoundError(
            "No config/judge_config_litellm.yaml found. Please create one in the config/ folder.\n"
            "You can use one of the example configurations:\n"
            "  - config/judge_config_litellm.yaml (main LiteLLM configuration)\n"
            "  - config/judge_config_litellm_example.yaml (minimal example)\n"
            "  - config/judge_config_internal.yaml (internal configuration - for internal use only)"
        )
    return manager
