# The script is used to generate datasets.

import numpy as np
np.random.seed(42)
import random
# random.seed(42)
from prompt import Prompt
from copy import deepcopy
import re
from number_utils import generate_integer, generate_float, generate_fraction, generate_scientific
# from ..task import Task

    
def find_dot(result_int: str, pos: int, name_a: str) -> str:
    P = Prompt('find_dot')
    output = P.initialize.format(result_int, pos)
    cnt = 0
    result_dec = ''
    while cnt < pos:
        output += P.enter
        last_digit = result_int[-1] if result_int else '0'
        output += P.last_digit.format(last_digit)
        output += P.update_result.format(last_digit, result_dec, last_digit + result_dec, result_int[:-1])
        result_dec = last_digit + result_dec
        result_int = result_int[:-1]
        output += P.update_cnt.format(cnt, cnt+1)
        cnt += 1
    output += P.not_enter
    result_int = result_int.lstrip('0') or '0'
    result_dec = result_dec.rstrip('0') or '0'
    output += P.strip.format(result_int, result_dec)
    result = result_int + '.' + result_dec
    output += P.merge.format(result_int, result_dec, result)
    output += P.ret.format(result)
    output = output.replace('_var_res_', name_a)
    return output, result

def get_digit_Integer_int_int(num: str, pos: str) -> int:
    pos = int(pos)
    P = Prompt('get_digit_Integer_int_int')
    output = P.initialize.format(num, pos)
    cnt = 0
    while cnt < pos:
        output += P.enter
        output += P.update_num.format(num[1:])
        num = num[1:]
        output += P.update_cnt.format(cnt, cnt+1)
        cnt += 1
    output += P.not_enter
    output += P.ret.format(num[0])
    return output, num[0]

def get_digit_Float_int_int(num: str, pos: str) -> int:
    pos = int(pos)
    P = Prompt('get_digit_Float_int_int')
    output = P.initialize.format(num, pos)
    cnt = 0
    num = num.replace('.', '') 
    output += P.remove_dot.format(num)
    while cnt < pos:
        output += P.enter
        output += P.update_num.format(num[1:])
        num = num[1:]
        output += P.update_cnt.format(cnt, cnt+1)
        cnt += 1
    output += P.not_enter
    output += P.ret.format(num[0])
    return output, num[0]
    
    
