"""
Utility for parsing JSON from LLM responses with automatic retries and validation.
"""

import json
import logging
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple, TypeVar

from llm_utils.openai_api.chat import Chat
from llm_utils.openai_api.text_message_content import TextMessageContent
from llm_utils.textgen_api.textgen_api import TextGenApi
from python_utils.string_utils import get_markup_from_text

logger = logging.getLogger(__name__)

T = TypeVar('T')


class LLMJsonParser:
    """
    A utility class for parsing JSON from LLM responses with automatic retries,
    validation, and error handling.
    """
    
    def __init__(self, textgen_api: TextGenApi):
        self.textgen_api = textgen_api
    
    def parse_json_response(
        self,
        chat: Chat,
        call_id: str,
        validator: Optional[Callable[[Dict[str, Any]], bool]] = None,
        processor: Optional[Callable[[Dict[str, Any]], T]] = None,
        connection: Optional[str] = None,
        max_retries: int = 5,
        chat_file: Optional[Path] = None,
        custom_error_message: Optional[str] = None,
    ) -> Tuple[T, Chat]:
        """
        Parse JSON from LLM response with automatic retries on failure.
        
        Args:
            chat: The chat context
            call_id: Identifier for the API call
            validator: Optional function to validate the parsed JSON structure
            processor: Optional function to process/transform the validated JSON
            connection: Optional connection ID for the API call
            max_retries: Maximum number of retry attempts
            chat_file: Optional file to save chat history
            custom_error_message: Custom error message for retry prompts
            
        Returns:
            Tuple of (processed_data, updated_chat)
            
        Raises:
            AssertionError: If valid JSON response cannot be obtained after max_retries
            ValueError: If validation or processing fails
        """
        error_msg = custom_error_message or "Failed to parse JSON. Please provide a valid JSON response."
        
        for attempt in range(max_retries):
            response = self.textgen_api.do_call(chat, connection_id=connection, call_id=call_id)
            chat = chat.add_message(response)
            
            if chat_file:
                chat_file.write_text(str(chat))

            assert len(response.content) == 1
            assert isinstance(response.content[0], TextMessageContent)
            content = response.content[0].text
            json_body = self._extract_json_from_content(content)

            try:
                data = json.loads(json_body)
                
                # Validate if validator is provided
                if validator and not validator(data):
                    raise ValueError("JSON validation failed")
                
                # Process if processor is provided
                if processor:
                    result = processor(data)
                else:
                    result = data
                    
                return result, chat
                
            except (json.JSONDecodeError, ValueError) as e:
                logger.error(f"Failed to process JSON (attempt {attempt + 1}/{max_retries}): {e}")
                if attempt < max_retries - 1:
                    chat = chat.add_user_text(f"{error_msg}\nError: {e}")
                continue

        raise AssertionError(f"Failed to get valid JSON response after {max_retries} attempts.")
    
    def _extract_json_from_content(self, content: str) -> str:
        """Extract JSON from content, trying markup extraction first."""
        if "json" in content.lower():
            json_bodies = get_markup_from_text(text=content, markup=["json"])
            if len(json_bodies) == 1:
                return json_bodies[0]
        return content


# Convenience function for one-off usage
def parse_json_from_llm(
    textgen_api: TextGenApi,
    chat: Chat,
    call_id: str,
    validator: Optional[Callable[[Dict[str, Any]], bool]] = None,
    processor: Optional[Callable[[Dict[str, Any]], T]] = None,
    **kwargs
) -> Tuple[T, Chat]:
    """
    Convenience function for parsing JSON from LLM without creating a parser instance.
    """
    parser = LLMJsonParser(textgen_api)
    return parser.parse_json_response(
        chat=chat,
        call_id=call_id,
        validator=validator,
        processor=processor,
        **kwargs
    )
