import numpy as np
import torch


vocab = ["+", "?", "c", "=", "I", "i", "J", "#"] + [str(i) for i in range(10)]  # '#' is 'EOS'
token_ids = {x: i + 1 for i, x in enumerate(vocab)}
token_ids[" "] = 0
inv_token_ids = {v: k for k, v in token_ids.items()}


def get_ind(a, digit, ind, pad):
    digit = min(digit, len(a) - 1)
    return pad * (len(a) - digit - 1) + ind + pad * digit


def generate_two_line_addition(length, digit=None):
    """
        12345+67890=?
            I     i
        12345+67890=?5
           I     i
        12345+67890=c35
          I     i
        12345+67890=c235
         I     i
        12345+67890=c0235
        I     i
        12345+67890=80235
    """
    len_a = np.random.randint(1, length + 1)
    len_b = np.random.randint(1, length + 1)
    a = ""
    for _ in range(len_a):
        a += str(np.random.randint(0, 10))
    a = str(int(a))
    len_a = len(a)
    b = ""
    for _ in range(len_b):
        b += str(np.random.randint(0, 10))
    b = str(int(b))
    len_b = len(b)
    if digit is None:
        digit = np.random.randint(0, max(len_a, len_b) + 1)
    if digit > max(len_a, len_b):
        c = int(a) + int(b)
        c = str(c)
    elif digit == 0:
        c = "?"
    else:
        c = int(a[-digit:]) + int(b[-digit:])
        c = str(c)
        if len(c) < digit:
            c = "0" * (digit - len(c)) + c
        if len(c) > digit:
            notion = "c"
        else:
            notion = "?"
        c = f"{notion}{c[-digit:]}"
    c = " " * (max(len_a, len_b) + 1 - len(c)) + c
    input0 = f" {a}+ {b}= {c}"
    input1 = f"{get_ind(' ' + a, digit, 'I', ' ')} {get_ind(' ' + b, digit, 'i', ' ')} {get_ind(' ' + c, digit, 'J', ' ')}"
    digit += 1
    if digit > max(len_a, len_b):
        c = int(a) + int(b)
        c = str(c)
    else:
        c = int(a[-digit:]) + int(b[-digit:])
        c = str(c)
        if len(c) < digit:
            c = "0" * (digit - len(c)) + c
        if len(c) > digit:
            notion = "c"
        else:
            notion = "?"
        c = f"{notion}{c[-digit:]}"
    c = " " * (max(len_a, len(b)) + 1 - len(c)) + c
    output0 = f" {a}+ {b}= {c}"
    if digit > max(len_a, len_b):
        output1 = f"{get_ind(' ' + a, digit, 'I', ' ')} {get_ind(' ' + b, digit, 'i', ' ')} {get_ind(' ' + c, digit, '#', ' ')}"
    else:
        output1 = f"{get_ind(' ' + a, digit, 'I', ' ')} {get_ind(' ' + b, digit, 'i', ' ')} {get_ind(' ' + c, digit, 'J', ' ')}"

    c = f"{int(a) + int(b)}"
    c = " " * (max(len_a, len(b)) + 1 - len(c)) + c
    final = f" {a}+ {b}= {c}"

    # print(input0)
    # print(input1)
    # print(output0)
    # print(output1)
    return input0, input1, output0, output1, final


def help(input0, input1, output0, output1, final):
    i, j, k = input1.index("I"), input1.index("i"), input1.index("J")
    key0 = input0[max(0, i - 1): i + 2] + input0[max(0, j - 1): j + 2] + input0[max(0, k - 1): k + 2]
    key1 = input1[max(0, i - 1): i + 2] + input1[max(0, j - 1): j + 2] + input1[max(0, k - 1): k + 2]
    value0 = output0[max(0, i - 1): i + 2] + output0[max(0, j - 1): j + 2] + output0[max(0, k - 1): k + 2]
    value1 = output1[max(0, i - 1): i + 2] + output1[max(0, j - 1): j + 2] + output1[max(0, k - 1): k + 2]
    return key0, key1, value0, value1


def one_hot(i, num):
    onehot = [0] * num
    onehot[i] = 1
    return onehot


class Generator:
    def __init__(self, length):
        self._length = length
        self._generator = self.generate()

    def __next__(self):
        return next(self._generator)

    def generate(self):
        while True:
            maxlen = 3 * self._length + 6
            input0, input1, output0, output1, final = generate_two_line_addition(self._length)
            mask = [1] * len(input0) + [0] * (maxlen - len(input0))
            input0 = input0 + " " * (maxlen - len(input0))
            input1 = input1 + " " * (maxlen - len(input1))
            output0 = output0 + " " * (maxlen - len(output0))
            output1 = output1 + " " * (maxlen - len(output1))

            _input0 = [one_hot(token_ids[t], len(token_ids)) for t in input0]
            _input1 = [one_hot(token_ids[t], len(token_ids)) for t in input1]
            _output0 = [token_ids[t] for t in output0]
            _output1 = [token_ids[t] for t in output1]

            yield _input0, _input1, _output0, _output1, mask


def evaluate_two_line_addtion(model, length=10):
    input0, input1, output0, output1, final = generate_two_line_addition(length, digit=0)

    print("=============================")
    print(input0, final)

    correct = False
    try:
        for i in range(length + 1):
            maxlen = len(input0)
            mask = [1] * maxlen
            input0 = input0 + " " * (maxlen - len(input0))
            input1 = input1 + " " * (maxlen - len(input1))
            _input0 = [one_hot(token_ids[t], len(token_ids)) for t in input0]
            _input1 = [one_hot(token_ids[t], len(token_ids)) for t in input1]

            _out0, _out1, _dist0, _dist1, _loss = model(
                torch.tensor([_input0], dtype=torch.float),
                torch.tensor([_input1], dtype=torch.float),
                torch.tensor([mask], dtype=torch.float)
            )
            _out0 = torch.argmax(_out0, dim=-1)[0]
            _out1 = torch.argmax(_out1, dim=-1)[0]
            _out0 = [inv_token_ids[i] for i in np.array(_out0)]
            _out1 = [inv_token_ids[i] for i in np.array(_out1)]

            input0 = "".join(_out0)
            input1 = "".join(_out1)

            # print("----------")
            # print(input0)
            # print(input1)

            if "#" in input1:
                break

        if input0 == final:
            correct = True
    except:
        pass
    if not correct:
        print("wrong: ", final)
    return correct