def add_integer_integer_integer(num1: str, num2: str, name_a: str, name_b: str, name_c: str) -> str:
    P = Prompt('add_Integer_Integer_Integer')
    output = P.initialize.format(num1, num2)
    result = ''
    carry = 0
    while num1 or num2:
        output += P.enter
        output += P.last_digit.format(num1, num2, num1[-3:] if num1 else 0, num2[-3:] if num2 else 0)
        digit1 = int(num1[-3:]) if num1 else 0
        digit2 = int(num2[-3:]) if num2 else 0
        output += P.sum.format(carry, digit1, digit2, carry, digit1 + digit2 + carry)
        total = digit1 + digit2 + carry
        output += P.update_result.format(total, total%1000, ('00' + str(total%1000))[-3:], result, ('00' + str(total%100))[-3:] + result, carry, total//1000)
        result = ('00' + str(total%1000))[-3:] + result
        carry = total//1000
        output += P.update_nums.format(num1[:-3] if num1 else num1, num2[:-3] if num2 else num2)
        num1 = num1[:-3] if num1 else num1
        num2 = num2[:-3] if num2 else num2
    output += P.not_enter

    if carry:
        output += P.carry_true.format(carry, str(carry), result, str(carry) + result)
        result = str(carry) + result
    else:
        output += P.carry_false.format(carry)
    result = result.lstrip('0') or '0'
    output += P.lstrip.format(result)
    output += P.ret.format(result)
    output = output.replace('_var1_', name_a).replace('_var2_', name_b).replace('_var_res_', name_c)
    return output, result


def add_float_float_float(num1: str, num2: str, name_a: str, name_b: str, name_c: str) -> str:
    P = Prompt('add_Float_Float_Float')
    int1, dec1 = num1.split('.')
    int2, dec2 = num2.split('.')
    len1 = len(dec1)
    len2 = len(dec2)
    output = P.initialize.format(num1, num2, int1, dec1, int2, dec2, len1, len2)
    output += P.condition1.format(len1, '<' if len1 < len2 else '>' if len1 > len2 else '==', len2)
    if len1 < len2:
        output += P.if1
        while len1 < len2:
            output += P.enter1
            output += P.update_vars1.format(dec1, dec1 + '0', len1, len1+1)
            dec1 += '0'
            len1 += 1
        output += P.out1
    elif len1 > len2:
        output += P.if2
        while len1 > len2:
            output += P.enter2
            output += P.update_vars2.format(dec2, dec2 + '0', len2, len2+1)
            dec2 += '0'
            len2 += 1
        output += P.out2
    else:
        output += P.not_branch
    full1 = int1 + dec1
    full2 = int2 + dec2
    output += P.full.format(int1, dec1, full1, int2, dec2, full2)
    
    output += P.add_integer
    new_output, result = add_integer_integer_integer(full1, full2, 'add1', 'add2', 'result')
    output += new_output
    output += P.exit_function.format(result, result)
    
    output += P.find_dot.format(result, len1)
    new_output, new_result = find_dot(result, int(len1), 'new_result')
    output += new_output
    output += P.exit_function2.format(new_result, new_result)
    output += P.ret.format(new_result)
    output = output.replace('_var1_', name_a).replace('_var2_', name_b).replace('_var_res_', name_c)
    return output, new_result
    
def sub_float_float_float(num1: str, num2: str, name_a: str, name_b: str, name_c: str) -> str:
    P = Prompt('sub_Float_Float_Float')
    int1, dec1 = num1.split('.')
    int2, dec2 = num2.split('.')
    len1 = len(dec1)
    len2 = len(dec2)
    output = P.initialize.format(num1, num2, int1, dec1, int2, dec2, len1, len2)
    output += P.condition1.format(len1, '<' if len1 < len2 else '>' if len1 > len2 else '==', len2)
    if len1 < len2:
        output += P.if1
        while len1 < len2:
            output += P.enter1
            output += P.update_vars1.format(dec1, dec1 + '0', len1, len1+1)
            dec1 += '0'
            len1 += 1
        output += P.out1
    elif len1 > len2:
        output += P.if2
        while len1 > len2:
            output += P.enter2
            output += P.update_vars2.format(dec2, dec2 + '0', len2, len2+1)
            dec2 += '0'
            len2 += 1
        output += P.out2
    else:
        output += P.not_branch
    full1 = int1 + dec1
    full2 = int2 + dec2
    output += P.full.format(int1, dec1, full1, int2, dec2, full2)
    
    output += P.sub_integer
    new_output, result = sub_integer_integer_integer(full1, full2, 'full1', 'full2', 'result')
    output += new_output
    output += P.exit_function.format(result, result)
    
    output += P.find_dot.format(result, len1)
    new_output, new_result = find_dot(result, int(len1), 'new_result')
    output += new_output
    output += P.exit_function2.format(new_result, new_result)
    output += P.ret.format(new_result)
    output = output.replace('_var1_', name_a).replace('_var2_', name_b).replace('_var_res_', name_c)
    return output, new_result

def sub_integer_integer_integer(num1: str, num2: str, name_a: str, name_b: str, name_c: str) -> str:
    P = Prompt('sub_Integer_Integer_Integer')
    output = P.initialize.format(num1, num2)
    result = ''
    borrow = 0
    while num1 or num2:
        output += P.enter
        output += P.last_digit.format(num1, num2, num1[-3:] if num1 else 0, num2[-3:] if num2 else 0)
        digit1 = int(num1[-3:]) if num1 else 0
        digit2 = int(num2[-3:]) if num2 else 0
        output += P.sub.format(borrow, digit1, digit2, borrow, digit1 - digit2 - borrow)
        total = digit1 - digit2 - borrow
        if total < 0:
            output += P.borrow_true.format(total, total, total+1000)
            total += 1000
            borrow = 1
        else:
            output += P.borrow_false.format(total)
            borrow = 0
        output += P.update_result.format(total, result, ('00' + str(total))[-3:] + result)
        result = ('00' + str(total))[-3:] + result

        output += P.update_nums.format(num1[:-3] if num1 else num1, num2[:-3] if num2 else num2)
        num1 = num1[:-3] if num1 else num1
        num2 = num2[:-3] if num2 else num2
    output += P.not_enter
    result = result.lstrip('0') or '0'
    output += P.lstrip.format(result)
    output += P.ret.format(result)
    output = output.replace('num1', name_a).replace('num2', name_b).replace('result', name_c)
    return output, result

def multiply_integer_integer_integer(num1: str, num2: str, name_a: str, name_b: str, name_c: str) -> str:
    P = Prompt('multiply_Integer_Integer_Integer')
    output = P.initialize.format(num1, num2)
    result = 0
    base = 0
    while num1:
        output += P.enter
        output += P.last_digit.format(num1[-1])
        digit1 = int(num1[-1])
        temp = int(num2) * digit1
        output += P.multiply.format(num2, digit1, temp)
        
        # add temp to result
        output += P.update_temp.format(temp, base, temp * 10 ** base, base, base+1)
        temp *= 10 ** base
        base += 1
        output += P.add_temp.format(result, temp, result + temp)
        result += temp
        output += P.update_result_num.format(num1[:-1])
        num1 = num1[:-1]
    output += P.not_enter
    output += P.ret.format(result)
    output = output.replace('num1', name_a).replace('num2', name_b).replace('result', name_c)
    return output, str(result)

def multiply_float_float_float(num1: str, num2: str, name_a: str, name_b: str, name_c: str) -> str:
    P = Prompt('multiply_Float_Float_Float')
    int1, dec1 = num1.split('.')
    int2, dec2 = num2.split('.')
    len1 = len(dec1)
    len2 = len(dec2)
    output = P.initialize.format(num1, num2, int1, dec1, int2, dec2, len1, len2)
    full1 = int1 + dec1
    full2 = int2 + dec2
    output += P.full.format(int1, dec1, full1, int2, dec2, full2)
    
    output += P.multiply_integer
    new_output, result = multiply_integer_integer_integer(full1, full2, 'mul1', 'mul2', 'result')
    output += new_output
    output += P.exit_function.format(result, result)
    
    output += P.find_dot.format(result, len1)
    new_output, new_result = find_dot(result, int(len1) + int(len2), 'new_result')
    output += new_output
    output += P.exit_function2.format(new_result, new_result)
    output += P.ret.format(new_result)
    output = output.replace('_var1_', name_a).replace('_var2_', name_b).replace('_var_res_', name_c)
    return output, new_result


def digit_max_integer_integer_integer(num1: str, num2: str) -> str:
    P = Prompt('digit_max_Integer_Integer_Integer')
    output = P.initialize.format(num1, num2)
    result = ''
    while num1 and num2:
        output += P.enter.format(num1, num2)
        output += P.last_digit.format(num1, num2, num1[-1], num2[-1])
        digit1 = num1[-1]
        digit2 = num2[-1]
        output += P.max.format(digit1, digit2, max(digit1, digit2), result, max(digit1, digit2), result, max(digit1, digit2) + result)
        result = str(max(digit1, digit2)) + result
        output += P.update_nums.format(num1, num2, num1[:-1], num2[:-1])
        num1 = num1[:-1]
        num2 = num2[:-1]
    output += P.not_enter.format(num1, num2)
    output += P.update_rest.format(num1, num2, num1, num2, result, num1 + num2 + result)
    result = num1 + num2 + result
    output += P.ret.format(result)
    return output, result

def length_Integer_none_int(num: str) -> int:
    P = Prompt('length_Integer_none_int')
    output = P.initialize.format(num)
    result = 0
    output = ''
    while num:
        output += P.enter
        output += P.update_result.format(result, result+1)
        result += 1
        output += P.update_nums.format(num[1:])
        num = num[1:]
    output += P.not_enter
    output += P.ret.format(result)
    return output, result

def floordiv_integer_integer_integer(num1: str, num2: str, name_a: str, name_b: str, name_c: str) -> str:
    P = Prompt('floordiv_Integer_Integer_Integer')
    output = P.initialize.format(num1, num2)
    num2 = int(num2)
    result = ''
    num_now = 0
    while num1:
        output += P.enter
        output += P.update_num_now.format(num_now, num1[0], num_now * 10 + int(num1[0]))
        num_now = num_now * 10 + int(num1[0])
        output += P.digit_init
        digit = 0
        while num_now and int(num_now) >= int(num2):
            output += P.enter_digit
            output += P.update_num_now2.format(num_now, num2, num_now-num2)
            num_now -= num2
            output += P.update_digit.format(digit, digit+1)
            digit += 1
        output += P.not_enter_digit
        output += P.update_result.format(result, str(digit), result + str(digit))
        result += str(digit)
        output += P.update_nums.format(num1[1:])
        num1 = num1[1:]
    output += P.not_enter
    result = result.lstrip('0') or '0'
    output += P.lstrip.format(result)
    output += P.ret.format(result)
    output = output.replace('_var1_', name_a).replace('_var2_', name_b).replace('_var_res_', name_c)
    return output, result

def brief_floordiv_integer_integer_integer(num1: str, num2: str, name_a: str, name_b: str, name_c: str) -> str:
    P = Prompt('floordiv_Integer_Integer_Integer')
    output = P.initialize.format(num1, num2)
    num2 = int(num2)
    result = ''
    num_now = 0
    while num1:
        output += P.enter
        output += P.update_num_now.format(num_now, num1[0], num_now * 10 + int(num1[0]))
        num_now = num_now * 10 + int(num1[0])
        output += P.brief.format(num_now, num2, num_now // num2, num_now, num2, num_now % num2)
        digit = num_now // num2
        num_now = num_now % num2
        output += P.update_result.format(result, str(digit), result + str(digit))
        result += str(digit)
        output += P.update_nums.format(num1[1:])
        num1 = num1[1:]
    output += P.not_enter
    result = result.lstrip('0') or '0'
    output += P.lstrip.format(result)
    output += P.ret.format(result)
    output = output.replace('_var1_', name_a).replace('_var2_', name_b).replace('_var_res_', name_c)
    return output, result

def mod_integer_integer_integer(num1: str, num2: str, name_a: str, name_b: str) -> str:
    P = Prompt('mod_Integer_Integer_Integer')
    output = P.initialize.format(num1, num2)
    
    output += P.calc_q
    new_output, q = brief_floordiv_integer_integer_integer(num1, num2, 'div1', 'div2', 'q')
    output += new_output
    output += P.exit_function1.format(q)
    
    output += P.calc_temp
    new_output, temp = multiply_integer_integer_integer(q, num2, 'mul1', 'mul2', 'temp')
    output += new_output
    output += P.exit_function2.format(temp)
    
    output += P.calc_r
    new_output, r = sub_integer_integer_integer(num1, temp, 'sub1', 'sub2', 'r')
    output += new_output
    output += P.exit_function3.format(r)
    
    output += P.ret.format(r)
    output = output.replace('_var1_', name_a).replace('_var2_', name_b)
    return output, r
    
def gcd_integer_integer_integer(num1: str, num2: str, name_a: str, name_b: str) -> str:
    output = '''the greatest common divisor of {} and {} is {}
'''.format(num1, num2, np.gcd(int(num1), int(num2)))
    return output, str(np.gcd(int(num1), int(num2)))

class Dataset_Generator:
    def __init__(self) -> None:
        self.name = 'digit_max_Integer_Integer_Integer'
    
    def gen_data_from_len(self, length: int) -> dict:
        num1 = generate_integer(length)
        num2 = generate_integer(length)
        return 'Compare two numbers digit by digit and return the larger digit at each position, treating any missing digits as 0: {} and {}'.format(num1, num2)
        
    def rfft_IO(self, q_str) -> dict:
        '''
        return rfft input-output of given data
        '''
        P = Prompt(self.name)
        rule = P.rule
        instruction = "Follow the given rule to solve the question.\nrule:\n"
        num1, num2 = re.findall(r'(-?\d+)', q_str)[-2:]
        input = instruction + rule + "\n\nQ: " + q_str + '\n'
        # rfft output
        output, answer = digit_max_integer_integer_integer(num1, num2)
        output += P.answer.format(answer)
            # assert answer == expected, f'expected: {expected}, got: {answer}'
       
        return {"input": input,
                "output": output,
                "answer": answer}

    def cot_IO(self, q_str):
        num1, num2 = re.findall(r'(-?\d+)', q_str)[-2:]
        rationale = ""
        result = ""
        while num1 and num2:
            digit1 = num1[-1]
            digit2 = num2[-1]
            result = str(max(digit1, digit2)) + result
            num1 = num1[:-1]
            num2 = num2[:-1]
            rationale += f"# Comparing {digit1} and {digit2}, adding {str(max(digit1, digit2))} to result\n{num1}, {num2}, {result}\n"
        if num1 or num2:
            rationale += f"# Adding the left digits {num1 + num2}\n , , {num1 + num2 + result}\n"
        result = num1 + num2 + result
        rationale += 'So the answer is ' + result
        return {"input": q_str,
                "output": rationale,
                "answer": result}

    
if __name__ == "__main__":
    pass
