import json
import os
import functools
from pathlib import Path
from typing import Dict, Optional, List, Literal, Union
from dataclasses import dataclass, field
from datasets import Dataset
from openai import OpenAI
import requests
from tenacity import retry, stop_after_attempt, wait_exponential
from src.utils import parse_args, update_args, prompt_cache
# HuggingFace backend removed due to multi-threading issues

DEBUG = os.environ.get('DEBUG', '').lower() in ('true', '1', 'yes', 'on')

# --- Argument Dataclass ---
@dataclass
class Arguments:
    """Arguments for the LLM-based prediction script."""
    base_url: str = field(default_factory=lambda: os.environ.get("OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1"))
    api_key: Optional[str] = field(default_factory=lambda: os.environ.get("OPENROUTER_API_KEY"))
    google_api_key: Optional[str] = field(default_factory=lambda: os.environ.get("GOOGLE_API_KEY"))
    model_name: str = "google/gemini-2.0-flash-001"
    prompt_template_path: str = "prompts/COT_D.txt"
    output_column: str = "predicted_answer"
    num_proc: int = 10
    max_tokens: int = 1024 
    temperature: float = 0.0
    top_p: float = 1.0
    verbose: bool = False
    retry_attempts: int = 5  # Added retry_attempts field
    suffix: str = ""  # Replaced add_no_think with suffix field
    max_thinking_tokens: Optional[int] = None  # Maximum tokens to allocate for reasoning
    inference_backend: Literal["openai", "google", "vllm", "vllm_offline"] = "openai"
    vllm_host: str = "localhost"
    vllm_port: int = 8001

# This function is no longer needed since we determine backend from config

@retry(wait=wait_exponential(multiplier=1, min=2, max=60), stop=stop_after_attempt(5), reraise=True)
def call_gemini_api(
    prompt: str,
    google_api_key: str,
    model_name: str,
    max_tokens: int,
    temperature: float,
    top_p: float,
    max_thinking_tokens: Optional[int] = None,
) -> Optional[str]:
    """Calls the Google Gemini API using requests to get predictions."""
    url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent?key={google_api_key}"
    headers = {
        "Content-Type": "application/json"
    }
    
    payload = {
        "contents": [
            {
                "parts": [
                    {
                        "text": prompt
                    }
                ]
            }
        ],
        "generationConfig": {
            "maxOutputTokens": max_tokens,
            "temperature": temperature,
            "topP": top_p,
        }
    }
    
    # Add thinking config if needed
    if max_thinking_tokens is not None:
        payload["generationConfig"]["thinkingConfig"] = {
            "thinkingBudget": max_thinking_tokens,
            "includeThoughts": True
        }
    
    response = None
    try:
        response = requests.post(url, headers=headers, data=json.dumps(payload))
        response.raise_for_status()
        data = response.json()
        
        # Extract the predicted answer content
        if "candidates" in data and data["candidates"]:
            candidate = data["candidates"][0]
            content = candidate.get("content", {})
            parts = content.get("parts", [{}])
            text = parts[0].get("text", "") if parts else ""
            
            # Handle thinking/reasoning part
            thoughts_token_count = data.get("usageMetadata", {}).get("thoughtsTokenCount", 0)
            
            if thoughts_token_count > 0:
                # Generate fake thinking content based on token count
                thinking_text = "think " * thoughts_token_count
                final_response = f"<think>{thinking_text}<\\think>\n{text}"
                return final_response.strip()
            else:
                return text.strip()
        else:
            print(f"API call successful but response format unexpected or content missing.")
            print(f"Response JSON: {data}")
            return None
            
    except requests.exceptions.RequestException as e:
        print(f"Gemini API request failed with error: {str(e)}")
        if response is not None:
            print(f"Response status code: {response.status_code}")
            try:
                print(f"Response body: {response.text}")
            except Exception:
                print("Could not read response body.")
        raise
    except json.JSONDecodeError as e:
        print(f"Failed to decode Gemini JSON response: {str(e)}")
        if response is not None:
            print(f"Response text: {response.text}")
        raise
    except Exception as e:
        print(f"An unexpected error occurred during Gemini API call: {str(e)}")
        if response is not None:
            try:
                print(f"Response body: {response.text}")
            except Exception:
                print("Could not read response body.")
        raise

