"""LLM clients for API providers."""
import requests
import json
from pathlib import Path
from typing import Optional, Dict, Tuple
import torch
from transformers import AutoTokenizer
import threading

try:
    from vllm import LLM, SamplingParams
    VLLM_AVAILABLE = True
except ImportError:
    VLLM_AVAILABLE = False
    LLM = None
    SamplingParams = None

_model_cache: Dict[str, LLM] = {}
_tokenizer_cache: Dict[str, AutoTokenizer] = {}
_cache_lock = threading.Lock()


def get_or_load_model(
    model_name: str,
    model_path: str,
    tensor_parallel_size: int = None,
    dtype: str = None,
    max_model_len: int = 16384
) -> Tuple[LLM, AutoTokenizer]:
    """Get or load a vLLM model using a global cache."""
    if not VLLM_AVAILABLE:
        raise ImportError(
            "vLLM is not installed. Please install it with: pip install vllm"
        )
    
    with _cache_lock:
        if model_name in _model_cache:
            return _model_cache[model_name], _tokenizer_cache[model_name]
        
        if tensor_parallel_size is None:
            if torch.cuda.is_available():
                tensor_parallel_size = torch.cuda.device_count()
            else:
                tensor_parallel_size = 1
        
        if dtype is None:
            if torch.cuda.is_available():
                if torch.cuda.is_bf16_supported():
                    vllm_dtype = "bfloat16"
                else:
                    vllm_dtype = "float16"
            else:
                vllm_dtype = "float32"
        else:
            dtype_map = {
                "float16": "float16",
                "bfloat16": "bfloat16",
                "float32": "float32"
            }
            vllm_dtype = dtype_map.get(dtype, "bfloat16")
        
        vllm_kwargs = {
            "trust_remote_code": True,
            "dtype": vllm_dtype,
            "tensor_parallel_size": tensor_parallel_size,
            "max_model_len": max_model_len,
        }
        
        llm = LLM(model=model_path, **vllm_kwargs)
        
        tokenizer = AutoTokenizer.from_pretrained(
            model_path,
            trust_remote_code=True
        )
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        _model_cache[model_name] = llm
        _tokenizer_cache[model_name] = tokenizer
        
        return llm, tokenizer


def get_local_model_path(model_name: str) -> str:
    """Get local model path."""
    from .vllm_model_config import get_model_path_with_snapshot, VLLM_MODEL_CONFIG
    try:
        return get_model_path_with_snapshot(model_name)
    except ValueError:
        available = ", ".join(VLLM_MODEL_CONFIG.keys())
        raise ValueError(f"Unknown local model: {model_name}. Available models: {available}")


def load_api_key(key_name: str, keys_file: str = None) -> str:
    """Load API key from JSON file."""
    if keys_file is None:
        keys_file = Path(__file__).parent / "api_keys.json"
    
    with open(keys_file, 'r', encoding='utf-8') as f:
        keys = json.load(f)
    
    return keys.get(key_name, "")


class InternS1Client:
    """Intern S1 API client."""
    
    def __init__(
        self,
        api_key: str = None,
        base_url: str = None,
        model: str = "intern-s1",
        thinking_mode: bool = True
    ):
        """Initialize Intern S1 client."""
        if api_key is None:
            api_key = load_api_key("intern_s1")
        
        
        self.api_key = api_key
        self.base_url = base_url.rstrip('/')
        self.model = model
        self.thinking_mode = thinking_mode
        self.url = f"{self.base_url}/chat/completions"
        self.headers = {
            'Content-Type': 'application/json',
            'Authorization': f'Bearer {api_key}'
        }
    
    def generate(
        self,
        prompt: str,
        *,
        seed: Optional[int] = None,
        temperature: float = 0.7,
        top_p: float = 0.9,
        max_tokens: int = 2048
    ) -> str:
        """Generate text from the model."""
        data = {
            "model": self.model,
            "messages": [{"role": "user", "content": prompt}],
            "thinking_mode": self.thinking_mode,
            "temperature": temperature,
            "top_p": top_p,
            "max_tokens": max_tokens
        }
        
        if seed is not None:
            data["seed"] = seed
        
        try:
            response = requests.post(
                self.url,
                headers=self.headers,
                data=json.dumps(data),
                timeout=300 if not self.thinking_mode else 1200
            )
            
            if response.status_code != 200:
                raise Exception(
                    f"API call failed: {response.status_code}, {response.text}"
                )
            
            result = response.json()
            return result["choices"][0]["message"]["content"]
        
        except requests.exceptions.RequestException as e:
            raise Exception(f"Request failed: {e}")
        except (KeyError, IndexError) as e:
            raise Exception(f"Failed to parse response: {e}")


