import numpy as np
import torch


vocab = ["+", "*", "?", "="] + [str(i) for i in range(10)]
token_ids = {x: i + 1 for i, x in enumerate(vocab)}
token_ids[" "] = 0
inv_token_ids = {v: k for k, v in token_ids.items()}


def generate_one_line_multiplication(maximum, from_beginning=False):
    a = np.random.randint(0, maximum)
    b = np.random.randint(0, maximum)
    c = a * b
    final = f"{a}*{b}={c}"
    if b == 0 or b == 1:
        question = f"{a}*{b}=?"
        answer = f"{a}*{b}={a * b}"
        return question, answer, final
    if from_beginning:
        question = f"{a}*{b}=?"
        answer = f"{a}*{b}={a}+?"
        return question, answer, final
    if np.random.random() < 0.5:
        n = np.random.randint(0, b)
        question = f"{a}*{b}="
        answer = f"{a}*{b}="
        for i in range(n):
            question += f"{a}+"
            answer += f"{a}+"
        question += "?"
        if n == b - 1:
            answer += f"{a}"
        else:
            answer += f"{a}+?"
        return question, answer, final
    n = np.random.randint(1, b)
    question = f"{a}*{b}={a * n}"
    for i in range(b - n):
        question += f"+{a}"
    answer = f"{a}*{b}={a * (n + 1)}"
    for i in range(b - n - 1):
        answer += f"+{a}"
    return question, answer, final


def one_hot(i, num):
    onehot = [0] * num
    onehot[i] = 1
    return onehot


class Generator:
    def __init__(self, maximum):
        self._maximum = maximum
        self._generator = self.generate()

    def __next__(self):
        return next(self._generator)

    def generate(self):
        while True:
            n = len(str(self._maximum - 1))
            maxlen = n + 1 + n + 1 + n + (self._maximum - 2) * (n + 1)
            x, cot, y_x = generate_one_line_multiplication(self._maximum)
            mask = [1] * maxlen
            x = x + " " * (maxlen - len(x))
            cot = cot + " " * (maxlen - len(cot))

            _x = [token_ids[t] for t in x]
            _x_onehot = [one_hot(i, len(token_ids)) for i in _x]
            _y = [token_ids[t] for t in cot]

            yield x, cot, _x_onehot, _y, mask


def evaluate_one_line_multiplication(model, maximum=9):
    x, cot, y_x = generate_one_line_multiplication(maximum, from_beginning=True)
    n = len(str(maximum - 1))
    maxlen = n + 1 + n + 1 + n + (maximum - 2) * (n + 1)
    tmp = x
    print(tmp, y_x)
    correct = False
    try:
        for _ in range(maximum * 2):
            _mask = [1] * maxlen
            _x = tmp + " " * (maxlen - len(tmp))
            _x = [token_ids[t] for t in _x]
            _x_onehot = [one_hot(i, len(token_ids)) for i in _x]
            _out, _dist = model(torch.tensor([_x_onehot], dtype=torch.float),
                                torch.tensor([_mask], dtype=torch.float))
            _y = torch.argmax(_out, dim=-1)[0]
            new_tmp = ""
            for index, m in zip(_y, _mask):
                if index > 0:
                    index = index.item()
                    new_tmp += str(inv_token_ids[index])
            tmp = new_tmp
            print(tmp)
            if "?" not in tmp and "+" not in tmp:
                break
        if tmp == y_x:
            correct = True
    except:
        pass
    if not correct:
        print("wrong: ", x, y_x)
    return correct
