import re
import math
import numpy as np
import random



def extract_code_blocks(text, block_type):
    # Regular expression to match ```python CODE ```
    if block_type == 'python':
        code_block_pattern = re.compile(r'```python(.*?)```', re.DOTALL)
    elif block_type == 'hypothesis':
        code_block_pattern = re.compile(r'```hypothesis(.*?)```', re.DOTALL)
    elif block_type == 'json':
        code_block_pattern = re.compile(r'```json(.*?)```', re.DOTALL)
    elif block_type == 'obs':
        code_block_pattern = re.compile(r'```obs(.*?)```', re.DOTALL)
    elif block_type == 'answer':
        code_block_pattern = re.compile(r'```answer(.*?)```', re.DOTALL)
    elif block_type == '':
        # code_block_pattern = re.compile(r'```(.*?)```', re.DOTALL)
        code_block_pattern = re.compile(r'```(.*?)`', re.DOTALL)
        
    # Find all code blocks
    code_blocks = code_block_pattern.findall(text)
    # Strip leading/trailing whitespace from each code block
    code_blocks = [block.strip() for block in code_blocks]

    if block_type == '' and len(code_blocks) == 0:
        code_block_pattern = re.compile(r'``(.*?)`', re.DOTALL)
        code_blocks = code_block_pattern.findall(text)
        code_blocks = [block.strip() for block in code_blocks]
        
    code_blocks = [block.strip('`') for block in code_blocks]

    return code_blocks

def flatten_list(nested_list):
    flat_list = []
    for item in nested_list:
        if isinstance(item, list):
            flat_list.extend(flatten_list(item))
        else:
            flat_list.append(item)
    return flat_list

def reshape_list(input_list, dimensions):
    """
    Reshapes a flat list into a nested list with the given dimensions.

    Parameters:
    input_list (list): The list to be reshaped.
    dimensions (list): A list of dimensions specifying the new shape.

    Returns:
    list: A reshaped nested list.
    """
    if not dimensions:
        raise ValueError("Dimensions must be a non-empty list")
    
    total_elements = 1
    for dim in dimensions:
        total_elements *= dim
    if total_elements != len(input_list):
        raise ValueError("The total number of elements does not match the product of the dimensions")
    def build_nested_list(flat_list, dims):
        if len(dims) == 1:
            return flat_list[:dims[0]]
        else:
            size = len(flat_list) // dims[0]
            return [build_nested_list(flat_list[i * size:(i + 1) * size], dims[1:]) for i in range(dims[0])]
    return build_nested_list(input_list, dimensions)

def mbpp_assertion(out, exp, atol=0):
    if atol == 0 and is_floats(exp):
        atol = 1e-6
    if out == exp:
        return 1
    else:
        if atol != 0:
            if np.allclose(out, exp, rtol=1e-07, atol=atol):
                return 1
            else:
                return 0
        else:
            return 0