"""
Base reasoning scaffold class that defines the interface for all reasoning scaffolds.

This module provides the abstract base class that all reasoning scaffolds must inherit from,
along with common data structures and utilities.
"""

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Optional, Any, Union
from PIL import Image
import json
import time


@dataclass
class ReasoningResult:
    """
    Result of a reasoning process through a scaffold.
    
    This class encapsulates all information about a reasoning session,
    including the final answer, intermediate steps, and metadata.
    """
    # Core results
    final_answer: str
    reasoning_steps: List[str]
    success: bool
    
    # VLM outputs (the tunable part)
    vlm_initial_response: Optional[str] = None
    vlm_intermediate_responses: Optional[List[str]] = None
    
    # Reasoner outputs (the frozen part)  
    reasoner_analysis: Optional[str] = None
    reasoner_steps: Optional[List[str]] = None
    
    # Metadata and tracking
    scaffold_type: str = "unknown"
    total_iterations: int = 1
    processing_time: float = 0.0
    
    # Debug information
    debug_info: Optional[Dict[str, Any]] = None
    error_message: Optional[str] = None
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary for serialization."""
        return {
            'final_answer': self.final_answer,
            'reasoning_steps': self.reasoning_steps,
            'success': self.success,
            'vlm_initial_response': self.vlm_initial_response,
            'vlm_intermediate_responses': self.vlm_intermediate_responses,
            'reasoner_analysis': self.reasoner_analysis,
            'reasoner_steps': self.reasoner_steps,
            'scaffold_type': self.scaffold_type,
            'total_iterations': self.total_iterations,
            'processing_time': self.processing_time,
            'debug_info': self.debug_info,
            'error_message': self.error_message,
        }
    
    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> 'ReasoningResult':
        """Create from dictionary for deserialization."""
        return cls(**data)


class BaseReasoningScaffold(ABC):
    """
    Abstract base class for all reasoning scaffolds.
    
    This class defines the interface that all reasoning scaffolds must implement.
    The key design principle is the separation of tunable VLM components from
    frozen reasoner components.
    
    Architecture:
        VLM (tunable) → processes image/question → generates analysis
        Reasoner (frozen) → performs logical reasoning → produces final answer
    """
    
    def __init__(
        self,
        vlm_interface: 'VLMInterface',
        reasoner_interface: 'ReasonerInterface',
        scaffold_name: str = "base_scaffold"
    ):
        """
        Initialize the reasoning scaffold.
        
        Args:
            vlm_interface: Interface to the vision-language model (tunable)
            reasoner_interface: Interface to the reasoning model (frozen)
            scaffold_name: Name identifier for this scaffold type
        """
        self.vlm = vlm_interface
        self.reasoner = reasoner_interface
        self.scaffold_name = scaffold_name
        
        # Statistics tracking
        self.total_queries = 0
        self.successful_queries = 0
        
    @abstractmethod
    def reason(
        self,
        image: Union[str, Image.Image],
        question: str,
        **kwargs
    ) -> ReasoningResult:
        """
        Perform reasoning on the given image and question.
        
        This is the main entry point for the reasoning process. Each scaffold
        implements its own reasoning strategy while maintaining the same interface.
        
        Args:
            image: Input image (path or PIL Image)
            question: Question to answer about the image
            **kwargs: Additional scaffold-specific parameters
            
        Returns:
            ReasoningResult containing the complete reasoning trace
        """
        pass
    
    def batch_reason(
        self,
        inputs: List[Dict[str, Any]],
        **kwargs
    ) -> List[ReasoningResult]:
        """
        Process a batch of reasoning requests.
        
        Default implementation processes each item sequentially.
        Subclasses can override for parallel processing.
        
        Args:
            inputs: List of dicts with 'image', 'question' keys
            **kwargs: Additional parameters
            
        Returns:
            List of ReasoningResult objects
        """
        results = []
        for item in inputs:
            try:
                result = self.reason(
                    image=item['image'],
                    question=item['question'],
                    **kwargs
                )
                results.append(result)
            except Exception as e:
                # Create error result
                error_result = ReasoningResult(
                    final_answer="",
                    reasoning_steps=[],
                    success=False,
                    scaffold_type=self.scaffold_name,
                    error_message=str(e)
                )
                results.append(error_result)
                
        return results
    
    def get_statistics(self) -> Dict[str, Any]:
        """Get scaffold usage statistics."""
        success_rate = (
            self.successful_queries / self.total_queries 
            if self.total_queries > 0 else 0.0
        )
        
        return {
            'scaffold_name': self.scaffold_name,
            'total_queries': self.total_queries,
            'successful_queries': self.successful_queries,
            'success_rate': success_rate,
            'vlm_stats': self.vlm.get_statistics(),
            'reasoner_stats': self.reasoner.get_statistics(),
        }
    
    def reset_statistics(self):
        """Reset usage statistics."""
        self.total_queries = 0
        self.successful_queries = 0
        self.vlm.reset_statistics()
        self.reasoner.reset_statistics()
    
    def get_trainable_parameters(self):
        """
        Get parameters that should be trained (VLM only).
        
        This is the key method that defines which parameters are tunable.
        Only VLM parameters are returned, as reasoner stays frozen.
        
        Returns:
            Iterator of trainable parameters from VLM only
        """
        return self.vlm.get_trainable_parameters()
    
    def freeze_reasoner(self):
        """Ensure reasoner parameters are frozen."""
        self.reasoner.freeze_parameters()
    
    def unfreeze_vlm(self):
        """Ensure VLM parameters are trainable.""" 
        self.vlm.unfreeze_parameters()
    
    def save_config(self, config_path: str):
        """Save scaffold configuration."""
        config = {
            'scaffold_name': self.scaffold_name,
            'vlm_config': self.vlm.get_config(),
            'reasoner_config': self.reasoner.get_config(),
        }
        
        with open(config_path, 'w') as f:
            json.dump(config, f, indent=2)
    
    @classmethod
    def from_config(cls, config_path: str) -> 'BaseReasoningScaffold':
        """Load scaffold from configuration file."""
        with open(config_path, 'r') as f:
            config = json.load(f)
            
        # This is abstract, so concrete classes will implement the loading
        raise NotImplementedError("Subclasses must implement from_config")
    
    def __str__(self) -> str:
        return f"{self.__class__.__name__}(vlm={self.vlm}, reasoner={self.reasoner})"
    
    def __repr__(self) -> str:
        return self.__str__() 