# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""
Direct RAGAS wrapper for liteLLM - bypasses LangChain entirely.

This wrapper implements the exact interface RAGAS needs, using liteLLM's
completion API directly without any LangChain involvement.
"""

import os
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
from litellm import completion
import litellm


@dataclass
class LiteLLMResult:
    """Minimal result format that RAGAS expects."""
    generations: List[Any]
    llm_output: Optional[Dict[str, Any]] = None


class RagasLiteLLMWrapper:
    """
    Direct RAGAS wrapper for liteLLM - no LangChain needed.
    Implements only what RAGAS actually uses.
    """
    
    def __init__(
        self, 
        model: str,
        temperature: float = 0.0,
        max_tokens: int = 8,
        api_key: Optional[str] = None,
        api_base: Optional[str] = None,
        api_version: Optional[str] = None,
        timeout: Optional[int] = 60,
        max_retries: Optional[int] = 3,
        **kwargs
    ):
        self.model = model
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.api_key = api_key
        self.api_base = api_base
        self.api_version = api_version
        self.timeout = timeout
        self.max_retries = max_retries
        self.additional_params = kwargs
        
        # Configure liteLLM
        litellm.drop_params = True
        litellm.set_verbose = False
        
        # Set API key if provided
        if api_key:
            if "nvidia" in model.lower():
                os.environ["NVIDIA_NIM_API_KEY"] = api_key
            elif any(x in model.lower() for x in ["gpt", "o1", "o3"]):
                os.environ["OPENAI_API_KEY"] = api_key
            elif "claude" in model.lower() or "anthropic" in model.lower():
                os.environ["ANTHROPIC_API_KEY"] = api_key
    
    def generate(self, prompt: str, n: int = 1, **kwargs):
        """
        Direct generation using liteLLM - this is what RAGAS actually calls.
        """
        # Handle RAGAS prompt objects that may have a 'text' attribute
        if hasattr(prompt, 'text'):
            prompt_text = prompt.text
        elif isinstance(prompt, dict) and 'text' in prompt:
            prompt_text = prompt['text']
        else:
            prompt_text = str(prompt)
            
        messages = [{"role": "user", "content": prompt_text}]
        
        # Build completion kwargs
        completion_kwargs = {
            "model": self.model,
            "messages": messages,
            "max_tokens": kwargs.get("max_tokens", self.max_tokens),
            "n": n,
            "timeout": self.timeout,
            "num_retries": self.max_retries,
        }
        
        # Add temperature for models that support it
        if not any(x in self.model.lower() for x in ['o1', 'o3']):
            completion_kwargs["temperature"] = kwargs.get("temperature", self.temperature)
        
        # Add optional parameters
        if self.api_base:
            completion_kwargs["api_base"] = self.api_base
        elif "nvidia" in self.model.lower():
            # For NVIDIA models, use the NVIDIA_NIM_API_BASE from environment
            nvidia_api_base = os.environ.get("NVIDIA_NIM_API_BASE")
            if nvidia_api_base:
                completion_kwargs["api_base"] = nvidia_api_base
        
        if self.api_key:
            # Handle "EMPTY" as a special case for local servers
            if self.api_key.upper() != "EMPTY":
                completion_kwargs["api_key"] = self.api_key
            else:
                # For local servers that don't require auth
                completion_kwargs["api_key"] = "EMPTY"
        if self.api_version:
            completion_kwargs["api_version"] = self.api_version
        
        # Make the liteLLM call
        response = completion(**completion_kwargs)
        
        # Format for RAGAS
        generations = []
        for choice in response.choices:
            text = choice.message.content if hasattr(choice.message, 'content') else ""
            # Create a proper Generation class that RAGAS expects
            class Generation:
                def __init__(self, text):
                    self.text = text
                def __getitem__(self, i):
                    # Return self so that .text can be accessed on the result
                    return self if i == 0 else None
                def __len__(self):
                    return 1
            generations.append(Generation(text))
        
        # Token usage tracking removed
        llm_output = {}
        
        # Create a proper Result class
        class Result:
            def __init__(self, generations, llm_output):
                self.generations = generations
                self.llm_output = llm_output
            def __getitem__(self, i):
                return self.generations if i == 0 else None
            
        return Result(generations, llm_output)
    
    def set_run_config(self, run_config):
        """
        Set runtime configuration from Ragas.
        This method is called by Ragas during metric initialization.
        """
        # Store run config if needed
        self.run_config = run_config
        # Ragas doesn't seem to require specific handling here for liteLLM
        pass
    
    async def agenerate_text(self, prompt: str, n: int = 1, **kwargs):
        """Async generation - RAGAS uses this for parallel processing."""
        # For simplicity, use sync version (liteLLM handles async internally)
        return self.generate(prompt, n, **kwargs)
    
    async def agenerate_prompt(self, prompts: List[Any], **kwargs):
        """Batch generation for RAGAS."""
        results = []
        for prompt in prompts:
            # The generate method now handles all prompt formats
            result = await self.agenerate_text(prompt, **kwargs)
            results.append(result)
        return results[0] if len(results) == 1 else results


def create_ragas_litellm_wrapper(
    model: str,
    temperature: float = 0.0,
    max_tokens: int = 8,
    api_key: Optional[str] = None,
    api_base: Optional[str] = None,
    api_version: Optional[str] = None,
    **kwargs
) -> RagasLiteLLMWrapper:
    """
    Create a direct RAGAS wrapper for liteLLM.
    
    This bypasses LangChain entirely and provides exactly what RAGAS needs,
    using liteLLM's completion API directly.
    """
    return RagasLiteLLMWrapper(
        model=model,
        temperature=temperature,
        max_tokens=max_tokens,
        api_key=api_key,
        api_base=api_base,
        api_version=api_version,
        **kwargs
    )
