import numpy as np
import torch


vocab = ["+", "?", "c", "="] + [str(i) for i in range(10)]
token_ids = {x: i + 1 for i, x in enumerate(vocab)}
token_ids[" "] = 0
inv_token_ids = {v: k for k, v in token_ids.items()}


def generate_one_line_addition(length, digit=None, ret_ab=False):
    """
        12345+67890=?
        12345+67890=?5
        12345+67890=c35
        12345+67890=c235
        12345+67890=c0235
        12345+67890=80235
    """
    len_a = np.random.randint(1, length + 1)
    len_b = np.random.randint(1, length + 1)
    a = ""
    for _ in range(len_a):
        a += str(np.random.randint(0, 10))
    a = str(int(a))
    len_a = len(a)
    b = ""
    for _ in range(len_b):
        b += str(np.random.randint(0, 10))
    b = str(int(b))
    len_b = len(b)
    if digit is None:
        digit = np.random.randint(0, max(len_a, len_b))
    if digit == 0:
        c = "?"
    else:
        c = int(a[-digit:]) + int(b[-digit:])
        c = str(c)
        if len(c) < digit:
            c = "0" * (digit - len(c)) + c
        if len(c) > digit:
            notion = "c"
        else:
            notion = "?"
        c = f"{notion}{c[-digit:]}"
    # c = " " * (max(len_a, len(b)) + 1 - len(c)) + c
    input0 = f"{a}+{b}={c}"
    digit += 1
    if digit == max(len_a, len_b):
        c = f"{int(a) + int(b)}"
    else:
        c = int(a[-digit:]) + int(b[-digit:])
        c = str(c)
        if len(c) < digit:
            c = "0" * (digit - len(c)) + c
        if len(c) > digit:
            notion = "c"
        else:
            notion = "?"
        c = f"{notion}{c[-digit:]}"
    # c = " " * (max(len_a, len(b)) + 1 - len(c)) + c
    output0 = f"{a}+{b}={c}"

    c = f"{int(a) + int(b)}"
    # c = " " * (max(len_a, len(b)) + 1 - len(c)) + c
    final = f"{a}+{b}={c}"

    if ret_ab:
        return input0, output0, final, a, b
    return input0, output0, final


def one_hot(i, num):
    onehot = [0] * num
    onehot[i] = 1
    return onehot


class Generator:
    def __init__(self, length):
        self._length = length
        self._generator = self.generate()

    def __next__(self):
        return next(self._generator)

    def generate(self):
        while True:
            maxlen = 3 * self._length + 3
            input0, output0, final = generate_one_line_addition(self._length)
            mask = [1] * (len(input0) + 1) + [0] * (maxlen - len(input0) - 1)
            input0 = input0 + " " * (maxlen - len(input0))
            output0 = output0 + " " * (maxlen - len(output0))

            _input0 = [one_hot(token_ids[t], len(token_ids)) for t in input0]
            _output0 = [token_ids[t] for t in output0]

            yield _input0, _output0, mask


def evaluate_one_line_addition(model, length=10):
    input0, output0, final = generate_one_line_addition(length, digit=0)

    print("=============================")
    print(input0, final)

    correct = False
    try:
        for i in range(length):
            maxlen = len(input0) + 1
            mask = [1] * maxlen
            input0 = input0 + " " * (maxlen - len(input0))
            _input0 = [one_hot(token_ids[t], len(token_ids)) for t in input0]
            _out0, _dist0, _loss = model(
                torch.tensor([_input0], dtype=torch.float),
                torch.tensor([mask], dtype=torch.float)
            )
            _out0 = torch.argmax(_out0, dim=-1)[0]
            _out0 = [inv_token_ids[i] for i in np.array(_out0)]
            print("----------")

            input0 = "".join(_out0).strip()

            print(input0)

            if "?" not in input0 and "c" not in input0:
                break

        if input0 == final:
            correct = True
    except:
        pass
    if not correct:
        print("wrong: ", final)
    return correct
