# The script is used to generate datasets.

import numpy as np
np.random.seed(42)
import random
# random.seed(42)
from prompt import Prompt
from copy import deepcopy
import re
from number_utils import generate_integer, generate_float, generate_fraction, generate_scientific
# from ..task import Task

    
def find_dot(result_int: str, pos: int, name_a: str) -> str:
    P = Prompt('find_dot')
    output = P.initialize.format(result_int, pos)
    cnt = 0
    result_dec = ''
    while cnt < pos:
        output += P.enter
        last_digit = result_int[-1] if result_int else '0'
        output += P.last_digit.format(last_digit)
        output += P.update_result.format(last_digit, result_dec, last_digit + result_dec, result_int[:-1])
        result_dec = last_digit + result_dec
        result_int = result_int[:-1]
        output += P.update_cnt.format(cnt, cnt+1)
        cnt += 1
    output += P.not_enter
    result_int = result_int.lstrip('0') or '0'
    result_dec = result_dec.rstrip('0') or '0'
    output += P.strip.format(result_int, result_dec)
    result = result_int + '.' + result_dec
    output += P.merge.format(result_int, result_dec, result)
    output += P.ret.format(result)
    output = output.replace('_var_res_', name_a)
    return output, result

def get_digit_Integer_int_int(num: str, pos: str) -> int:
    pos = int(pos)
    P = Prompt('get_digit_Integer_int_int')
    output = P.initialize.format(num, pos)
    cnt = 0
    while cnt < pos:
        output += P.enter
        output += P.update_num.format(num[1:])
        num = num[1:]
        output += P.update_cnt.format(cnt, cnt+1)
        cnt += 1
    output += P.not_enter
    output += P.ret.format(num[0])
    return output, num[0]

def get_digit_Float_int_int(num: str, pos: str) -> int:
    pos = int(pos)
    P = Prompt('get_digit_Float_int_int')
    output = P.initialize.format(num, pos)
    cnt = 0
    num = num.replace('.', '') 
    output += P.remove_dot.format(num)
    while cnt < pos:
        output += P.enter
        output += P.update_num.format(num[1:])
        num = num[1:]
        output += P.update_cnt.format(cnt, cnt+1)
        cnt += 1
    output += P.not_enter
    output += P.ret.format(num[0])
    return output, num[0]
    
    
