"""
Vision-Language Model interface for the tunable component of reasoning scaffolds.

This module provides the interface to VLMs that can be trained. It supports both
vLLM API-based models (for inference) and HuggingFace models (for training).
"""

from abc import ABC, abstractmethod
from dataclasses import dataclass, asdict
from typing import Dict, List, Optional, Any, Union, Iterator
from PIL import Image
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor
from openai import OpenAI
import base64
import io
import os
import json


@dataclass
class VLMConfig:
    """Configuration for VLM interface."""
    # Model identification
    model_name: str
    model_type: str = "vllm"  # "vllm", "huggingface", "openai"
    
    # API configuration (for vLLM/OpenAI)
    api_base: Optional[str] = None
    api_key: str = "EMPTY"
    timeout: int = 300
    
    # Model path (for HuggingFace)
    model_path: Optional[str] = None
    
    # Generation parameters (None -> rely on model defaults / context)
    max_tokens: Optional[int] = None
    temperature: Optional[float] = None
    top_p: Optional[float] = None
    top_k: Optional[int] = -1  # Match legacy
    
    # Training configuration
    freeze_vision_tower: bool = False
    freeze_language_model: bool = False
    
    # Device and precision
    device: str = "auto"
    torch_dtype: str = "bfloat16"
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary."""
        return asdict(self)
    
    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> 'VLMConfig':
        """Create from dictionary."""
        return cls(**data)


class VLMInterface(ABC):
    """
    Abstract interface for Vision-Language Models.
    
    This class defines the interface for VLMs that can be either used for inference
    (via APIs) or training (via HuggingFace models). The key focus is on VLM-only
    tuning while keeping reasoner components frozen.
    """
    
    def __init__(self, config: VLMConfig):
        """
        Initialize VLM interface.
        
        Args:
            config: VLM configuration
        """
        self.config = config
        self.model = None
        self.tokenizer = None
        self.processor = None
        
        # Statistics tracking
        self.total_requests = 0
        self.successful_requests = 0
        self.total_tokens_generated = 0
        
    @abstractmethod
    def generate(
        self,
        image: Union[str, Image.Image],
        prompt: str,
        **kwargs
    ) -> str:
        """
        Generate response from VLM.
        
        Args:
            image: Input image (path or PIL Image)
            prompt: Text prompt
            **kwargs: Additional generation parameters
            
        Returns:
            Generated response text
        """
        pass
    
    def generate_with_logprobs(
        self,
        image: Union[str, Image.Image],
        prompt: str,
        **kwargs
    ) -> Dict[str, Any]:
        """
        Generate response from VLM with logprobs data.
        
        Args:
            image: Input image (path or PIL Image)
            prompt: Text prompt
            **kwargs: Additional generation parameters
            
        Returns:
            Dict containing 'text' and 'logprobs' keys
        """
        # Default implementation just returns text
        text = self.generate(image, prompt, **kwargs)
        return {'text': text, 'logprobs': None}
    
    @abstractmethod
    def get_trainable_parameters(self) -> Iterator[nn.Parameter]:
        """
        Get trainable parameters for VLM-only tuning.
        
        Returns:
            Iterator of trainable parameters
        """
        pass
    
    def freeze_parameters(self):
        """Freeze all VLM parameters."""
        if self.model is None:
            return
            
        for param in self.model.parameters():
            param.requires_grad = False
    
    def unfreeze_parameters(self):
        """Unfreeze VLM parameters according to configuration."""
        if self.model is None:
            return
            
        # Unfreeze all parameters first
        for param in self.model.parameters():
            param.requires_grad = True
            
        # Then selectively freeze based on config
        if self.config.freeze_vision_tower and hasattr(self.model, 'vision_tower'):
            for param in self.model.vision_tower.parameters():
                param.requires_grad = False
                
        if self.config.freeze_language_model and hasattr(self.model, 'language_model'):
            for param in self.model.language_model.parameters():
                param.requires_grad = False
    
    def get_statistics(self) -> Dict[str, Any]:
        """Get VLM usage statistics."""
        success_rate = (
            self.successful_requests / self.total_requests 
            if self.total_requests > 0 else 0.0
        )
        
        return {
            'total_requests': self.total_requests,
            'successful_requests': self.successful_requests,
            'success_rate': success_rate,
            'total_tokens_generated': self.total_tokens_generated,
            'model_name': self.config.model_name,
            'model_type': self.config.model_type,
        }
    
    def reset_statistics(self):
        """Reset usage statistics."""
        self.total_requests = 0
        self.successful_requests = 0
        self.total_tokens_generated = 0
    
    def get_config(self) -> Dict[str, Any]:
        """Get VLM configuration."""
        return self.config.to_dict()
    
    @staticmethod
    def create(config: VLMConfig) -> 'VLMInterface':
        """
        Factory method to create appropriate VLM interface.
        
        Args:
            config: VLM configuration
            
        Returns:
            Concrete VLM interface implementation
        """
        if config.model_type == "vllm":
            return VLLMInterface(config)
        elif config.model_type == "huggingface":
            return HuggingFaceVLMInterface(config)
        elif config.model_type == "openai":
            return OpenAIVLMInterface(config)
        else:
            raise ValueError(f"Unknown VLM type: {config.model_type}")


class VLLMInterface(VLMInterface):
    """VLM interface for vLLM API-based models."""
    
    def __init__(self, config: VLMConfig):
        super().__init__(config)
        
        if not config.api_base:
            raise ValueError("api_base required for vLLM interface")
            
        self.client = OpenAI(
            api_key=config.api_key,
            base_url=config.api_base,
            timeout=config.timeout,
            max_retries=0  # Disable automatic retries to avoid duplicate requests
        )
    
    def generate(
        self,
        image: Union[str, Image.Image],
        prompt: str,
        **kwargs
    ) -> str:
        """Generate response using vLLM API with robust error handling."""
        try:
            self.total_requests += 1
            
            # Prepare image
            if isinstance(image, str):
                if image.startswith('http'):
                    image_url = image
                else:
                    # Convert local file to base64
                    with open(image, 'rb') as f:
                        image_data = base64.b64encode(f.read()).decode('utf-8')
                    image_url = f"data:image/jpeg;base64,{image_data}"
            else:
                # Convert PIL Image to base64
                buffer = io.BytesIO()
                image.save(buffer, format='JPEG')
                image_data = base64.b64encode(buffer.getvalue()).decode('utf-8')
                image_url = f"data:image/jpeg;base64,{image_data}"
            
            # Merge generation parameters, handling top_k properly for vLLM
            gen_params = {
                'temperature': self.config.temperature,
                'top_p': self.config.top_p,
            }
            
            # Handle top_k parameter for vLLM - must be passed in extra_body
            if self.config.top_k is not None:
                gen_params['extra_body'] = {'top_k': self.config.top_k}
            
            # Filter out None values from kwargs to avoid issues with OpenAI client
            kwargs.pop('top_k', None)
            filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
            gen_params.update(filtered_kwargs)
            
            # Try the original request first
            try:
                response = self.client.chat.completions.create(
                    model=self.config.model_name,
                    messages=[
                        {
                            "role": "user",
                            "content": [
                                {"type": "text", "text": prompt},
                                {"type": "image_url", "image_url": {"url": image_url}}
                            ]
                        }
                    ],
                    **gen_params
                )
                
                result = response.choices[0].message.content
                self.successful_requests += 1
                self.total_tokens_generated += response.usage.completion_tokens
                return result
                
            except Exception as e:
                error_str = str(e)
                
                # Handle context length errors (both prompt length and image tokens) - same logic as legacy captioner.py
                if "maximum model length" in error_str.lower() or "too long" in error_str.lower() or "8192" in error_str:
                    print(f"Warning: Context length exceeded. Error: {error_str}")
                    
                    # Try progressive image resizing first (this is usually the culprit)
                    image_sizes = [
                        (1024, 1024),  # First try: reasonable high resolution
                        (768, 768),    # Second try: medium resolution  
                        (512, 512),    # Third try: lower resolution
                        (384, 384),    # Fourth try: very low resolution
                    ]
                    
                    for i, max_size in enumerate(image_sizes):
                        try:
                            print(f"Retrying with image size limit: {max_size}")
                            # Resize image and try again
                            resized_image_url = self._resize_image_for_api(image, max_size)
                            
                            response = self.client.chat.completions.create(
                                model=self.config.model_name,
                                messages=[
                                    {
                                        "role": "user",
                                        "content": [
                                            {"type": "text", "text": prompt},
                                            {"type": "image_url", "image_url": {"url": resized_image_url}}
                                        ]
                                    }
                                ],
                                **gen_params
                            )
                            
                            result = response.choices[0].message.content
                            self.successful_requests += 1
                            self.total_tokens_generated += response.usage.completion_tokens
                            return result
                            
                        except Exception as resize_error:
                            resize_error_str = str(resize_error)
                            if "maximum model length" in resize_error_str.lower() and i < len(image_sizes) - 1:
                                # Try next smaller size
                                continue
                            elif i == len(image_sizes) - 1:
                                # Last resize attempt failed, try shorter prompt
                                print(f"All image resize attempts failed, trying shorter prompt")
                                break
                            else:
                                # Different error, propagate it
                                raise resize_error
                    
                    # If image resizing didn't work, try shorter prompts
                    if len(prompt) > 3000:
                        shortened_prompt = prompt[:2500] + "\n\nPlease provide a detailed description of the key visual elements relevant to answering the question above."
                    else:
                        shortened_prompt = prompt[:2000] + "\n\nPlease provide a focused description based on the question above."
                    
                    try:
                        print(f"Trying with shorter prompt ({len(shortened_prompt)} chars) and smallest image size")
                        smallest_image_url = self._resize_image_for_api(image, (384, 384))
                        
                        response = self.client.chat.completions.create(
                            model=self.config.model_name,
                            messages=[
                                {
                                    "role": "user",
                                    "content": [
                                        {"type": "text", "text": shortened_prompt},
                                        {"type": "image_url", "image_url": {"url": smallest_image_url}}
                                    ]
                                }
                            ],
                            **gen_params
                        )
                        
                        result = response.choices[0].message.content
                        self.successful_requests += 1
                        self.total_tokens_generated += response.usage.completion_tokens
                        return result
                        
                    except Exception as e2:
                        # Final attempt with very short prompt
                        very_short_prompt = "Describe this image briefly."
                        try:
                            print(f"Final attempt with very short prompt and smallest image")
                            response = self.client.chat.completions.create(
                                model=self.config.model_name,
                                messages=[
                                    {
                                        "role": "user",
                                        "content": [
                                            {"type": "text", "text": very_short_prompt},
                                            {"type": "image_url", "image_url": {"url": smallest_image_url}}
                                        ]
                                    }
                                ],
                                **gen_params
                            )
                            
                            result = response.choices[0].message.content
                            self.successful_requests += 1
                            self.total_tokens_generated += response.usage.completion_tokens
                            return result
                            
                        except Exception as e3:
                            raise RuntimeError(f"VLM API call failed even with shortest prompt and smallest image: {e3}")
                else:
                    raise RuntimeError(f"VLM API call failed: {e}")
            
        except Exception as e:
            raise RuntimeError(f"VLM generation failed: {e}")
    
    def generate_with_logprobs(
        self,
        image: Union[str, Image.Image],
        prompt: str,
        **kwargs
    ) -> Dict[str, Any]:
        """Generate response using vLLM API with logprobs data."""
        try:
            self.total_requests += 1
            
            # Prepare image (same as generate method)
            if isinstance(image, str):
                if image.startswith('http'):
                    image_url = image
                else:
                    # Convert local file to base64
                    with open(image, 'rb') as f:
                        image_data = base64.b64encode(f.read()).decode('utf-8')
                    image_url = f"data:image/jpeg;base64,{image_data}"
            else:
                # Convert PIL Image to base64
                buffer = io.BytesIO()
                image.save(buffer, format='JPEG')
                image_data = base64.b64encode(buffer.getvalue()).decode('utf-8')
                image_url = f"data:image/jpeg;base64,{image_data}"
            
            # Merge generation parameters
            gen_params = {
                'temperature': self.config.temperature,
                'top_p': self.config.top_p,
            }
            
            # Handle top_k parameter for vLLM
            if self.config.top_k is not None:
                gen_params['extra_body'] = {'top_k': self.config.top_k}
            
            # Filter out None values and add logprobs
            filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
            gen_params.update(filtered_kwargs)
            
            # Force logprobs to be enabled
            gen_params['logprobs'] = True
            gen_params['top_logprobs'] = gen_params.get('top_logprobs', 10)
            
            response = self.client.chat.completions.create(
                model=self.config.model_name,
                messages=[
                    {
                        "role": "user",
                        "content": [
                            {"type": "text", "text": prompt},
                            {"type": "image_url", "image_url": {"url": image_url}}
                        ]
                    }
                ],
                **gen_params
            )
            
            result_text = response.choices[0].message.content
            logprobs_data = response.choices[0].logprobs
            
            self.successful_requests += 1
            self.total_tokens_generated += response.usage.completion_tokens
            
            return {
                'text': result_text,
                'logprobs': logprobs_data
            }
            
        except Exception as e:
            raise RuntimeError(f"VLM generation with logprobs failed: {e}")
    
    def _resize_image_for_api(self, image: Union[str, Image.Image], max_size: tuple) -> str:
        """Resize image and return as base64 data URL for API call."""
        try:
            from PIL import Image as PILImage
            
            # Load image
            if isinstance(image, str):
                if image.startswith('http'):
                    # For HTTP URLs, we can't resize, so just return the URL
                    return image
                else:
                    pil_image = PILImage.open(image)
            else:
                pil_image = image
            
            # Convert to RGB if needed
            if pil_image.mode != 'RGB':
                pil_image = pil_image.convert('RGB')
            
            # Resize maintaining aspect ratio
            pil_image.thumbnail(max_size, PILImage.Resampling.LANCZOS)
            
            # Convert to base64
            buffer = io.BytesIO()
            pil_image.save(buffer, format='JPEG', quality=85, optimize=True)
            image_data = base64.b64encode(buffer.getvalue()).decode('utf-8')
            
            return f"data:image/jpeg;base64,{image_data}"
            
        except Exception as e:
            raise ValueError(f"Failed to resize image: {e}")
    
    def get_trainable_parameters(self) -> Iterator[nn.Parameter]:
        """vLLM models are not directly trainable."""
        return iter([])


class HuggingFaceVLMInterface(VLMInterface):
    """VLM interface for HuggingFace models (trainable)."""
    
    def __init__(self, config: VLMConfig):
        super().__init__(config)
        
        if not config.model_path:
            config.model_path = config.model_name
            
        # Load model and tokenizer
        self._load_model()
    
    def _load_model(self):
        """Load HuggingFace model and tokenizer."""
        try:
            # Determine torch dtype
            if self.config.torch_dtype == "bfloat16":
                torch_dtype = torch.bfloat16
            elif self.config.torch_dtype == "float16":
                torch_dtype = torch.float16
            else:
                torch_dtype = torch.float32
            
            # Load tokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.config.model_path,
                trust_remote_code=True
            )
            
            # Load processor if available
            try:
                self.processor = AutoProcessor.from_pretrained(
                    self.config.model_path,
                    trust_remote_code=True
                )
            except:
                self.processor = None
            
            # Load model
            self.model = AutoModelForCausalLM.from_pretrained(
                self.config.model_path,
                torch_dtype=torch_dtype,
                device_map=self.config.device,
                trust_remote_code=True
            )
            
            # Apply freezing configuration
            self.unfreeze_parameters()
            
        except Exception as e:
            raise RuntimeError(f"Failed to load HuggingFace model: {e}")
    
    def generate(
        self,
        image: Union[str, Image.Image],
        prompt: str,
        **kwargs
    ) -> str:
        """Generate response using HuggingFace model."""
        try:
            self.total_requests += 1
            
            # Prepare image
            if isinstance(image, str):
                image = Image.open(image)
            
            # Prepare inputs
            if self.processor:
                inputs = self.processor(
                    text=prompt,
                    images=image,
                    return_tensors="pt"
                ).to(self.model.device)
            else:
                # Fallback: use tokenizer only
                inputs = self.tokenizer(
                    prompt,
                    return_tensors="pt"
                ).to(self.model.device)
            
            # Generation parameters
            gen_params = {
                'temperature': self.config.temperature,
                'top_p': self.config.top_p,
            }
            if self.config.max_tokens is not None:
                gen_params['max_tokens'] = self.config.max_tokens
            if self.config.top_k is not None:
                gen_params['top_k'] = self.config.top_k
            gen_params.update(kwargs)
            
            # Generate
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    **gen_params
                )
            
            # Decode response
            generated_tokens = outputs[0][inputs['input_ids'].shape[1]:]
            response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
            
            self.successful_requests += 1
            self.total_tokens_generated += len(generated_tokens)
            
            return response
            
        except Exception as e:
            raise RuntimeError(f"VLM generation failed: {e}")
    
    def get_trainable_parameters(self) -> Iterator[nn.Parameter]:
        """Get trainable parameters from HuggingFace model."""
        if self.model is None:
            return iter([])
            
        return (p for p in self.model.parameters() if p.requires_grad)


class OpenAIVLMInterface(VLMInterface):
    """VLM interface for OpenAI API models (GPT-4V, etc.)."""
    
    def __init__(self, config: VLMConfig):
        super().__init__(config)
        
        # Set up OpenAI client
        api_key = config.api_key
        if api_key == "EMPTY" or not api_key:
            api_key = os.getenv("OPENAI_API_KEY")
            if not api_key:
                raise ValueError("OpenAI API key required")
        
        self.client = OpenAI(
            api_key=api_key,
            base_url=config.api_base,
            timeout=config.timeout,
            max_retries=0  # Disable automatic retries to avoid duplicate requests
        )
    
    def generate(
        self,
        image: Union[str, Image.Image],
        prompt: str,
        **kwargs
    ) -> str:
        """Generate response using OpenAI API."""
        try:
            self.total_requests += 1
            
            # Prepare image
            if isinstance(image, str):
                if image.startswith('http'):
                    image_url = image
                else:
                    with open(image, 'rb') as f:
                        image_data = base64.b64encode(f.read()).decode('utf-8')
                    image_url = f"data:image/jpeg;base64,{image_data}"
            else:
                buffer = io.BytesIO()
                image.save(buffer, format='JPEG')
                image_data = base64.b64encode(buffer.getvalue()).decode('utf-8')
                image_url = f"data:image/jpeg;base64,{image_data}"
            
            # Generation parameters
            gen_params = {
                'temperature': self.config.temperature,
                'top_p': self.config.top_p,
            }
            if self.config.max_tokens is not None:
                gen_params['max_tokens'] = self.config.max_tokens
            gen_params.update(kwargs)
            
            # Make API call
            response = self.client.chat.completions.create(
                model=self.config.model_name,
                messages=[
                    {
                        "role": "user",
                        "content": [
                            {"type": "text", "text": prompt},
                            {"type": "image_url", "image_url": {"url": image_url}}
                        ]
                    }
                ],
                **gen_params
            )
            
            result = response.choices[0].message.content
            self.successful_requests += 1
            self.total_tokens_generated += response.usage.completion_tokens
            
            return result
            
        except Exception as e:
            raise RuntimeError(f"VLM generation failed: {e}")
    
    def get_trainable_parameters(self) -> Iterator[nn.Parameter]:
        """OpenAI models are not directly trainable."""
        return iter([]) 