import re
import io
import sys
import traceback
from utils.sonnet_eval import sonnet_errors

def clean_output_for_GameOf24(output: str) -> str:
    """
    Clean the output for GameOf24 problems.
    """
    if "=" in output:
        output = output.split("=")[0].strip()
    if "is" in output:
        output = output.split("is")[1].strip()
    if "equals" in output:
        output = output.split("equals")[0].strip()
    if "evaluates to" in output:
        output = output.split("evaluates to")[0].strip()
    return output

def eval_for_GameOf24(input: str, output: str, d) -> bool:
    """
    Given an input and output, check if the output is correct and follows the rules of the game.
    """
    clean_output = clean_output_for_GameOf24(output)
    try:
        # Get the value of the expression using eval
        value = eval(clean_output)
        if not (abs(value - 24) < 1e-3):
            return False
        # Split the input and output digits by space
        input_digits = input.split(" ")
        # Replace the following symbols with space
        replacements = ["+", "-", "*", "/", "÷", "(", ")"]
        for symbol in replacements:
            clean_output = clean_output.replace(symbol, " ")
        # Replace multiple spaces with single space
        clean_output = re.sub(" +", " ", clean_output)
        clean_output = clean_output.strip()
        output_digits = clean_output.split(" ")
        # Sort the digits
        input_digits.sort()
        output_digits.sort()
        # Check if the digits are the same
        if input_digits != output_digits:
            return False
        return True
    except Exception as e:
        d['error'] += 1
        return False

def remove_punctuation(output: str) -> str:
    markers = [",", ";", ":", ".", '"']
    for marker in markers:
        output = output.replace(marker, "")
    return output


def convert_newline_to_space(output: str) -> str:
    output = output.replace("\n", " ")
    return output


def eval_for_exact_matching_with_no_punctuation(
    input: str, output: str, target: str
) -> bool:
    output = remove_punctuation(output)
    output = convert_newline_to_space(output)
    if target == output:
        return True
    return False

def eval_for_Sonnet(output: str, rhyme_scheme: str, d) -> bool:
    scheme, words = rhyme_scheme.split(',')
    words = [w.strip() for w in words.split()]
    try:
        errors = sonnet_errors(output, rhyme_scheme)
        allin = all(w in output for w in words)
        if not errors and allin:
            return True
        return False
    except Exception as e:
        d['error'] += 1
        return False

def extract_field(text, field_name):
    # Regular expressions to match the "Answer:" and "SAT:" sections
    match = re.search(rf'{field_name}:\s*(.*?)\s*$', text, re.DOTALL)
    # Extract and clean up the results
    answer = match.group(1).strip() if match else ""
    return answer
    
def peel_py_mk_wrapper(text, return_num=1):
    # Define the possible markers indicating the start and end of code blocks
    code_start_identifiers = ["```python", "```Python", "```"]
    code_end_identifier = "```"

    # Initialize an empty list to hold all code snippets
    code_snippets = []
    num_blocks = 0

    # Start searching for code blocks from the beginning of the text
    search_start = 0

    while True:
        code_block_start = -1
        selected_start_marker = None
        
        # Locate the start of the next Python code block
        for start_marker in code_start_identifiers:
            code_block_start = text.lower().find(start_marker.lower(), search_start)
            if code_block_start != -1:
                selected_start_marker = start_marker
                break

        # Break the loop if no more code blocks are found
        if code_block_start == -1:
            break

        # Locate the end of the code block
        code_block_end = text.find(code_end_identifier, code_block_start + len(selected_start_marker))
        
        # If the end marker is not found, assume the code runs till the end of the text
        if code_block_end == -1:
            code_block_end = len(text)
        
        # Extract the code snippet
        code_content = text[code_block_start + len(selected_start_marker):code_block_end].strip()
        
        # Clean the code content by removing any leftover markers
        for marker in code_start_identifiers:
            code_content = code_content.replace(marker, "")
        code_content = code_content.replace(code_end_identifier, "").strip()

        # Add the cleaned code snippet to the list
        code_snippets.append(f"from typing import *\n\n{code_content}")
        num_blocks += 1

        # Move the search start position to after the current code block
        search_start = code_block_end + len(code_end_identifier)
    
    if return_num == 1: # compatible
        if num_blocks > 0:
            return code_snippets[0]
        else:
            return ""
    if return_num == -1:
        return code_snippets, num_blocks
    elif num_blocks > return_num:
        return code_snippets[:return_num], num_blocks
    else:
        return code_snippets + [""] * (return_num - num_blocks), num_blocks

def execute_py_code(code_content, max_line_no=10000):

    # Proceed only if a code block is found
    if len(code_content):
        
        # Redirect the output stream to capture print statements
        saved_stdout = sys.stdout
        output_stream = io.StringIO()
        sys.stdout = output_stream
        
        try:
            exec(code_content, globals())
        except Exception as e:
            sys.stdout = saved_stdout
            # Capture the full traceback
            tb = traceback.extract_tb(e.__traceback__)
            codes = code_content.split('\n')
            # Find the traceback entry corresponding to the code_content
            tb.reverse()
            for frame in tb:
                if frame.filename == "<string>":
                    # Return the line number and error message
                    #print(code_content)
                    #print(f"An error occurred at line {frame.lineno} of code \"{codes[int(frame.lineno)-1]}\": {str(e)}")
                    if frame.lineno > max_line_no:
                        #print("MAX LINE")
                        #print(code_content)
                        return "An error occurred", code_content
                    return f"An error occurred at line {frame.lineno} of code \"{codes[int(frame.lineno)-1]}\": {str(e)}", code_content
            
            # If not found, return the full traceback
            return f"An error occurred: {str(e)}", code_content
        
        # Restore the original stdout and fetch the captured output
        sys.stdout = saved_stdout
        return output_stream.getvalue(), code_content
    return "No Python code block found in the provided input.", ""


def extract_answer(text):
    # Regular expression pattern to detect the answer in a consistent format
    # The pattern is flexible enough to handle different spacing and line breaks
    answer_pattern = re.compile(r"Answer:\s*(.*?)\s*$", re.DOTALL)

    # Search for the answer using the defined pattern
    found_match = answer_pattern.search(text)

    # Return the extracted answer if found, otherwise return None
    return found_match.group(1).strip() if found_match else None


if __name__ == '__main__':
    mk = """
abc
```python
1+1
```
abc
```python
666
```    
"""
    print(peel_py_mk_wrapper(mk))
    print(peel_py_mk_wrapper(mk, -1))