def add_integer_integer_integer(num1: str, num2: str, name_a: str, name_b: str, name_c: str) -> str:
    P = Prompt('add_Integer_Integer_Integer')
    output = P.initialize.format(num1, num2)
    result = ''
    carry = 0
    while num1 or num2:
        output += P.enter.format(num1, num2)
        output += P.last_digit.format(num1, num2, num1[-1] if num1 else 0, num2[-1] if num2 else 0)
        digit1 = int(num1[-1]) if num1 else 0
        digit2 = int(num2[-1]) if num2 else 0
        output += P.sum.format(digit1, digit2, carry, digit1, digit2, carry, digit1 + digit2 + carry)
        total = digit1 + digit2 + carry
        output += P.update_result.format(result, total, total%10, str(total%10), result, str(total%10) + result, carry, total//10)
        output += P.update_carry.format(total, total//10, total//10)
        result = str(total%10) + result
        carry = total//10
        output += P.update_nums.format(num1, num2, num1[:-1] if num1 else num1, num2[:-1] if num2 else num2)
        num1 = num1[:-1] if num1 else num1
        num2 = num2[:-1] if num2 else num2
    output += P.not_enter

    if carry:
        output += P.carry_true.format(carry, result, str(carry), result, str(carry) + result)
        result = str(carry) + result
    else:
        output += P.carry_false.format(carry)
    output += P.lstrip.format(result, result.lstrip('0') or '0')
    result = result.lstrip('0') or '0'
    output += P.ret.format(result)
    output = output.replace('_var1_', name_a).replace('_var2_', name_b).replace('_var_res_', name_c)
    return output, result


def add_float_float_float(num1: str, num2: str, name_a: str, name_b: str, name_c: str) -> str:
    P = Prompt('add_Float_Float_Float')
    int1, dec1 = num1.split('.')
    int2, dec2 = num2.split('.')
    len1 = len(dec1)
    len2 = len(dec2)
    output = P.initialize.format(num1, num2, int1, dec1, int2, dec2, len1, len2)
    output += P.condition1.format(len1, '<' if len1 < len2 else '>' if len1 > len2 else '==', len2)
    if len1 < len2:
        output += P.if1
        while len1 < len2:
            output += P.enter1
            output += P.update_vars1.format(dec1, dec1 + '0', len1, len1+1)
            dec1 += '0'
            len1 += 1
        output += P.out1
    elif len1 > len2:
        output += P.if2
        while len1 > len2:
            output += P.enter2
            output += P.update_vars2.format(dec2, dec2 + '0', len2, len2+1)
            dec2 += '0'
            len2 += 1
        output += P.out2
    else:
        output += P.not_branch
    full1 = int1 + dec1
    full2 = int2 + dec2
    output += P.full.format(int1, dec1, full1, int2, dec2, full2)
    
    output += P.add_integer
    new_output, result = add_integer_integer_integer(full1, full2, 'add1', 'add2', 'result')
    output += new_output
    output += P.exit_function.format(result, result)
    
    output += P.find_dot.format(result, len1)
    new_output, new_result = find_dot(result, int(len1), 'new_result')
    output += new_output
    output += P.exit_function2.format(new_result, new_result)
    output += P.ret.format(new_result)
    output = output.replace('_var1_', name_a).replace('_var2_', name_b).replace('_var_res_', name_c)
    return output, new_result
    
def sub_float_float_float(num1: str, num2: str, name_a: str, name_b: str, name_c: str) -> str:
    P = Prompt('sub_Float_Float_Float')
    int1, dec1 = num1.split('.')
    int2, dec2 = num2.split('.')
    len1 = len(dec1)
    len2 = len(dec2)
    output = P.initialize.format(num1, num2, int1, dec1, int2, dec2, len1, len2)
    output += P.condition1.format(len1, '<' if len1 < len2 else '>' if len1 > len2 else '==', len2)
    if len1 < len2:
        output += P.if1
        while len1 < len2:
            output += P.enter1
            output += P.update_vars1.format(dec1, dec1 + '0', len1, len1+1)
            dec1 += '0'
            len1 += 1
        output += P.out1
    elif len1 > len2:
        output += P.if2
        while len1 > len2:
            output += P.enter2
            output += P.update_vars2.format(dec2, dec2 + '0', len2, len2+1)
            dec2 += '0'
            len2 += 1
        output += P.out2
    else:
        output += P.not_branch
    full1 = int1 + dec1
    full2 = int2 + dec2
    output += P.full.format(int1, dec1, full1, int2, dec2, full2)
    
    output += P.sub_integer
    new_output, result = sub_integer_integer_integer(full1, full2, 'full1', 'full2', 'result')
    output += new_output
    output += P.exit_function.format(result, result)
    
    output += P.find_dot.format(result, len1)
    new_output, new_result = find_dot(result, int(len1), 'new_result')
    output += new_output
    output += P.exit_function2.format(new_result, new_result)
    output += P.ret.format(new_result)
    output = output.replace('_var1_', name_a).replace('_var2_', name_b).replace('_var_res_', name_c)
    return output, new_result

def sub_integer_integer_integer(num1: str, num2: str, name_a: str, name_b: str, name_c: str) -> str:
    P = Prompt('sub_Integer_Integer_Integer')
    output = P.initialize.format(num1, num2)
    result = ''
    borrow = 0
    while num1 or num2:
        output += P.enter
        output += P.last_digit.format(num1, num2, num1[-3:] if num1 else 0, num2[-3:] if num2 else 0)
        digit1 = int(num1[-3:]) if num1 else 0
        digit2 = int(num2[-3:]) if num2 else 0
        output += P.sub.format(borrow, digit1, digit2, borrow, digit1 - digit2 - borrow)
        total = digit1 - digit2 - borrow
        if total < 0:
            output += P.borrow_true.format(total, total, total+1000)
            total += 1000
            borrow = 1
        else:
            output += P.borrow_false.format(total)
            borrow = 0
        output += P.update_result.format(total, result, ('00' + str(total))[-3:] + result)
        result = ('00' + str(total))[-3:] + result

        output += P.update_nums.format(num1[:-3] if num1 else num1, num2[:-3] if num2 else num2)
        num1 = num1[:-3] if num1 else num1
        num2 = num2[:-3] if num2 else num2
    output += P.not_enter
    result = result.lstrip('0') or '0'
    output += P.lstrip.format(result)
    output += P.ret.format(result)
    output = output.replace('num1', name_a).replace('num2', name_b).replace('result', name_c)
    return output, result

def multiply_integer_integer_integer(num1: str, num2: str, name_a: str, name_b: str, name_c: str) -> str:
    P = Prompt('multiply_Integer_Integer_Integer')
    output = P.initialize.format(num1, num2)
    result = 0
    base = 0
    while num1:
        output += P.enter
        output += P.last_digit.format(num1[-1])
        digit1 = int(num1[-1])
        temp = int(num2) * digit1
        output += P.multiply.format(num2, digit1, temp)
        
        # add temp to result
        output += P.update_temp.format(temp, base, temp * 10 ** base, base, base+1)
        temp *= 10 ** base
        base += 1
        output += P.add_temp.format(result, temp, result + temp)
        result += temp
        output += P.update_result_num.format(num1[:-1])
        num1 = num1[:-1]
    output += P.not_enter
    output += P.ret.format(result)
    output = output.replace('num1', name_a).replace('num2', name_b).replace('result', name_c)
    return output, str(result)

def multiply_float_float_float(num1: str, num2: str, name_a: str, name_b: str, name_c: str) -> str:
    P = Prompt('multiply_Float_Float_Float')
    int1, dec1 = num1.split('.')
    int2, dec2 = num2.split('.')
    len1 = len(dec1)
    len2 = len(dec2)
    output = P.initialize.format(num1, num2, int1, dec1, int2, dec2, len1, len2)
    full1 = int1 + dec1
    full2 = int2 + dec2
    output += P.full.format(int1, dec1, full1, int2, dec2, full2)
    
    output += P.multiply_integer
    new_output, result = multiply_integer_integer_integer(full1, full2, 'mul1', 'mul2', 'result')
    output += new_output
    output += P.exit_function.format(result, result)
    
    output += P.find_dot.format(result, len1)
    new_output, new_result = find_dot(result, int(len1) + int(len2), 'new_result')
    output += new_output
    output += P.exit_function2.format(new_result, new_result)
    output += P.ret.format(new_result)
    output = output.replace('_var1_', name_a).replace('_var2_', name_b).replace('_var_res_', name_c)
    return output, new_result


def digit_max_integer_integer_integer(num1: str, num2: str) -> str:
    P = Prompt('digit_max_Integer_Integer_Integer')
    output = P.initialize.format(num1, num2)
    result = ''
    output = ''
    while num1 and num2:
        output += P.enter
        output += P.last_digit.format(num1, num2, num1[-1], num2[-1])
        digit1 = num1[-1]
        digit2 = num2[-1]
        output += P.max.format(max(digit1, digit2), digit1, digit2, max(digit1, digit2) + result)
        result = str(max(digit1, digit2)) + result
        output += P.update_nums.format(num1[:-1], num2[:-1])
        num1 = num1[:-1]
        num2 = num2[:-1]
    output += P.not_enter
    output += P.update_rest.format(num1, num2, result, num1 + num2 + result)
    result = num1 + num2 + result
    output += P.ret.format(result)
    return output, result

def length_Integer_none_int(num: str) -> int:
    P = Prompt('length_Integer_none_int')
    output = P.initialize.format(num)
    result = 0
    output = ''
    while num:
        output += P.enter
        output += P.update_result.format(result, result+1)
        result += 1
        output += P.update_nums.format(num[1:])
        num = num[1:]
    output += P.not_enter
    output += P.ret.format(result)
    return output, result

def floordiv_integer_integer_integer(num1: str, num2: str, name_a: str, name_b: str, name_c: str) -> str:
    P = Prompt('floordiv_Integer_Integer_Integer')
    output = P.initialize.format(num1, num2)
    num2 = int(num2)
    result = ''
    num_now = 0
    while num1:
        output += P.enter
        output += P.update_num_now.format(num_now, num1[0], num_now * 10 + int(num1[0]))
        num_now = num_now * 10 + int(num1[0])
        output += P.digit_init
        digit = 0
        while num_now and int(num_now) >= int(num2):
            output += P.enter_digit
            output += P.update_num_now2.format(num_now, num2, num_now-num2)
            num_now -= num2
            output += P.update_digit.format(digit, digit+1)
            digit += 1
        output += P.not_enter_digit
        output += P.update_result.format(result, str(digit), result + str(digit))
        result += str(digit)
        output += P.update_nums.format(num1[1:])
        num1 = num1[1:]
    output += P.not_enter
    result = result.lstrip('0') or '0'
    output += P.lstrip.format(result)
    output += P.ret.format(result)
    output = output.replace('_var1_', name_a).replace('_var2_', name_b).replace('_var_res_', name_c)
    return output, result

def brief_floordiv_integer_integer_integer(num1: str, num2: str, name_a: str, name_b: str, name_c: str) -> str:
    P = Prompt('floordiv_Integer_Integer_Integer')
    output = P.initialize.format(num1, num2)
    num2 = int(num2)
    result = ''
    num_now = 0
    while num1:
        output += P.enter
        output += P.update_num_now.format(num_now, num1[0], num_now * 10 + int(num1[0]))
        num_now = num_now * 10 + int(num1[0])
        output += P.brief.format(num_now, num2, num_now // num2, num_now, num2, num_now % num2)
        digit = num_now // num2
        num_now = num_now % num2
        output += P.update_result.format(result, str(digit), result + str(digit))
        result += str(digit)
        output += P.update_nums.format(num1[1:])
        num1 = num1[1:]
    output += P.not_enter
    result = result.lstrip('0') or '0'
    output += P.lstrip.format(result)
    output += P.ret.format(result)
    output = output.replace('_var1_', name_a).replace('_var2_', name_b).replace('_var_res_', name_c)
    return output, result

def mod_integer_integer_integer(num1: str, num2: str, name_a: str, name_b: str) -> str:
    P = Prompt('mod_Integer_Integer_Integer')
    output = P.initialize.format(num1, num2)
    
    output += P.calc_q
    new_output, q = brief_floordiv_integer_integer_integer(num1, num2, 'div1', 'div2', 'q')
    output += new_output
    output += P.exit_function1.format(q)
    
    output += P.calc_temp
    new_output, temp = multiply_integer_integer_integer(q, num2, 'mul1', 'mul2', 'temp')
    output += new_output
    output += P.exit_function2.format(temp)
    
    output += P.calc_r
    new_output, r = sub_integer_integer_integer(num1, temp, 'sub1', 'sub2', 'r')
    output += new_output
    output += P.exit_function3.format(r)
    
    output += P.ret.format(r)
    output = output.replace('_var1_', name_a).replace('_var2_', name_b)
    return output, r
    
def gcd_integer_integer_integer(num1: str, num2: str, name_a: str, name_b: str) -> str:
    output = '''the greatest common divisor of {} and {} is {}
'''.format(num1, num2, np.gcd(int(num1), int(num2)))
    return output, str(np.gcd(int(num1), int(num2)))

class Dataset_Generator:
    def __init__(self) -> None:
        self.name = 'add_Integer_Integer_Integer'
    
    def gen_data_from_len(self, length: int) -> dict:
        num1 = generate_integer(length)
        num2 = generate_integer(random.randint(1, length))
        if random.random() < 0.5:
            num1, num2 = num2, num1
        return 'Add two numbers: {} + {}'.format(num1, num2)
        
    def rfft_IO(self, q_str) -> dict:
        '''
        return rfft input-output of given data
        '''
        P = Prompt(self.name)
        rule = P.rule
        instruction = "Follow the given rule to solve the question.\nrule:\n"
        num1, num2 = re.findall(r'(-?\d+)', q_str.split(':')[-1])
        input = instruction + rule + "\n\nQ: " + q_str + '\n'
        input = input.replace('_var1_', 'num1').replace('_var2_', 'num2').replace('_var_res_', 'result')
        # rfft output
        output, answer = add_integer_integer_integer(num1, num2, 'num1', 'num2', 'result')
        output += P.answer.format(answer)
            # assert answer == expected, f'expected: {expected}, got: {answer}'
       
        return {"input": input,
                "output": output,
                "answer": answer}
    
    def cot_IO(self, q_str):
        def extract_answer(rationale):
            answer = rationale.split("\n")[-1].split(",")[1]
            answer = re.sub(" ", "", answer)
            return answer
        def count_c(a_digit, b_digit, c):
            a_digit = int(a_digit)
            b_digit = int(b_digit)
            c = int(c)
            if a_digit + b_digit + c>= 10:
                return 1
            else:
                return 0
        a, b = re.findall(r'(-?\d+)', q_str.split(':')[-1])
        gt = int(a) + int(b)
        a_digits = [digit for digit in str(a)]
        b_digits = [digit for digit in str(b)]
        rationale = ""
        answer = ""
        c = 0
        for i in range(len(str(int(a)+int(b)))+1):
            line = f"{''.join(a_digits)} + {''.join(b_digits)}, {answer}, C: {c}\n"
            a_digit = a_digits[-1] if a_digits else 0
            b_digit = b_digits[-1] if b_digits else 0
            if i != len(str(int(a)+int(b))):
                adding = f"# added {a_digit} + {b_digit} + {c} = {str(int(a_digit) + int(b_digit) + c)[-1]}\n"
                line += adding
            rationale += line
            if a_digits and b_digits:
                answer = str(int(a_digits[-1]) + int(b_digits[-1]) + c)[-1] + answer
                c = count_c(a_digits[-1], b_digits[-1], c)
                a_digits.pop()
                b_digits.pop()
            elif a_digits:
                answer = str(int(a_digits[-1]) + c)[-1] + answer
                c = count_c(a_digits[-1], 0, c)
                a_digits.pop()
            elif b_digits:
                answer = str(int(b_digits[-1]) + c)[-1] + answer
                c = count_c(0, b_digits[-1], c)
                b_digits.pop()
            else:
                if c:
                    answer = str(c) + answer
                c = 0
        rationale = rationale.strip()
        assert int(extract_answer(rationale)) == int(gt)
        return {"input": q_str,
                "output": f"{rationale}\nSo the answer is {gt}",
                "answer": str(gt)}
    
if __name__ == "__main__":
    pass
