# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""
YAML-based Judge Configuration Manager

This module provides a flexible configuration system for LLM judges,
supporting multiple serving frameworks and configuration options via YAML files.
"""

import os
from dataclasses import asdict, dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

import yaml


class ServingFramework(Enum):
    """Supported serving frameworks for LLM judges."""

    LITELLM = "litellm"
    ONEAPI = "oneapi"
    NVDEV = "nvdev"
    OPENAI = "openai"
    CUSTOM = "custom"


@dataclass
class JudgeModelConfig:
    """Configuration for a single judge model."""

    name: str
    framework: ServingFramework
    model: str  # Model identifier/name for the API
    base_url: Optional[str] = None
    api_key: Optional[str] = None  # Direct API key value (for local models)
    api_key_env: Optional[str] = None  # Environment variable name for API key
    api_version: Optional[str] = None
    deployment: Optional[str] = None  # For OneAPI deployments
    temperature: float = 0.0
    max_tokens: int = 8
    num_workers: int = 16
    timeout: int = 60
    max_retries: int = 3
    reasoning_mode: Optional[str] = None  # For models with reasoning modes
    custom_headers: Dict[str, str] = field(default_factory=dict)
    additional_params: Dict[str, Any] = field(default_factory=dict)
    is_closed: bool = False  # Whether this is a closed-source model (default: open source)

    def __post_init__(self):
        """Validate and convert framework string to enum if needed."""
        if isinstance(self.framework, str):
            try:
                self.framework = ServingFramework(self.framework.lower())
            except ValueError:
                raise ValueError(f"Unsupported framework: {self.framework}")

    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary representation."""
        result = asdict(self)
        result["framework"] = self.framework.value
        return result

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "JudgeModelConfig":
        """Create instance from dictionary."""
        return cls(**data)

    def get_api_key(self) -> Optional[str]:
        """Get API key from direct value or environment variable."""
        # First check for direct API key value (useful for local models)
        if self.api_key:
            return self.api_key

        # Then check for environment variable
        if self.api_key_env:
            return os.getenv(self.api_key_env)

        # For liteLLM, determine API key based on model provider
        if self.framework == ServingFramework.LITELLM:
            # Check if it's a local model (doesn't need real API key)
            if self.model.startswith("local/"):
                return "EMPTY"
            # Check model name to determine provider
            elif self.model.startswith("nvidia_nim/"):
                return os.getenv("NVIDIA_NIM_API_KEY")
            elif self.model.startswith("anthropic/"):
                return os.getenv("ANTHROPIC_API_KEY")
            elif self.model.startswith("openai/"):
                return os.getenv("OPENAI_API_KEY")
            else:
                # Default to OpenAI for unknown models
                return os.getenv("OPENAI_API_KEY")

        # Default environment variables for other frameworks
        default_env_vars = {
            ServingFramework.NVDEV: "NVIDIA_API_KEY",
            ServingFramework.OPENAI: "OPENAI_API_KEY",
            ServingFramework.ONEAPI: "ONE_API_KEY",
        }

        if self.framework in default_env_vars:
            return os.getenv(default_env_vars[self.framework])

        return None

    def get_effective_base_url(self) -> Optional[str]:
        """Get the effective base URL for the model."""
        if self.base_url:
            return self.base_url

        # For liteLLM with NVIDIA models, check for NVIDIA_NIM_API_BASE
        if self.framework == ServingFramework.LITELLM and self.model.startswith("nvidia_nim/"):
            nvidia_base = os.getenv("NVIDIA_NIM_API_BASE")
            if nvidia_base:
                return nvidia_base

        # For OneAPI with deployment, construct the URL
        if self.framework == ServingFramework.ONEAPI and self.deployment:
            return f"https://llm-proxy.perflab.nvidia.com/openai/deployments/{self.deployment}"

        return None


class JudgeConfigManager:
    """Manager for loading and accessing judge configurations."""

    def __init__(self, config_path: Optional[Union[str, Path]] = None):
        """
        Initialize the configuration manager.

        Args:
            config_path: Path to the YAML configuration file.
                        If None, looks for 'judge_config.yaml' in standard locations.
        """
        self.config_path = self._resolve_config_path(config_path)
        self.models: Dict[str, JudgeModelConfig] = {}
        self.defaults: Dict[str, Any] = {}

        if self.config_path and self.config_path.exists():
            self.load_config()

    def _resolve_config_path(self, config_path: Optional[Union[str, Path]]) -> Optional[Path]:
        """Resolve the configuration file path."""
        if config_path:
            return Path(config_path)

        # Look for config in standard locations (prioritize config/ folder)
        search_paths = [
            Path.cwd() / "config" / "judge_config.yaml",
            Path(__file__).parent.parent / "config" / "judge_config.yaml",
            Path.cwd() / "judge_config.yaml",
            Path(__file__).parent / "judge_config.yaml",
        ]

        for path in search_paths:
            if path.exists():
                return path

        return None

    def load_config(self, config_path: Optional[Union[str, Path]] = None):
        """
        Load configuration from YAML file.

        Args:
            config_path: Path to YAML file. If None, uses the instance's config_path.
        """
        if config_path:
            self.config_path = Path(config_path)

        if not self.config_path or not self.config_path.exists():
            raise FileNotFoundError(f"Configuration file not found: {self.config_path}")

        with open(self.config_path, "r") as f:
            config = yaml.safe_load(f)

        # Load defaults
        self.defaults = config.get("defaults", {})

        # Load models
        models_config = config.get("models", {})
        for model_name, model_data in models_config.items():
            # Apply defaults
            for key, value in self.defaults.items():
                if key not in model_data:
                    model_data[key] = value

            model_data["name"] = model_name
            self.models[model_name] = JudgeModelConfig.from_dict(model_data)

    def get_model(self, model_name: str) -> Optional[JudgeModelConfig]:
        """Get configuration for a specific model."""
        return self.models.get(model_name)

    def get_models_by_framework(self, framework: Union[str, ServingFramework]) -> List[JudgeModelConfig]:
        """Get all models using a specific framework."""
        if isinstance(framework, str):
            try:
                framework = ServingFramework(framework.lower())
            except ValueError:
                return []

        return [model for model in self.models.values() if model.framework == framework]

    def list_models(self) -> List[str]:
        """List all available model names."""
        return list(self.models.keys())
