import re
import json
import subprocess
import tempfile
from typing import Tuple, List, Dict, Any


def livecodebench_data_process(dataset):
    """
    Process LiveCodeBench dataset into the required format.
    
    Expected input format:
    {
        "question_title": str,
        "question_content": str,
        "starter_code": str,
        "function_name": str,
        "test": str (JSON string with test cases),
        ...
    }
    
    Output format:
    {
        "task": str (question content + starter code),
        "test_cases": list (parsed test cases),
        "function_name": str,
        ...
    }
    """
    list_data_dict = []
    
    for data in dataset:
        # Parse test cases - handle both "test" and "public_test_cases" fields
        test_str = data.get("test", data.get("public_test_cases", "[]"))
        try:
            test_cases = json.loads(test_str) if isinstance(test_str, str) else test_str
        except json.JSONDecodeError:
            test_cases = []
        
        # Create task prompt
        question_content = data.get('question_content', '')
        starter_code = data.get('starter_code', '')
        
        task_prompt = f"""# {data.get('question_title', 'Problem')}

{question_content}

{starter_code}
"""
        
        item = {
            "task": task_prompt.strip(),
            "test_cases": test_cases if isinstance(test_cases, list) else [],
            "function_name": data.get("function_name", ""),
            "question_id": data.get("question_id", ""),
            "difficulty": data.get("difficulty", ""),
            "platform": data.get("platform", ""),
            "question_title": data.get("question_title", "")
        }
        list_data_dict.append(item)
    
    return list_data_dict


def livecodebench_get_predict(response: str) -> str:
    """
    Extract code from the model's response for LiveCodeBench.
    Handles markdown code blocks and plain code.
    """
    if not response:
        return ""
    
    # Try to extract code from markdown code blocks
    if "```python" in response:
        pattern = r"```python\n(.*?)```"
        matches = re.findall(pattern, response, re.DOTALL)
        if matches:
            return matches[0].strip()
    elif "```" in response:
        pattern = r"```\n(.*?)```"
        matches = re.findall(pattern, response, re.DOTALL)
        if matches:
            return matches[0].strip()
    
    # If no code blocks, try to find code after common phrases
    markers = [
        "Here is the solution:",
        "Here's the solution:",
        "Solution:",
        "Code:",
        "Answer:"
    ]
    for marker in markers:
        if marker in response:
            code_start = response.find(marker) + len(marker)
            remaining = response[code_start:].strip()
            # Try to extract code block after marker
            if "```python" in remaining:
                pattern = r"```python\n(.*?)```"
                matches = re.findall(pattern, remaining, re.DOTALL)
                if matches:
                    return matches[0].strip()
            elif "```" in remaining:
                pattern = r"```\n(.*?)```"
                matches = re.findall(pattern, remaining, re.DOTALL)
                if matches:
                    return matches[0].strip()
            return remaining.strip()
    
    # Return the whole response as last resort
    return response.strip()


def check_livecodebench_correctness(
    predicted_code: str,
    test_cases: List[Dict[str, Any]],
    function_name: str = "",
    timeout: float = 3.0
) -> Tuple[float, str]:
    """
    Check if the predicted code passes all test cases.
    
    Args:
        predicted_code: The code generated by the model
        test_cases: List of test case dicts with 'input' and 'output' keys
        function_name: Name of the function to test (for functional tests)
        timeout: Maximum time allowed for each test case
    
    Returns:
        (score, result_string): score is proportion of tests passed (0.0 to 1.0)
    """
    if not predicted_code or not predicted_code.strip():
        return 0.0, "No code generated"
    
    if not test_cases:
        return 0.0, "No test cases provided"
    
    passed = 0
    total = len(test_cases)
    error_messages = []
    
    for idx, test_case in enumerate(test_cases):
        test_type = test_case.get("testtype", "stdin")
        test_input = test_case.get("input", "")
        expected_output = test_case.get("output", "").strip()
        
        if test_type == "functional":
            # LeetCode-style functional test
            # Parse the input string to extract function arguments
            test_script = f"""
from typing import List, Optional, Dict, Any
import json

# User's code
{predicted_code}

# Test execution
try:
    # Parse input - handle LeetCode-style input format
    test_input_str = {repr(test_input)}
    expected_output = {repr(expected_output)}
    
    # Parse the input arguments
    parsed = json.loads(test_input_str)
    
    # Determine if we should unpack:
    # If parsed is a list where the first element is also a list/dict, it's likely multi-arg
    # Otherwise, if all elements are primitives (str/int/etc), it's a single list argument
    if isinstance(parsed, list) and len(parsed) > 0:
        # Check if this looks like multiple arguments (e.g., [[1,2,3], 5])
        # vs a single list argument (e.g., ["str1", "str2"])
        has_nested_structure = any(isinstance(item, (list, dict)) for item in parsed)
        
        if has_nested_structure:
            # Multiple arguments: [[1,2,3], 5] -> two args
            args_to_pass = parsed
            should_unpack = True
        else:
            # Single list argument: ["str1", "str2"] -> one arg
            args_to_pass = [parsed]
            should_unpack = False
    else:
        args_to_pass = [parsed]
        should_unpack = False
    
    # Create instance if it's a class method
    if 'class Solution' in {repr(predicted_code)} or 'def {function_name}(self' in {repr(predicted_code)}:
        solution = Solution()
        result = solution.{function_name}(*args_to_pass) if should_unpack else solution.{function_name}(*args_to_pass)
    else:
        # Standalone function
        result = {function_name}(*args_to_pass) if should_unpack else {function_name}(*args_to_pass)
    
    # Compare result
    result_str = str(result)
    if result_str == expected_output:
        print("PASS")
    else:
        print(f"FAIL: Expected {{expected_output}}, got {{result_str}}")
except Exception as e:
    import traceback
    print(f"ERROR: {{type(e).__name__}}: {{str(e)}}")
    # traceback.print_exc()
"""
        else:
            # stdin/stdout test
            test_script = f"""
import sys
from io import StringIO

# Redirect stdin
sys.stdin = StringIO('''{test_input}''')

# User code
{predicted_code}
"""
        
        try:
            with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
                f.write(test_script)
                temp_file = f.name
            
            try:
                result = subprocess.run(
                    ['python', temp_file],
                    capture_output=True,
                    text=True,
                    timeout=timeout
                )
                
                actual_output = result.stdout.strip()
                
                if test_type == "functional":
                    # Check for "PASS" in output
                    if "PASS" in actual_output:
                        passed += 1
                    else:
                        error_msg = actual_output[:200] if actual_output else result.stderr[:200]
                        error_messages.append(f"Test {idx+1}: {error_msg}")
                else:
                    # stdin/stdout test
                    if result.returncode == 0 and actual_output == expected_output:
                        passed += 1
                    else:
                        error_msg = result.stderr[:100] if result.stderr else f"Expected '{expected_output}', got '{actual_output}'"
                        error_messages.append(f"Test {idx+1}: {error_msg}")
                    
            finally:
                import os
                if os.path.exists(temp_file):
                    os.unlink(temp_file)
                    
        except subprocess.TimeoutExpired:
            error_messages.append(f"Test {idx+1}: Timeout")
        except Exception as e:
            error_messages.append(f"Test {idx+1}: {str(e)[:100]}")
    
    score = passed / total if total > 0 else 0.0
    result_str = f"Passed {passed}/{total} tests"
    if error_messages and len(error_messages) <= 3:
        result_str += f". Errors: {'; '.join(error_messages)}"
    
    return score, result_str