@retry(wait=wait_exponential(multiplier=1, min=2, max=60), stop=stop_after_attempt(5), reraise=True)
@prompt_cache(param_names=["prompt", "model_name", "max_tokens", "temperature", "top_p"])
def call_openai_api(
    prompt: str,
    client: OpenAI,
    model_name: str,
    max_tokens: int,
    temperature: float,
    top_p: float,
    max_thinking_tokens: Optional[int] = None,
) -> Optional[str]:
    """Calls the OpenRouter Chat Completion API using requests to get predictions."""
    api_key = client.api_key
    base_url = str(client.base_url).rstrip('/') # Ensure no trailing slash
    url = f"{base_url}/chat/completions"
    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json"
    }
    payload = {
        "model": model_name,
        "messages": [{"role": "user", "content": prompt}],
        "max_tokens": max_tokens,
        "temperature": temperature,
        "top_p": top_p,
    }
    
    # Configure reasoning parameters
    reasoning_config = {}
    if max_thinking_tokens is not None:
        reasoning_config["max_tokens"] = max_thinking_tokens
        
    payload["reasoning"] = reasoning_config
    
    response = None # Define response here to ensure it's available in except block if request fails
    try:
        response = requests.post(url, headers=headers, data=json.dumps(payload))
        response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
        data = response.json()

        # Extract the predicted answer content
        message = data.get('choices', [{}])[0].get('message', {})
        content = message.get('content', '')
        reasoning = message.get('reasoning')
        
        if content:
            final_response = content
            if reasoning:
                final_response = f"<think>{reasoning}<\\think>\n{content}"
            return final_response.strip()
        else:
            print(f"API call successful but response format unexpected or content missing.")
            print(f"Response JSON: {data}")
            return None # Or raise an error, depending on desired handling

    except requests.exceptions.RequestException as e:
        # Handle connection errors, timeouts, etc.
        print(f"API request failed with error: {str(e)}")
        if response is not None:
            print(f"Response status code: {response.status_code}")
            try:
                print(f"Response body: {response.text}")
            except Exception:
                print("Could not read response body.")
        raise # Re-raise the exception for tenacity to handle retry
    except json.JSONDecodeError as e:
        # Handle cases where the response isn't valid JSON
        print(f"Failed to decode JSON response: {str(e)}")
        if response is not None:
            print(f"Response text: {response.text}")
        raise # Re-raise the exception for tenacity to handle retry
    except (KeyError, IndexError) as e:
        # Handle unexpected structure in the JSON response
        print(f"Failed to extract data from response JSON: {str(e)}")
        if 'data' in locals():
             print(f"Response JSON structure: {data}")
        elif response is not None:
             print(f"Response text: {response.text}")
        raise # Re-raise the exception for tenacity to handle retry
    except Exception as e:
        # Catch any other unexpected errors
        print(f"An unexpected error occurred during API call: {str(e)}")
        if response is not None:
            print(f"Response status code: {response.status_code}")
            try:
                print(f"Response body: {response.text}")
            except Exception:
                print("Could not read response body.")
        elif 'data' in locals():
             print(f"Parsed data before error: {data}")
        raise  # Re-raise the exception for tenacity to handle retry

@retry(wait=wait_exponential(multiplier=1, min=2, max=60), stop=stop_after_attempt(5), reraise=True)
@prompt_cache(param_names=["prompt", "model_name", "max_tokens", "temperature"])
def call_google_api(
    prompt: str,
    api_key: str,
    model_name: str,
    max_tokens: int,
    temperature: float,
    top_p: float = 0.95,
) -> Optional[str]:
    try:
        import google.generativeai as genai
        genai.configure(api_key=api_key)
        
        model = genai.GenerativeModel(model_name)
        
        generation_config = genai.GenerationConfig(
            max_output_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
        )
        
        response = model.generate_content(
            prompt,
            generation_config=generation_config,
        )
        
        return response.text.strip()
    except Exception as e:
        print(f"Google API call failed with error: {str(e)}")
        raise


@retry(wait=wait_exponential(multiplier=1, min=2, max=60), stop=stop_after_attempt(5), reraise=True)
@prompt_cache(param_names=["prompt", "model_name", "max_tokens", "temperature", "top_p"])
def call_vllm_api(
    prompt: str,
    model_name: str,
    max_tokens: int,
    temperature: float,
    top_p: float,
    vllm_host: str,
    vllm_port: int,
) -> Optional[str]:
    try:
        # Using OpenAI client with vLLM's OpenAI-compatible endpoint
        client = OpenAI(
            api_key="EMPTY",
            base_url=f"http://{vllm_host}:{vllm_port}/v1"
        )
        
        response = client.chat.completions.create(
            model=model_name,
            messages=[{"role": "user", "content": prompt}],
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
        )
        
        return response.choices[0].message.content.strip()
    except Exception as e:
        print(f"vLLM API call failed with error: {str(e)}")
        raise




