import numpy as np
from numpy.random import Generator
from tasks import register

def generate_operands(rng: Generator, la, lb=None):
    la = rng.integers(*la)
    lb = rng.integers(*lb) if lb is not None else la
    a = str(rng.integers(1, 10)) + ''.join(map(str, rng.integers(0, 10, size=la-1)))
    b = str(rng.integers(1, 10)) + ''.join(map(str, rng.integers(0, 10, size=lb-1)))
    return la, lb, a, b

@register()
def synthetic_COT_mult(rng: Generator, la, lb=None):
    la, lb, a, b = generate_operands(rng, la, lb)
    s = str(int(a) * int(b))

    prompt = 'Compute the integer multiplication {a} * {b}. Please reason step by step, but do not overthink. Put your final answer within \\boxed{{}}. <think>\n'.format(a=a, b=b)

    return prompt, s, None

@register()
def reverse_mult(rng: Generator, la, lb=None):
    la, lb, a, b = generate_operands(rng, la, lb)
    s = str(int(a) * int(b))
    return f'{a[::-1]}*{b[::-1]}=', s[::-1], None

@register()
def reverse_mult_COT(rng: Generator, la, lb=None):
    la, lb, a, b = generate_operands(rng, la, lb)
    cot = []
    sum = 0
    for i, bi in enumerate(b[::-1]):
        ps = int(a) * int(bi)
        sum += ps * 10**i
        ps_str = str(ps).rjust(la+1, '0') + '0' * i
        # ps_str = str(ps).rjust(la+1, '0')
        # sum = str(sum).rjust(p*2, '0')
        # ps_str = str(ps)
        sum_str = str(sum)
        if i == 0:
            cot.append(ps_str[::-1])
        else:
            cot.append(ps_str[::-1] + '=' + sum_str[::-1])
    cot = '+'.join(cot)
    # s = str(int(a) * int(b))
    return f'{a[::-1]}*{b[::-1]}=', cot, None

@register()
def reverse_mult_with_padding(rng: Generator, la, lb=None, lp=0):
    la, lb, a, b = generate_operands(rng, la, lb)

    # p = rng.integers(max(la, lb), lp + 1)
    p = lp
    offset_a = rng.integers(0, p - max(la, lb) + 1)
    offset_b = rng.integers(0, p - max(la, lb) + 1)
    a = '0' * (p - offset_a - la) + a + '0' * offset_a
    b = '0' * (p - offset_b - lb) + b + '0' * offset_b

    cot = []
    sum = 0
    for i, bi in enumerate(b[::-1]):
        ps = int(a) * int(bi)
        sum += ps * 10**i
        ps = str(ps).rjust(p + 1, '0')
        cot.append(ps[::-1] + '(' + str(sum).rjust(p*2, '0')[::-1] + ')')
    cot = '+'.join(cot)

    return f'{a[::-1]}*{b[::-1]}=', cot, None

@register()
def reverse_mult_skills(rng: Generator, la, lb=None, skill=0):
    la, lb, a, b = generate_operands(rng, la, lb)
    s = str(int(a) * int(b))

    cot = []
    for i, bi in enumerate(b[::-1]):
        ps = int(a) * int(bi)
        ps = chr(ord('A') + i) + str(ps).rjust(la+1, '0') + '0' * i
        cot.append(ps[::-1])

if __name__ == '__main__':
    la = (8, 9)
    lb = (3, 4)

    print(reverse_mult_COT(np.random.default_rng(42), la, lb))
    print(reverse_mult_with_padding(np.random.default_rng(42), la, lb, 8))
