from numpy.random import Generator
from tasks import register

def generate_operands(rng: Generator, la, lb=None):
    la = rng.integers(*la)
    lb = rng.integers(*lb) if lb is not None else la
    a = str(rng.integers(1, 10)) + ''.join(map(str, rng.integers(0, 10, size=la - 1)))
    b = str(rng.integers(1, 10)) + ''.join(map(str, rng.integers(0, 10, size=lb - 1)))
    return la, lb, a, b

# @register(type='check_ood')
# def check_ood(train_kwargs, test_kwargs):
#     return train_kwargs['la'] <= test_kwargs['la'] and train_kwargs.get('lb', train_kwargs['la']) <= test_kwargs.get('lb', test_kwargs['la'])

@register()
def reverse_add(rng: Generator, la, lb=None, rjust=True):
    la, lb, a, b = generate_operands(rng, la, lb)
    if rjust:
        s = str(int(a) + int(b)).rjust(max(la, lb) + 1, '0')
    else:
        s = str(int(a) + int(b))
    return f'{a[::-1]}+{b[::-1]}=', s[::-1], None

@register()
def reverse_add_with_padding(rng: Generator, la, lb=None, lp=0):
    la, lb, a, b = generate_operands(rng, la, lb)
    s = str(int(a) + int(b))

    p = rng.integers(max(la, lb), lp + 1)
    offset = rng.integers(0, p - max(la, lb) + 1)
    a = '0' * (p - offset - la) + a + '0' * offset
    b = '0' * (p - offset - lb) + b + '0' * offset
    s = '0' * (p - offset - len(s)) + s + '0' * offset
    return f'{a[::-1]}+{b[::-1]}=', s[::-1], None

@register()
def reverse_sub(rng: Generator, la, lb=None):
    la, lb, a, b = generate_operands(rng, la, lb)
    s = str(int(a) - int(b))
    l = max(len(a), len(b))
    s = s.rjust(l+1, '0')
    return f'{a[::-1]}+{b[::-1]}=', s[::-1], None

@register()
def reverse_add_no_carry(rng: Generator, la, lb=None):
    la, lb, a, b = generate_operands(rng, la, lb)
    l = max(la, lb)
    s = ''.join([
        str((int(ai) + int(bi)) % 10)
        for ai, bi in zip(a.rjust(l, '0')[::-1], b.rjust(l, '0')[::-1])
    ]) + '0'
    return f'{a[::-1]}+{b[::-1]}=', s, None

@register()
def reverse_add_only_carry(rng: Generator, la, lb=None):
    la, lb, a, b = generate_operands(rng, la, lb)
    l = max(la, lb)
    s = '0' + ''.join([
        str((int(ai) + int(bi)) // 10)
        for ai, bi in zip(a.rjust(l, '0')[::-1], b.rjust(l, '0')[::-1])
    ])
    return f'{a[::-1]}+{b[::-1]}=', s, None

@register()
def reverse_add_trans(rng: Generator, la, lb=None):
    la, lb, a, b = generate_operands(rng, la, lb)
    s = str(int(a) + int(b))
    s = s.translate(str.maketrans('0123456789', 'abcdefghij'))
    return f'{a[::-1]}+{b[::-1]}=', s[::-1], None

@register()
def copy_first_op(rng: Generator, la, lb=None):
    la, lb, a, b = generate_operands(rng, la, lb)
    s = a
    return f'{a[::-1]}+{b[::-1]}=', s[::-1], None

@register()
def reverse_add_ICL(rng: Generator, la, lb=None, num_examples=6):
    p = ''
    for i in range(num_examples):
        _, _, a, b = generate_operands(rng, la, lb)
        s = str(int(a) + int(b))
        if i < num_examples - 1:
            p += f'{a[::-1]}+{b[::-1]}=' + s[::-1] + ','
        else:
            p += f'{a[::-1]}+{b[::-1]}='
    return p, s[::-1], None

if __name__ == '__main__':
    from numpy.random import default_rng

    task = reverse_add_only_carry
    rng = default_rng(43)
    print(task(rng, (7, 8), (5, 6)))

    task = reverse_add_no_carry
    rng = default_rng(43)
    print(task(rng, (7, 8), (5, 6)))

    task = reverse_add
    rng = default_rng(43)
    print(task(rng, (7, 8), (7, 8)))

    task = reverse_sub
    rng = default_rng(43)
    print(task(rng, (7, 8), (2, 3)))