# Global variable to store vLLM model instance
_vllm_model_cache = {}

def get_or_create_vllm_model(model_name: str):
    """Get or create a cached vLLM model instance."""
    # Hardcode gpu_memory_utilization to 0.9
    gpu_memory_utilization = 0.9
    
    # Hardcode max_model_len to 8192
    max_model_len = 8192
    
    # Compute tensor_parallel_size from visible CUDA devices
    cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES', '')
    if cuda_visible_devices:
        # Count the number of visible CUDA devices
        tensor_parallel_size = len(cuda_visible_devices.split(','))
    else:
        # Default to 1 if CUDA_VISIBLE_DEVICES is not set
        tensor_parallel_size = 1
    
    # Check if FlashInfer should be enabled
    use_flashinfer = os.environ.get('VLLM_ATTENTION_BACKEND', '').upper() == 'FLASHINFER'
    
    cache_key = (model_name, gpu_memory_utilization, tensor_parallel_size, max_model_len, use_flashinfer)
    if cache_key not in _vllm_model_cache:
        from vllm import LLM
        
        # Base configuration
        config = {
            "model": model_name,
            "gpu_memory_utilization": gpu_memory_utilization,
            "tensor_parallel_size": tensor_parallel_size,
            "max_model_len": max_model_len,
        }
        
        # Add FlashInfer specific configuration if enabled
        if use_flashinfer:
            config["distributed_executor_backend"] = "mp"  # Multiprocessing backend required for FlashInfer
            
        _vllm_model_cache[cache_key] = LLM(**config)
    return _vllm_model_cache[cache_key]


@prompt_cache(param_names=["prompt", "model_name", "max_tokens", "temperature", "top_p"])
def call_vllm_offline(
    prompt: str,
    model_name: str,
    max_tokens: int,
    temperature: float,
    top_p: float,
) -> Optional[str]:
    try:
        from vllm import SamplingParams
        
        # Get or create cached vLLM model
        llm = get_or_create_vllm_model(model_name)
        
        # Set sampling parameters
        sampling_params = SamplingParams(
            temperature=temperature,
            top_p=top_p,
            max_tokens=max_tokens,
        )
        
        # Generate output
        outputs = llm.generate([prompt], sampling_params)
        
        # Extract text from output
        return outputs[0].outputs[0].text.strip()
        
    except Exception as e:
        print(f"vLLM offline inference failed with error: {str(e)}")
        raise


# --- Prediction Function for a Single Sample ---
def get_prediction_for_sample(
    sample: Dict,
    args: Arguments,
    prompt_template: str
) -> Dict[str, Union[Optional[str], List[float]]]:
    prompt = prompt_template.format(**sample)
    if args.suffix:
        prompt += " " + args.suffix

    try:
        if args.inference_backend == "openai":
            client = OpenAI(api_key=args.api_key, base_url=args.base_url)
            prediction = call_openai_api(
                prompt=prompt,
                client=client,
                model_name=args.model_name,
                max_tokens=args.max_tokens,
                temperature=args.temperature,
                top_p=args.top_p,
                max_thinking_tokens=args.max_thinking_tokens,
            )
        elif args.inference_backend == "google":
            if not args.google_api_key:
                raise ValueError("Google API key is required for Google backend")
            prediction = call_gemini_api(
                prompt=prompt,
                google_api_key=args.google_api_key,
                model_name=args.model_name,
                max_tokens=args.max_tokens,
                temperature=args.temperature,
                top_p=args.top_p,
                max_thinking_tokens=args.max_thinking_tokens,
            )
        elif args.inference_backend == "vllm":
            prediction = call_vllm_api(
                prompt=prompt,
                model_name=args.model_name,
                max_tokens=args.max_tokens,
                temperature=args.temperature,
                top_p=args.top_p,
                vllm_host=args.vllm_host,
                vllm_port=args.vllm_port,
            )
        elif args.inference_backend == "vllm_offline":
            prediction = call_vllm_offline(
                prompt=prompt,
                model_name=args.model_name,
                max_tokens=args.max_tokens,
                temperature=args.temperature,
                top_p=args.top_p,
            )
        else:
            raise ValueError(f"Unknown inference backend: {args.inference_backend}")
    except Exception as e:
        if args.verbose:
            print(f"Error getting prediction for sample: {e}. Prompt: {prompt[:50]}... Returning None.")
        return {args.output_column: None} # Return dict format on error

    return {args.output_column: prediction}