class LocalVLLMClient:
    """Local vLLM client using a global cache."""
    
    def __init__(
        self,
        model_name: str,
        model_path: str = None,
        tensor_parallel_size: int = None,
        dtype: str = None,
        max_model_len: int = 16384,
        thinking_mode: bool = True
    ):
        """Initialize a local vLLM client."""
        self.model_name = model_name
        self.thinking_mode = thinking_mode
        
        if model_path is None:
            from .vllm_model_config import get_model_path_with_snapshot, get_model_config
            model_path = get_model_path_with_snapshot(model_name)
            config = get_model_config(model_name)
            if tensor_parallel_size is None:
                tensor_parallel_size = config.get("tensor_parallel_size")
            if dtype is None:
                dtype = config.get("dtype")
            if max_model_len is None:
                max_model_len = config.get("max_model_len", 16384)
        
        self.llm, self.tokenizer = get_or_load_model(
            model_name=model_name,
            model_path=model_path,
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
            max_model_len=max_model_len
        )
    
    def generate(
        self,
        prompt: str,
        *,
        seed: Optional[int] = None,
        temperature: float = 0.7,
        top_p: float = 0.9,
        max_tokens: int = 2048
    ) -> str:
        """Generate text using vLLM."""
        max_input_tokens = 8192
        inputs = self.tokenizer(prompt, return_tensors="pt")
        input_length = inputs["input_ids"].shape[1]
        
        if input_length > max_input_tokens:
            truncated_prompt = self.tokenizer.decode(
                inputs["input_ids"][0, :max_input_tokens],
                skip_special_tokens=True
            )
            prompt = truncated_prompt
        
        sampling_params = SamplingParams(
            temperature=temperature,
            top_p=top_p,
            max_tokens=max_tokens,
            seed=seed if seed is not None else None,
        )
        
        outputs = self.llm.generate([prompt], sampling_params)
        
        if len(outputs) > 0 and len(outputs[0].outputs) > 0:
            generated_text = outputs[0].outputs[0].text
        else:
            generated_text = ""
        
        return generated_text.strip()


LocalHuggingFaceClient = LocalVLLMClient


def create_llm_client(llm_config):
    """Create an LLM client instance from a configuration object."""
    if llm_config.llm_type == "intern_s1":
        return InternS1Client(
            api_key=None,
            base_url=llm_config.base_url,
            model=llm_config.model,
            thinking_mode=llm_config.thinking_mode
        )
    elif llm_config.llm_type == "local_vllm" or llm_config.llm_type == "local_hf":
        return LocalVLLMClient(
            model_name=llm_config.model_name,
            model_path=getattr(llm_config, "model_path", None),
            tensor_parallel_size=getattr(llm_config, "tensor_parallel_size", None),
            dtype=getattr(llm_config, "dtype", None),
            max_model_len=getattr(llm_config, "max_model_len", 16384),
            thinking_mode=getattr(llm_config, "thinking_mode", False)
        )
    else:
        raise ValueError(f"Unknown llm_type: {llm_config.llm_type}")
