"""
Reasoner interface for the frozen reasoning component of reasoning scaffolds.

This module provides the interface to reasoning models that stay frozen during training.
These models (like DeepSeek-R1, GPT-4, etc.) perform the logical reasoning based on
VLM-generated context.
"""

from abc import ABC, abstractmethod
from dataclasses import dataclass, asdict
from typing import Dict, List, Optional, Any, Iterator
import torch.nn as nn
from openai import OpenAI
import os
import time


@dataclass 
class ReasonerConfig:
    """Configuration for reasoner interface."""
    # Model identification
    model_name: str
    model_type: str = "openai"  # "api", "deepseek", "openai"
    
    # API configuration
    api_base: Optional[str] = None
    api_key: str = "EMPTY"
    timeout: int = 900  # Longer timeout for reasoning
    
    # Generation parameters
    max_tokens: int = 100000  # Longer for reasoning
    temperature: float = 0.6  # Slightly higher for reasoning creativity
    top_p: float = 0.95
    top_k: Optional[int] = None
    
    # Reasoning-specific parameters
    enable_step_by_step: bool = True
    enable_verification: bool = True
    max_reasoning_steps: int = 7
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary."""
        return asdict(self)
    
    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> 'ReasonerConfig':
        """Create from dictionary."""
        return cls(**data)


class ReasonerInterface(ABC):
    """
    Abstract interface for reasoning models.
    
    This class defines the interface for reasoning models that perform logical
    analysis and problem solving. These models are ALWAYS FROZEN during training
    and provide consistent reasoning capabilities.
    """
    
    def __init__(self, config: ReasonerConfig):
        """
        Initialize reasoner interface.
        
        Args:
            config: Reasoner configuration
        """
        self.config = config
        
        # Statistics tracking
        self.total_requests = 0
        self.successful_requests = 0
        self.total_tokens_generated = 0
        self.total_reasoning_time = 0.0
        
    @abstractmethod
    def reason(
        self,
        context: str,
        **kwargs
    ) -> str:
        """
        Perform reasoning on the given context.
        
        Args:
            context: Input context (usually from VLM)
            **kwargs: Additional reasoning parameters
            
        Returns:
            Reasoning result with step-by-step analysis
        """
        pass
    
    def freeze_parameters(self):
        """
        Freeze reasoner parameters.
        
        For API-based reasoners, this is a no-op since they're already frozen.
        For local models, this would freeze all parameters.
        """
        pass  # API models are inherently frozen
    
    def get_trainable_parameters(self) -> Iterator[nn.Parameter]:
        """
        Get trainable parameters (always empty for reasoners).
        
        Reasoners are ALWAYS frozen, so this always returns empty.
        """
        return iter([])
    
    def get_statistics(self) -> Dict[str, Any]:
        """Get reasoner usage statistics."""
        success_rate = (
            self.successful_requests / self.total_requests 
            if self.total_requests > 0 else 0.0
        )
        
        avg_reasoning_time = (
            self.total_reasoning_time / self.successful_requests
            if self.successful_requests > 0 else 0.0
        )
        
        return {
            'total_requests': self.total_requests,
            'successful_requests': self.successful_requests,
            'success_rate': success_rate,
            'total_tokens_generated': self.total_tokens_generated,
            'total_reasoning_time': self.total_reasoning_time,
            'avg_reasoning_time': avg_reasoning_time,
            'model_name': self.config.model_name,
            'model_type': self.config.model_type,
        }
    
    def reset_statistics(self):
        """Reset usage statistics."""
        self.total_requests = 0
        self.successful_requests = 0
        self.total_tokens_generated = 0
        self.total_reasoning_time = 0.0
    
    def get_config(self) -> Dict[str, Any]:
        """Get reasoner configuration."""
        return self.config.to_dict()
    
    @staticmethod
    def create(config: ReasonerConfig) -> 'ReasonerInterface':
        """
        Factory method to create appropriate reasoner interface.
        
        Args:
            config: Reasoner configuration
            
        Returns:
            Concrete reasoner interface implementation
        """
        if config.model_type == "deepseek":
            return DeepSeekReasonerInterface(config)
        elif config.model_type == "openai":
            return OpenAIReasonerInterface(config)
        elif config.model_type == "api":
            return GenericAPIReasonerInterface(config)
        else:
            raise ValueError(f"Unknown reasoner type: {config.model_type}")


class DeepSeekReasonerInterface(ReasonerInterface):
    """Reasoner interface for DeepSeek-R1 models."""
    
    def __init__(self, config: ReasonerConfig):
        super().__init__(config)
        
        # Set up DeepSeek client
        api_key = config.api_key
        if api_key == "EMPTY" or not api_key:
            api_key = os.getenv("DEEPSEEK_API_KEY")
            if not api_key:
                raise ValueError("DeepSeek API key required")
        
        if not config.api_base:
            raise ValueError("DeepSeek API base URL is required in configuration")
        api_base = config.api_base
        
        self.client = OpenAI(
            api_key=api_key,
            base_url=api_base,
            timeout=config.timeout,
            max_retries=0  # Disable automatic retries to avoid duplicate requests
        )
    
    def reason(self, context: str, **kwargs) -> str:
        """Perform reasoning using DeepSeek-R1."""
        try:
            start_time = time.time()
            self.total_requests += 1
            
            # Legacy compatibility: do NOT add extra system prompt or wrap the context.
            # The `context` string already contains the full, formatted prompt including
            # the desired opening line ("You are an expert visual reasoning assistant ...").
            system_prompt = ""

            # Send the context directly as the user message to replicate legacy behaviour.
            reasoning_prompt = context
            
            # Merge generation parameters, handling top_k properly for vLLM-compatible APIs
            gen_params = {
                'max_tokens': self.config.max_tokens,
                'temperature': self.config.temperature,
                'top_p': self.config.top_p,
            }
            
            # Handle top_k parameter for vLLM-compatible APIs - must be passed in extra_body
            if self.config.top_k is not None:
                gen_params['extra_body'] = {'top_k': self.config.top_k}
            
            # Handle top_k from kwargs as well
            top_k_from_kwargs = kwargs.pop('top_k', None)
            if top_k_from_kwargs is not None:
                if 'extra_body' not in gen_params:
                    gen_params['extra_body'] = {}
                gen_params['extra_body']['top_k'] = top_k_from_kwargs
            
            # Filter out None values from kwargs to avoid issues with OpenAI client
            filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
            gen_params.update(filtered_kwargs)
            
            # Legacy behaviour: send only a single user message with the full prompt.
            response = self.client.chat.completions.create(
                model=self.config.model_name,
                messages=[
                    {"role": "user", "content": reasoning_prompt}
                ],
                **gen_params
            )
            
            result = response.choices[0].message.content
            
            # Update statistics
            self.successful_requests += 1
            self.total_tokens_generated += response.usage.completion_tokens
            self.total_reasoning_time += time.time() - start_time
            
            return result
            
        except Exception as e:
            raise RuntimeError(f"DeepSeek reasoning failed: {e}")


class OpenAIReasonerInterface(ReasonerInterface):
    """Reasoner interface for OpenAI models (GPT-4, etc.)."""
    
    def __init__(self, config: ReasonerConfig):
        super().__init__(config)
        
        # Set up OpenAI client
        api_key = config.api_key
        if api_key == "EMPTY" or not api_key:
            api_key = os.getenv("OPENAI_API_KEY")
            if not api_key:
                raise ValueError("OpenAI API key required")
        
        self.client = OpenAI(
            api_key=api_key,
            base_url=config.api_base,
            timeout=config.timeout,
            max_retries=0  # Disable automatic retries to avoid duplicate requests
        )
    
    def reason(self, context: str, **kwargs) -> str:
        """Perform reasoning using OpenAI models."""
        try:
            start_time = time.time()
            self.total_requests += 1
            
            # Legacy compatibility: do NOT prepend an additional system prompt.
            system_prompt = ""

            # Use the supplied context verbatim as the single user message.
            reasoning_prompt = context
            
            # Generation parameters
            gen_params = {
                'max_tokens': self.config.max_tokens,
                'temperature': self.config.temperature,
                'top_p': self.config.top_p,
            }

            # Handle top_k for vLLM-compatible OpenAI servers via extra_body
            extra_body = {}
            if self.config.top_k is not None:
                extra_body['top_k'] = self.config.top_k

            top_k_from_kwargs = kwargs.pop('top_k', None)
            if top_k_from_kwargs is not None:
                extra_body['top_k'] = top_k_from_kwargs

            # Remaining kwargs (excluding top_k) can be merged directly
            filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
            gen_params.update(filtered_kwargs)

            if extra_body:
                gen_params['extra_body'] = extra_body
            
            # Legacy behaviour: send only a single user message with the full prompt.
            response = self.client.chat.completions.create(
                model=self.config.model_name,
                messages=[
                    {"role": "user", "content": reasoning_prompt}
                ],
                **gen_params
            )
            
            result = response.choices[0].message.content
            
            # Update statistics
            self.successful_requests += 1
            self.total_tokens_generated += response.usage.completion_tokens
            self.total_reasoning_time += time.time() - start_time
            
            return result
            
        except Exception as e:
            raise RuntimeError(f"OpenAI reasoning failed: {e}")


class GenericAPIReasonerInterface(ReasonerInterface):
    """Generic API reasoner interface for custom reasoning models."""
    
    def __init__(self, config: ReasonerConfig):
        super().__init__(config)
        
        if not config.api_base:
            raise ValueError("api_base required for generic API reasoner")
        
        self.client = OpenAI(
            api_key=config.api_key,
            base_url=config.api_base,
            timeout=config.timeout,
            max_retries=0  # Disable automatic retries to avoid duplicate requests
        )
    
    def reason(self, context: str, **kwargs) -> str:
        """Perform reasoning using generic API."""
        try:
            start_time = time.time()
            self.total_requests += 1
            
            # Legacy: no extra wrapping, use context verbatim.
            system_prompt = ""
            reasoning_prompt = context
            
            # Generation parameters, handling top_k properly for vLLM-compatible APIs
            gen_params = {
                'max_tokens': self.config.max_tokens,
                'temperature': self.config.temperature,
                'top_p': self.config.top_p,
            }
            
            # Handle top_k parameter for vLLM-compatible APIs - must be passed in extra_body
            if self.config.top_k is not None:
                gen_params['extra_body'] = {'top_k': self.config.top_k}
            
            # Handle top_k from kwargs as well
            top_k_from_kwargs = kwargs.pop('top_k', None)
            if top_k_from_kwargs is not None:
                if 'extra_body' not in gen_params:
                    gen_params['extra_body'] = {}
                gen_params['extra_body']['top_k'] = top_k_from_kwargs
            
            # Filter out None values from kwargs to avoid issues with OpenAI client
            filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
            gen_params.update(filtered_kwargs)
            
            # Legacy behaviour: send only a single user message with the full prompt.
            response = self.client.chat.completions.create(
                model=self.config.model_name,
                messages=[
                    {"role": "user", "content": reasoning_prompt}
                ],
                **gen_params
            )
            
            result = response.choices[0].message.content
            
            # Update statistics
            self.successful_requests += 1
            self.total_tokens_generated += response.usage.completion_tokens
            self.total_reasoning_time += time.time() - start_time
            
            return result
            
        except Exception as e:
            raise RuntimeError(f"Generic API reasoning failed: {e}") 