import os
import time
from typing import Any, Dict, List, Optional, Union
from dataclasses import dataclass
from transformers import AutoTokenizer
import logging

logger = logging.getLogger(__name__)

@dataclass
class TokenUsage:
    """Token usage counters."""
    input_tokens: int = 0
    output_tokens: int = 0
    total_tokens: int = 0
    
    def add_usage(self, input_tokens: int, output_tokens: int):
        self.input_tokens += input_tokens
        self.output_tokens += output_tokens
        self.total_tokens += input_tokens + output_tokens

class BaseLLMClient:
    """Shared base class for LLM clients."""
    
    def __init__(self, model: str, **kwargs):
        self.model = model
        self.tokenizer_path = kwargs.get("tokenizer_path")
        self.tokenizer = None
        self.token_usage = TokenUsage()
        self._init_tokenizer()
    
    def _init_tokenizer(self):
        """Initialize tokenizer."""
        if self.tokenizer_path:
            try:
                self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path)
                logger.info(f"Loaded tokenizer from path: {self.tokenizer_path}")
                return
            except Exception as e:
                logger.warning(f"Failed to load tokenizer from path: {e}")
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(self.model)
            logger.info(f"Loaded tokenizer from model: {self.model}")
        except Exception as e2:
            logger.warning(f"Failed to load tokenizer from model: {e2}")
            self.tokenizer = None
    
    def _count_tokens(self, text: str) -> int:
        """Count tokens for a text string."""
        if self.tokenizer is None:
            # If tokenizer is unavailable, estimate by character count.
            return len(text) // 4
        try:
            return len(self.tokenizer.encode(text))
        except Exception:
            return len(text) // 4
    
    def complete(self, prompt: str, **kwargs) -> str:
        """Generate completion text."""
        start_time = time.time()
        
        # Count input tokens.
        input_tokens = self._count_tokens(prompt)
        
        # Invoke the concrete LLM implementation.
        response = self._complete_impl(prompt, **kwargs)
        
        # Count output tokens.
        output_tokens = self._count_tokens(response)
        
        # Record token usage.
        self.token_usage.add_usage(input_tokens, output_tokens)
        
        # Record invocation stats.
        elapsed_time = time.time() - start_time
        logger.info(
            "LLM call completed - model: %s, input tokens: %s, output tokens: %s, elapsed: %.2fs",
            self.model,
            input_tokens,
            output_tokens,
            elapsed_time,
        )
        
        return response
    
    def _complete_impl(self, prompt: str, **kwargs) -> str:
        """Concrete LLM implementation (override in subclasses)."""
        raise NotImplementedError
    
    def get_token_usage(self) -> TokenUsage:
        """Get token usage counters."""
        return self.token_usage
    
    def reset_token_usage(self):
        """Reset token usage counters."""
        self.token_usage = TokenUsage()
    
    def test_connection(self) -> bool:
        """Test API connectivity."""
        try:
            response = self.complete("List prime numbers under 10.")
            logger.info(f"API connection test succeeded - model: {self.model}, response: {response}")
            return True
        except Exception as e:
            logger.error(f"API connection test failed - model: {self.model}, error: {e}")
            return False

class AnthropicClient(BaseLLMClient):
    """Anthropic Claude client."""
    
    def __init__(self, model: str = "claude-3-7-sonnet-20250219", **kwargs):
        super().__init__(model, **kwargs)
        try:
            from llama_index.llms.anthropic import Anthropic
            self._client = Anthropic(
                model=model,
                max_tokens=kwargs.get("max_tokens", 16384),
            )
        except ImportError:
            raise ImportError("Please install llama-index-llms-anthropic.")
    
    def _complete_impl(self, prompt: str, **kwargs) -> str:
        # assert 0, "AnthropicClient not implemented."
        response = self._client.complete(prompt)
        return getattr(response, "text", str(response))

class OpenAIClient(BaseLLMClient):
    """OpenAI client."""
    
    def __init__(self, model: str = "gpt-4-turbo", **kwargs):
        super().__init__(model, **kwargs)
        try:
            from llama_index.llms.openai import OpenAI
            self._client = OpenAI(
                model=model,
                max_tokens=kwargs.get("max_tokens", 8192),
            )
        except ImportError:
            raise ImportError("Please install llama-index-llms-openai.")
    
    def _complete_impl(self, prompt: str, **kwargs) -> str:
        assert 0, "OpenAIClient not implemented."
        response = self._client.complete(prompt)
        return getattr(response, "text", str(response))

class VLLMClient(BaseLLMClient):
    """vLLM client."""
    
    def __init__(self, model: str, **kwargs):
        super().__init__(model, **kwargs)
        try:
            from llama_index.llms.vllm import Vllm
            self._client = Vllm(
                model=model,
                max_new_tokens=kwargs.get("max_new_tokens", 32768),
                temperature=kwargs.get("temperature", 0.1),
                top_p=kwargs.get("top_p", 0.95),
                tensor_parallel_size=kwargs.get("tensor_parallel_size", 2),
            )
        except ImportError:
            raise ImportError("Please install llama-index-llms-vllm.")
    
    def _complete_impl(self, prompt: str, **kwargs) -> str:
        response = self._client.complete(prompt)
        return getattr(response, "text", str(response))

class DeepSeekClient(BaseLLMClient):
    """DeepSeek client."""
    
    def __init__(self, model: str = "deepseek-v3-250324", **kwargs):
        super().__init__(model, **kwargs)
        try:
            from llama_index.llms.deepseek import DeepSeek
            self._client = DeepSeek(
                model=model,
                timeout=kwargs.get("timeout", 300),
                max_tokens=kwargs.get("max_tokens", 16384),
            )
        except ImportError:
            raise ImportError("Please install llama-index-llms-deepseek.")
    
    def _complete_impl(self, prompt: str, **kwargs) -> str:
        # Call complete (pass kwargs to support additional params).
        response = self._client.complete(prompt, **kwargs)
        # Get final answer text.
        answer_text = getattr(response, "text", str(response))
        # Extract reasoning_content.
        reasoning = ""
        raw = getattr(response, "raw", None)
        if raw and getattr(raw, "choices", None):
            message = getattr(raw.choices[0], "message", None)
            reasoning = getattr(message, "reasoning_content", "") or ""
        # Combine reasoning and answer.
        if reasoning:
            return f"<think>\n{reasoning}\n</think>\n\n{answer_text}"
        return answer_text

def create_llm_client(provider: str, model: str, **kwargs) -> BaseLLMClient:
    """Factory: create a client for the specified provider."""
    provider = provider.lower()
    
    if provider == "anthropic":
        return AnthropicClient(model, **kwargs)
    elif provider == "openai":
        return OpenAIClient(model, **kwargs)
    elif provider == "vllm":
        return VLLMClient(model, **kwargs)
    elif provider == "deepseek":
        return DeepSeekClient(model, **kwargs)
    else:
        raise ValueError(f"Unsupported provider: {provider}")

def test_all_providers():
    """Test all available providers."""
    providers = [
        # ("anthropic", "claude-3-7-sonnet-20250219"),
        # ("openai", "gpt-4-turbo"),
        ("deepseek", "deepseek-r1-250528"),
    ]
    
    for provider, model in providers:
        try:
            print(f"\nTesting {provider} provider...")
            client = create_llm_client(provider, model)
            if client.test_connection():
                print(f"✓ {provider} connection succeeded")
            else:
                print(f"✗ {provider} connection failed")
        except Exception as e:
            print(f"✗ {provider} initialization failed: {e}")

if __name__ == "__main__":
    test_all_providers() 