import re
import random
import copy

WORD2DIGIT = {
    # number
    "zero": 0,
    "one": 1,
    "two": 2,
    "three": 3,
    "four": 4,
    "five": 5,
    "six": 6,
    "seven": 7,
    "eight": 8,
    "nine": 9,
    "ten": 10,
    "eleven": 11,
    "twelve": 12,
    "thirteen": 13,
    "fourteen": 14,
    "fifteen": 15,
    "sixteen": 16,
    "seventeen": 17,
    "eighteen": 18,
    "nineteen": 19,
    "twenty": 20,
    "thirty": 30,
    "forty": 40,
    "fifty": 50,
    "sixty": 60,
    "seventy": 70,
    "eighty": 80,
    "ninety": 90,
    # times
    "once": 1,
    "twice": 2,
    # fraction
    "half": 0.5,
    # -th
    "first": 1,
    "second": 2,
    "third": 3,
    "fourth": 4,
    "fifth": 5,
    "sixth": 6,
    "seventh": 7,
    "eighth": 8,
    "ninth": 9,
    "tenth": 10,
    "eleventh": 11,
    "twelfth": 12,
    "thirteenth": 13,
    "fourteenth": 14,
    "fifteenth": 15,
    "sixteenth": 16,
    "seventeenth": 17,
    "eighteenth": 18,
    "nineteenth": 19,
    "twentieth": 20,
    "thirtieth": 30,
    "fortieth": 40,
    "fiftieth": 50,
    "sixtieth": 60,
    "seventieth": 70,
    "eightieth": 80,
    "ninetieth": 90,
}

# map a to b, return {'v_1': 16, v_2: 'three', 'v_3': 'four', 'v_4': 2}
def map_values(a: str, b: str) -> dict:
    # replace " an " with " a "
    a = a.replace(" an ", " a ")
    b = b.replace(" an ", " a ")

    # first let's make sure a and b a word level, like "Janet’s" -> "Janet ' s"
    b = re.sub(r"(\d+),(\d+)", r"\1\2", b)
    a = re.sub(r"([,!?$:])", r" \1 ", a)
    b = re.sub(r"([,!?$:])", r" \1 ", b)

    # replace "-" with " - "
    a = re.sub(r"(-)", r" - ", a)
    b = re.sub(r"(-)", r" - ", b)

    # replace "/" with " / " if it is not a fraction
    a = re.sub(r'(?<![0-9])/|/(?![0-9])', ' / ', a)
    b = re.sub(r'(?<![0-9])/|/(?![0-9])', ' / ', b)

    # we do not do ".", since it could be a decimal point
    # only tokenize "." when it only has one space after it
    # for example, we do "1." -> "1 .", but we do not do "1.2" -> "1 . 2", and we do not do ".3" -> ". 3"
    a = re.sub(r"(?<=\S)\.(?=\s)", " . ", a)
    b = re.sub(r"(?<=\S)\.(?=\s)", " . ", b)


    # if a,b have two spaces consecutively, remove one
    a = re.sub(r"(\s{2,})", r" ", a)
    b = re.sub(r"(\s{2,})", r" ", b)  

    v = []

    for i in range(len(b)):
        if i >= len(b) or i >= len(a):
            break
        if a[i] == b[i]:
            continue
        else:
            if a[i] == '{':
                # find the end of the placeholder
                end_a = a.find('}', i)
                # find the end of the number in b, it could be a " "
                end_b = b.find(' ', i)

                # some special cases
                # 1. a: {v_1}% and b:60%
                if a[end_a+1] == '%':
                    end_b -= 1   
                # 2. a: {v_1} times and b: twice
                if a[end_a+1:end_a+7] == ' times' and b[i:end_b] in ['once', 'twice']:
                    # remove ' times' in a
                    a = a[:end_a+1] + a[end_a+7:]
                # 3. a: {v_1} dollars and b: $ 1.2
                if a[end_a+1:end_a+9] == ' dollars' and b[i:end_b] == '$':
                    # remove ' dollars' in a 
                    a = a[:end_a+1] + a[end_a+9:]
                    # remove "$ " in b
                    b = b[:i] + b[i+2:]
                    # update the end_b
                    end_b = b.find(' ', i)
                
                    
                # get the placeholder
                placeholder = a[i+1:end_a]
                # get the number
                number = b[i:end_b]

                # update the placeholder in a with the number in b
                a = a[:i] + number + a[end_a+1:]

                # if the number start with ".", add a "0" before it
                if number[0] == '.':
                    number = '0' + number

                v.append(number)



                # update the index
                i = end_b
    
    # here are some cases that word numbers in v, like "three" in the example above
    
    
    for i in range(len(v)):
        # if the number is a word, like "three", map it to a digit
        if v[i] in WORD2DIGIT:
            v[i] = WORD2DIGIT[v[i]]
        # if the number is with %, like "16%", use the number without %, like "0.16" 
        elif v[i][-1] == '%':
            v[i] = float(v[i][:-1]) / 100

    return v

def calculate(formula:str) -> float:
    """
    formula: a string of a formula

    return: the result of the formula
    """

    # Split the formula string into a list of tokens (numbers and operators)
    tokens = []
    current_token = ""
    for char in formula:
        if char in "+-*/":
            tokens.append(current_token)
            tokens.append(char)
            current_token = ""
        else:
            current_token += char
    tokens.append(current_token)

    
    # Evaluate the expression using the order of operations
    result = float(tokens[0])
    i = 1
    while i < len(tokens):
        operator = tokens[i]
        operand = float(tokens[i+1])
        if operator == "+":
            result += operand
        elif operator == "-":
            result -= operand
        elif operator == "*":
            result *= operand
        elif operator == "/":
            result /= operand
        i += 2
    
    return result



