import unittest
from unittest.mock import patch, ANY
import logging
from typing import Any

import json
import re
from olym_gen.generator.json_format_generator import JsonFormatOutputMixin
from olym_gen.utils.utils import get_logger
logger = get_logger()

class TestJsonFormatOutputMixin(unittest.TestCase):
    class ConcreteJsonFormatOutputMixin(JsonFormatOutputMixin):
        def _check_json_format(self, json_object: Any, other_info: dict[str, Any] | None = None) -> None:
            pass  # Abstract method implementation

    def setUp(self):
        self.mixin = self.ConcreteJsonFormatOutputMixin()
        # Setup logger capture for testing
        self.log_capture = []
        handler = logging.Handler()
        handler.emit = lambda record: self.log_capture.append(record.getMessage())
        logger.addHandler(handler)

class TestPreFixEscape(TestJsonFormatOutputMixin):
    """Test cases for pre_fix_escape function"""
    
    def test_default_behavior_double_all_backslashes(self):
        """Test default behavior doubles all backslashes except allowed ones"""
        response = r'{"text": "\alpha + \beta = \gamma"}'
        fixed = self.mixin.pre_fix_escape(response)
        expected = r'{"text": "\\alpha + \\beta = \\gamma"}'
        self.assertEqual(fixed, expected)

    def test_preserve_allowed_escape_sequences(self):
        """Test that allowed escape sequences are preserved"""
        response = r'{"text": "Line with \n newline and \\ backslash and \" quote"}'
        fixed = self.mixin.pre_fix_escape(response)
        # Should preserve \n, \\, and \" by default
        self.assertEqual(fixed, response)

    def test_latex_commands_override_allowed_escapes(self):
        """Test that LaTeX commands override allowed escape sequences"""
        response = r'{"formula": "\not= and \neq symbols"}'
        fixed = self.mixin.pre_fix_escape(response)
        expected = r'{"formula": "\\not= and \\neq symbols"}'
        self.assertEqual(fixed, expected)

    def test_custom_allowed_escapes(self):
        """Test with custom allowed escape sequences"""
        response = r'{"text": "\n \t \r \alpha"}'
        fixed = self.mixin.pre_fix_escape(response, allow_escape=re.compile(r'\\[nt]'))
        expected = r'{"text": "\n \t \\r \\alpha"}'
        self.assertEqual(fixed, expected)

    def test_custom_latex_commands(self):
        """Test with custom LaTeX commands"""
        response = r'{"formula": "\alpha + \beta + \gamma"}'
        fixed = self.mixin.pre_fix_escape(response, allow_escape=re.compile(r'\\[abg]'), match_latex=[r'\alpha', r'\beta'])
        expected = r'{"formula": "\\alpha + \\beta + \gamma"}'
        self.assertEqual(fixed, expected)

    def test_no_latex_matching(self):
        """Test with LaTeX matching disabled"""
        response = r'{"formula": "\not= and \neq symbols"}'
        fixed = self.mixin.pre_fix_escape(response, match_latex=False)
        self.assertEqual(fixed, response)

    def test_latex_command_boundary_detection(self):
        """Test LaTeX command boundary detection"""
        # \neq should be matched, but \nequation should not
        response = r'{"test": "\neq and \nequation"}'
        fixed = self.mixin.pre_fix_escape(response, match_latex=[r'\neq'])
        expected = r'{"test": "\\neq and \nequation"}'
        self.assertEqual(fixed, expected)

    def test_longest_latex_command_first(self):
        """Test that longest LaTeX commands are matched first"""
        response = r'{"test": "\newcommand vs \new"}'
        fixed = self.mixin.pre_fix_escape(response, match_latex=[r'\new', r'\newcommand'])
        expected = r'{"test": "\\newcommand vs \\new"}'
        self.assertEqual(fixed, expected)
        
    def test_end_latex_command_first(self):
        """Test that whether correct complete LaTeX commands are matched"""
        response = r'{"test": "\newcommand vs \new"}'
        fixed = self.mixin.pre_fix_escape(response, match_latex=[r'\new'])
        expected = r'{"test": "\newcommand vs \\new"}'
        self.assertEqual(fixed, expected)

    def test_backslash_at_end_of_string(self):
        """Test handling of backslash at end of string"""
        response = r'{"path": "folder\"}'
        fixed = self.mixin.pre_fix_escape(response)
        expected = r'{"path": "folder\"}'
        self.assertEqual(fixed, expected)

    def test_consecutive_backslashes(self):
        """Test handling of consecutive backslashes"""
        response = r'{"path": "C:\\folder"}'
        fixed = self.mixin.pre_fix_escape(response)
        # \\ is allowed by default, so should be preserved
        expected = r'{"path": "C:\\folder"}'
        self.assertEqual(fixed, expected)

    def test_mixed_scenarios(self):
        """Test complex mixed scenarios"""
        response = r'{"complex": "Valid \n, LaTeX \alpha, invalid \xyz, quote \" end"}'
        fixed = self.mixin.pre_fix_escape(response)
        expected = r'{"complex": "Valid \n, LaTeX \\alpha, invalid \\xyz, quote \" end"}'
        self.assertEqual(fixed, expected)

    def test_empty_string(self):
        """Test with empty string"""
        response = ""
        fixed = self.mixin.pre_fix_escape(response)
        self.assertEqual(fixed, "")

    def test_no_backslashes(self):
        """Test string without backslashes"""
        response = '{"normal": "text without escapes"}'
        fixed = self.mixin.pre_fix_escape(response)
        self.assertEqual(fixed, response)

    def test_unicode_with_backslashes(self):
        """Test Unicode characters with backslashes"""
        response = r'{"unicode": "中文 \alpha 测试 \beta"}'
        fixed = self.mixin.pre_fix_escape(response)
        expected = r'{"unicode": "中文 \\alpha 测试 \\beta"}'
        self.assertEqual(fixed, expected)

