import re

from utils_proof.expression_parser import convert_expression
from utils_proof.find_functions import find_functions
from utils_proof.dissect_expression import dissect_expression

def is_num(s):
    try:
        float(s)
        return True
    except:
        try:
            [float(x) for x in s.split('/')]
            return True
        except:
            return False

def is_list(s):
    if '[' in s and '(' not in s:
        return True

def is_tuple(s):
    if '(' in s and ')' in s:
        return True

def is_variable(name):
    if not re.match(r'^[a-zA-Z_]', name):
        return False
    if not re.match(r'^[a-zA-Z_0-9]*$', name):
        return False
    return True

# def extract_parts(s):
#     pattern = r'^([^=]+)\s*=\s*\[([^]]+)\]\(([^)]+)\)$'
#     match = re.search(pattern, s)
#     if match:
#         return 'operation', (match.group(1).strip(), match.group(2).strip(), [m.strip() for m in match.group(3).split(',')])
#     elif '=' in s:
#         left, right = [x.strip() for x in s.split('=')]
#         if is_num(right)\
#             or is_list(right)\
#             or (right[0] == right[-1] == '"') or (right[0] == right[-1] == "'"):
#             return 'init', left
#         if is_variable(right):
#             return 'operation', (left, 'assigning a value', [right.strip()])
#         if is_tuple(right):
#             return 'operation', (left, 'assigning a tuple', [right.lstrip('(').rstrip(')')])

def extract_parts(s):
    pattern = r'^([^=]+)\s*=\s*\[([^]]+)\]\(([^)]+)\)$'
    match = re.search(pattern, s)
    if match:
        return 'operation', (match.group(1).strip(), match.group(2).strip(), [m.strip() for m in match.group(3).split(',')])
    elif '=' in s:
        left, right = [x.strip() for x in s.split('=')]
        if is_num(right)\
            or is_list(right)\
            or (right[0] == right[-1] == '"') or (right[0] == right[-1] == "'"):
            return 'init', left
        if is_variable(right):
            return 'operation', (left, 'assigning a value', [right.strip()])
        if is_tuple(right):
            return 'operation', (left, 'assigning a tuple', [right.lstrip('(').rstrip(')')])

def replace_variable_name(code, old_name, new_name):
    pattern = r'\b' + re.escape(old_name) + r'\b'
    new_code = re.sub(pattern, new_name, code)
    return new_code

def eliminate_loop(unpacked_statement_list):
    for i, s in enumerate(unpacked_statement_list):
        flag, output = extract_parts(s)
        if flag == 'operation':
            target, operation, sources = output
            if target in sources:
                for j, s_prime in enumerate(unpacked_statement_list[i+1:]):
                    s_prime = replace_variable_name(s_prime, target, target + '_NeW')
                    unpacked_statement_list[j+i+1] = s_prime
                unpacked_statement_list[i] = f"{target+'_NeW'} = [{operation}]({', '.join(sources)})"
    return unpacked_statement_list

def merge_lines(statements):
    new_statements = []
    tmp = statements[0]
    for s in statements[1:]:
        if '=' in s:
            new_statements.append(tmp)
            tmp = s
        elif s[:4] == 'for ':
            continue
        else:
            tmp += s.strip()
    new_statements.append(tmp)
    return new_statements

def unpack_statements(statement_list):
    unpacked_statement_list = []
    for s in statement_list:
        split_s = dissect_expression(s)
        if split_s:
            unpacked_statement_list += split_s
        else:
            unpacked_statement_list += [s]
    return unpacked_statement_list

def convert_to_function(statement_list):
    for i, s in enumerate(statement_list):
        right = s.split('=')[1].strip()
        if is_num(right) or is_list(right):
            continue
        if '+' in s or '-' in s or '/' in s or '*' in s:
            function_exps = find_functions(s)
            for j, exp in enumerate(function_exps):
                s = s.replace(exp, f'temp_value_{j}')
            if '+' in s or '-' in s or '/' in s or '*' in s:
                s = s.split('=')[0] + '=' + convert_expression(s.split('=')[1])
            for j, exp in enumerate(function_exps):
                s = s.replace(f'temp_value_{j}', exp)
            statement_list[i] = s
    return statement_list

def convert_function_calls(input_string):
    pattern = r'(\w+)\((.*?)\)'
    def replacer(match):
        function_name = match.group(1)
        arguments = match.group(2)
        return f'[{function_name}]({arguments})'
    result = re.sub(pattern, replacer, input_string)
    
    return result

def replace_equations(input_string):
    # Regular expression to find substrings enclosed in double quotes
    pattern = r'"[^"]*"'
    
    # Find all matches
    matches = re.findall(pattern, input_string)
    map = {}
    
    # Variable to store the modified string
    modified_string = input_string
    
    # Replace each match with the placeholder and increment the index
    for i, match in enumerate(matches, 1):
        placeholder = f"temp_equation_{i}"
        modified_string = modified_string.replace(match, placeholder, 1)
        map[placeholder] = match
    
    # Return the modified string and the list of original matches
    return modified_string, map