# --- Main Prediction Function ---
def predict_dataset_vllm_batch(
    dataset: Dataset,
    args: Arguments,
    prompt_template: str
) -> Dataset:
    """Batch inference using vLLM's native batch processing."""
    from vllm import SamplingParams
    
    # Get or create cached vLLM model
    llm = get_or_create_vllm_model(args.model_name)
    
    # Set sampling parameters
    sampling_params = SamplingParams(
        temperature=args.temperature,
        top_p=args.top_p,
        max_tokens=args.max_tokens,
    )
    
    # Prepare all prompts
    prompts = []
    for sample in dataset:
        prompt = prompt_template.format(**sample)
        if args.suffix:
            prompt += " " + args.suffix
        prompts.append(prompt)
    
    if args.verbose:
        print(f"Prepared {len(prompts)} prompts for batch inference")
    
    # Generate outputs for all prompts at once
    outputs = llm.generate(prompts, sampling_params)
    
    # Extract predictions from outputs
    predictions = []
    for output in outputs:
        text = output.outputs[0].text.strip()
        predictions.append(text)
    
    # Add predictions to dataset
    dataset = dataset.add_column(args.output_column, predictions)
    
    return dataset


def predict_dataset(
    dataset: Dataset,
    args: Arguments = None,
    **kwargs
) -> Dataset:
    args = update_args(args if args else Arguments(), **kwargs)
    prompt_template = Path(args.prompt_template_path).read_text()

    if args.verbose:
        print(f"Starting prediction for column '{args.output_column}'...")
        print(f"Using backend: {args.inference_backend}")
        print(f"Using model: {args.model_name}")
        if args.inference_backend == "openai":
            print(f"Base URL: {args.base_url}")
        elif args.inference_backend == "google":
            print(f"Using Google API directly")
        elif args.inference_backend == "vllm":
            print(f"vLLM Server: http://{args.vllm_host}:{args.vllm_port}")
        elif args.inference_backend == "vllm_offline":
            print(f"vLLM Offline: GPU memory utilization=0.9, tensor_parallel_size=auto (based on CUDA_VISIBLE_DEVICES), max_model_len=8192")
            print(f"Using native vLLM batch inference")
        else:
            print(f"Using {args.num_proc} processes.")

    # Use native batch inference for vllm_offline
    if args.inference_backend == "vllm_offline":
        predicted_dataset = predict_dataset_vllm_batch(dataset, args, prompt_template)
    else:
        # Use the existing map-based approach for other backends
        map_function = functools.partial(
            get_prediction_for_sample,
            args=args,
            prompt_template=prompt_template
        )

        predicted_dataset = dataset.map(
            map_function,
            num_proc=args.num_proc,
            load_from_cache_file=False, # Disable dataset's caching, rely on our prompt_cache
        )

    if args.verbose:
        try:
            print("Prediction Results (first 10):")
            num_to_print = min(len(predicted_dataset), 10)
            for idx in range(num_to_print):
                print(f"  Q: {predicted_dataset['question'][idx]}")
                print(f"  Predicted Answer: {predicted_dataset[args.output_column][idx]}")
                print("-" * 10)
        except Exception as e:
            print(f"Could not print prediction results: {e}")

    return predicted_dataset

if __name__ == "__main__":
    args = parse_args(Arguments)
    if args.verbose:
        print(f"Running prediction with arguments: {args}")

    # Create a dummy dataset containing questions
    dummy_data = {
        "question": [
            "What is the boiling point of water?",
            "How many light bulbs are in the Empire State Building?",
            "What is the capital of the moon?",
            "Calculate 254687 * 56778 / 45923 + 2344",
            "What is the 3th decimal digit of answer to sqrt(5649283 * 7899 / 234567 + 42786)"
        ]
    }
    dummy_dataset = Dataset.from_dict(dummy_data)

    predicted_dataset = predict_dataset(dummy_dataset, args=args)
    print("Final Predicted Dataset:")
    print(predicted_dataset)