class TestFixEscape(TestJsonFormatOutputMixin):
    """Test cases for fix_escape function"""

    def test_with_real_json_decode_error(self):
        """Test with actual JSONDecodeError from invalid escape"""
        response = r'{"test": "\invalid"}'
        try:
            json.loads(response)
            self.fail("Should have raised JSONDecodeError")
        except json.JSONDecodeError as e:
            error_message = str(e)
            fixed = self.mixin.fix_escape(response, error_message)
            # Should fix the invalid escape
            expected = r'{"test": "\\invalid"}'
            self.assertEqual(fixed, expected)

    def test_with_real_error_valid_and_invalid_escapes(self):
        """Test with real error containing multiple invalid escapes"""
        # Only the first error position will be reported by json.loads
        response = r'{"test1": "\bad", "test2": "\also_bad"}'
        try:
            json.loads(response)
            self.fail("Should have raised JSONDecodeError")
        except json.JSONDecodeError as e:
            error_message = str(e)
            fixed = self.mixin.fix_escape(response, error_message)
            # Should fix only the first invalid escape, \b is not a bad escape but the \a is.
            expected = r'{"test1": "\bad", "test2": "\\also_bad"}'
            self.assertEqual(fixed, expected)
            
    def test_with_real_error_multiple_invalid_escapes(self):
        """Test with real error containing multiple invalid escapes"""
        # Only the first error position will be reported by json.loads
        response = r'{"test1": "\ahabad", "test2": "\also_bad"}'
        try:
            json.loads(response)
            self.fail("Should have raised JSONDecodeError")
        except json.JSONDecodeError as e:
            error_message = str(e)
            fixed = self.mixin.fix_escape(response, error_message)
            # Should fix only the first invalid escape
            expected = r'{"test1": "\\ahabad", "test2": "\also_bad"}'
            self.assertEqual(fixed, expected)

    def test_with_real_error_preserve_valid_escapes(self):
        """Test preserving valid escapes when fixing invalid ones"""
        response = r'{"valid": "\n valid newline", "invalid": "\xyz escape"}'
        try:
            json.loads(response)
            self.fail("Should have raised JSONDecodeError")
        except json.JSONDecodeError as e:
            error_message = str(e)
            fixed = self.mixin.fix_escape(response, error_message)
            # Should preserve \n but fix \x
            expected = r'{"valid": "\n valid newline", "invalid": "\\xyz escape"}'
            self.assertEqual(fixed, expected)

    def test_malformed_error_message(self):
        """Test handling of malformed error messages"""
        response = r'{"test": "\invalid"}'
        error = "Some random error without position info"
        # raise ValueError if error is malformed
        with self.assertRaises(ValueError) as cm:
            self.mixin.fix_escape(response, error)

# Integration tests
class TestIntegration(TestJsonFormatOutputMixin):
    """Integration tests for the complete flow"""

    def test_pre_fix_then_json_loads_success(self):
        """Test that pre_fix_escape prepares string for successful JSON loading"""
        response = r'{"formula": "\alpha + \beta = \gamma"}'
        fixed = self.mixin.pre_fix_escape(response)
        # Should be valid JSON after pre-processing
        try:
            json_obj = json.loads(fixed)
            self.assertEqual(json_obj["formula"], r"\alpha + \beta = \gamma")
        except json.JSONDecodeError:
            self.fail("JSON should be valid after pre_fix_escape")

    def test_response_to_json_full_flow(self):
        """Test _response_to_json method with the complete flow"""
        response = r'{"formula": "\alpha + \beta", "note": "test \xyz"}'
        result = self.mixin._response_to_json(response)
        self.assertIsNotNone(result)
        self.assertEqual(result["formula"], r"\alpha + \beta") # type: ignore
        self.assertEqual(result["note"], r"test \xyz") # type: ignore

if __name__ == '__main__':
    unittest.main()