def solve(v:list, answer):
    """
    v: list of variables
    answer: reasoning answer
    """

    # match all the content in << >>
    pattern = re.compile(r'<<(.+?)>>')
    formulas = pattern.findall(answer)

    r = []
    process = []

    for formula in formulas:

        # formula format: {v_1}+{v_2}/10 = {r_1}

        # left and right of the formula
        left, right = formula.split('=')

        # find all the variables in the formula and replace them with their values
        for i in range(len(v)):
            left = left.replace(f'{{v_{i+1}}}', str(v[i]))
        
        for i in range(len(r)):
           left = left.replace(f'{{r_{i+1}}}', str(r[i]))
        
        # calculate the result of the formula
        result = calculate(left)

        r.append(result)

        process.append(f'{left} = {result}')
    
    return {
        'v': v,
        'r': r,
        'process': process,
        'result': r[-1]
    }

# get a random number that can be 
def get_a_4_digit_number(demand:list):
    while True:
        # generate a random 4-digit number
        num = random.randint(1000, 9999)

        # get the factors of the number, do not include 1 and itself
        factors = [ i for i in range(2, num) if num % i == 0]

        if 1 in demand:
            one_digit_factors = [i for i in factors if i < 10]
            # if there is no one digit factor, generate another number
            if len(one_digit_factors) == 0:
                continue
            else:
                # remove 1 from demand
                demand.remove(1)


        if 2 in demand:
            two_digit_factors = [i for i in factors if i >= 10 and i < 100]
            # if there is no two digit factor, generate another number
            if len(two_digit_factors) == 0:
                continue
            else:
                # remove 2 from demand
                demand.remove(2)

        if 3 in demand:
            three_digit_factors = [i for i in factors if i >= 100 and i < 1000]
            # if there is no three digit factor, generate another number
            if len(three_digit_factors) == 0:
                continue
            else:
                # remove 3 from demand
                demand.remove(3)
     
        if 4 in demand:
            four_digit_factors = [i for i in factors if i >= 1000 and i < 10000]
            # if there is no four digit factor, generate another number
            if len(four_digit_factors) == 0:
                continue
            else:
                # remove 4 from demand
                demand.remove(4)

        # if there is no number that can meet the demand, means we meet all the demand
        if len(demand) == 0:
            break
    
    return {
        'num': num,
        '1': one_digit_factors,
        '2': two_digit_factors,
        '3': three_digit_factors,
        '4': four_digit_factors
    }



def levelup(dic:dict, strategy=1):
    q = dic['question']
    a = dic['answer']

    mask_v = []
    # some numbers can not be level up, like 80%, 1/4
    for i in range(len(q)):
        if q[i] == '{':
            # find the end of the placeholder
            end_q = q.find('}', i)
            
            # if the number is followed with "%", set mask_v to 0
            if q[end_q+1] == '%':
                mask_v.append(0)
            else:
                mask_v.append(1)
    
    original_v = dic['v']

    if strategy == 1: 

        factors = get_a_4_digit_number([1,2,3,4])
        # increase the value of v
        for i in range(len(original_v)):
            try:
                original_v[i] = float(original_v[i])

                if original_v[i] < 1:
                    mask_v[i] = 0
            except:
                mask_v[i] = 0


        levelup_v = copy.deepcopy(original_v)

        # get the decreasing order of orginal_v index
        ordered_index = sorted(
            range(len(original_v)), 
            key=lambda k: float(original_v[k]) if isinstance(original_v[k], str) and original_v[k].isnumeric() else 0 if isinstance(original_v[k], str) else original_v[k], reverse=True)


        flag = True
        for i in ordered_index:
            if mask_v[i] == 0:
                continue
            else:
                if flag:
                    levelup_v[i] = original_v[i] * factors['num']
                    flag = False
                else:
                    # if original_v[i] > 1000:
                    #     levelup_v[i] = original_v[i] * random.choice(factors['4'])
                    if original_v[i] > 100:
                        levelup_v[i] = original_v[i] * random.choice(factors['4'])
                    elif original_v[i] > 10:
                        levelup_v[i] = original_v[i] * random.choice(factors['3'])
                    else:
                        levelup_v[i] = original_v[i] * random.choice(factors['2'])

    elif strategy == 2:

        levelup_v = []

        for mask, v in zip(mask_v, original_v):
            if mask == 0:
                levelup_v.append(float(v))
                continue

            try:
                v = float(v)
                levelup_v.append(v**2)
            except:
                # convert the str to digit, like 1/6
                if '/' in v:
                    v = v.split('/')
                    levelup_v.append(f"{int(v[0])**2}/{int(v[1])**2}")
    
    elif strategy == 3:

        levelup_v = []

        for mask, v in zip(mask_v, original_v):
            if mask == 0:
                levelup_v.append(float(v))
                continue

            try:
                v = float(v)
                levelup_v.append(v**3)
            except:
                # convert the str to digit, like 1/6
                if '/' in v:
                    v = v.split('/')
                    levelup_v.append(f"{int(v[0])**3}/{int(v[1])**3}")
    
    for i in range(len(levelup_v)):
        if str(levelup_v[i]).endswith('.0') and '/' not in str(levelup_v[i]):
            levelup_v[i] = int(levelup_v[i])
        else:
            try:
                levelup_v[i] = float(levelup_v[i])
                levelup_v[i] = round(levelup_v[i], 2)
            except:
                pass

    # solve the problem with new v
    new_answer = solve(levelup_v, a)

    dic['enhanced_v'] = levelup_v
    dic['enhanced_process'] = new_answer['process']
    dic['enhanced_result'] = new_answer['result']

    return dic

