import os
import json
import logging
import asyncio
from typing import List, Dict, Any
from tqdm import tqdm
import httpx

logger = logging.getLogger(__name__)

# Import OpenRouter API Key if available
try:
    from openrouter_key import API_KEY as OPENROUTER_API_KEY
    print("OpenRouter API key found: " + OPENROUTER_API_KEY)
except ImportError:
    try:
        # Try to import from the inference directory
        from inference.openrouter_key import API_KEY as OPENROUTER_API_KEY
    except ImportError:
        OPENROUTER_API_KEY = os.environ.get("OPENROUTER_API_KEY", "")
        if not OPENROUTER_API_KEY:
            logger.warning("OpenRouter API key not found. Set OPENROUTER_API_KEY environment variable or create openrouter_key.py")

from inference.base import BaseInference

class OpenRouterInference(BaseInference):
    """OpenRouter-based inference engine"""
    
    def __init__(self, model: str, max_tokens: int = 2048, temperature: float = 0.7):
        """
        Initialize OpenRouter inference engine
        
        Args:
            model: Model name on OpenRouter
            max_tokens: Maximum tokens for generation
            temperature: Temperature for generation
        """
        self.model = model
        self.max_tokens = max_tokens
        self.temperature = temperature
        
        if not OPENROUTER_API_KEY:
            raise ValueError("OpenRouter API key not found")
            
        self.api_url = "XXXX"
        
    async def generate(self, prompts: List[str], **kwargs) -> List[str]:
        """
        Generate completions for prompts using OpenRouter
        
        Args:
            prompts: List of prompts to generate completions for
            **kwargs: Additional arguments for generation
            
        Returns:
            List of generated completions
        """
        batch_size = kwargs.get('batch_size', 5)
        results = []
        
        for i in tqdm(range(0, len(prompts), batch_size)):
            batch = prompts[i:i+batch_size]
            tasks = []
            
            for prompt in batch:
                headers = {
                    "Authorization": f"Bearer {OPENROUTER_API_KEY}",
                    "Content-Type": "application/json",
                }
                
                data = {
                    "model": self.model,
                    "messages": [{"role": "user", "content": prompt}],
                    "temperature": self.temperature,
                    "max_tokens": self.max_tokens,
                }
                
                tasks.append(self._make_request(headers, data, prompt))
            
            batch_results = await asyncio.gather(*tasks)
            results.extend(batch_results)
            
            # Sleep to avoid rate limits
            await asyncio.sleep(1)
    
        return results
    
    async def _make_request(self, headers, data, prompt = None):
        """Make a request to the OpenRouter API with retries."""
        max_retries = 5
        retry_delay = 5  # seconds

        for attempt in range(max_retries):
            try:
                # Create a client for each request with an appropriate timeout
                async with httpx.AsyncClient(timeout=900.0) as client:
                    response = await client.post(self.api_url, headers=headers, json=data)
                    response.raise_for_status()

                    try:
                        resp_json = response.json()
                    except json.decoder.JSONDecodeError:
                        logger.error(f"Failed to decode JSON. Response text: {response.text}")
                        # Consider delaying before retry if JSON decode fails
                        await asyncio.sleep(retry_delay)
                        continue  # Go to next attempt
                    message = resp_json.get("choices", [{}])[0].get("message", {})
                    content = message.get("content")
                    answer_tags_required = False
                    if prompt:
                        answer_tags_required = "<answer>" in prompt
                    
                    if content is not None and len(content) > 0:  # Check if content is actually present
                        reasoning = message.get("reasoning", "")
                        finish_reason = resp_json.get("choices", [{}])[0].get("finish_reason", "")
                        usage = resp_json.get("usage", {})
                        prompt_tokens = usage.get("prompt_tokens", 0)
                        completion_tokens = usage.get("completion_tokens", 0)
                        # print("Response: ", content)
                        result = {
                            "response": content,
                            "finish_reason": finish_reason,
                            "prompt_tokens": prompt_tokens,
                            "reasoning": reasoning,
                            "completion_tokens": completion_tokens
                        }
                        
                        if answer_tags_required:
                            if "<answer>" in content and "</answer>" in content:
                                return result
                            else:
                                if attempt >= 1:
                                    return result
                                logger.warning(f"Received successful response but no <answer> tags for model {data.get('model')}. Response: {resp_json}")
                        else:
                            return result
                    else:
                        logger.warning(f"Received successful response but no content for model {data.get('model')}. Response: {resp_json}")
                        if resp_json["error"]["code"] == 429:
                            logger.warning(f"Rate limit hit. Retrying after 60 seconds...")
                            await asyncio.sleep(60)
                        # Treat as failure if content is None or empty after success
                        # Fall through to retry logic

            except httpx.HTTPStatusError as e:
                # Specific handling for rate limits or server errors if needed
                if e.response.status_code == 429:
                    logger.warning(f"Rate limit hit. Retrying after longer delay...")
                    await asyncio.sleep(retry_delay * (attempt + 2))  # Exponential backoff
                elif e.response.status_code >= 500:
                    logger.warning(f"Server error ({e.response.status_code}). Retrying...")
                    await asyncio.sleep(retry_delay * (attempt + 1))
                else:
                    logger.error(f"HTTP error: {e.response.status_code} - {e.response.text}")
                    # Don't retry for client errors like 400 Bad Request immediately
                    if 400 <= e.response.status_code < 500 and e.response.status_code != 429:
                        logger.error("Client error detected, not retrying.")
                        break  # Exit retry loop for client errors

            except httpx.RequestError as e:
                logger.error(f"Request failed: {e}")
            except Exception as e:
                logger.error(f"Unexpected error during API request: {str(e)}")

            # Wait before retrying, unless it was the last attempt or a non-retriable error
            if attempt < max_retries - 1:
                logger.info(f"Retrying request (attempt {attempt + 2}/{max_retries})...")
                await asyncio.sleep(retry_delay * (attempt + 1))  # Increase delay slightly

        # If all retries failed or a non-retriable error occurred, return None
        logger.error(f"Failed to get completion for prompt after {max_retries} attempts.")
        return None 