import numpy as np
import torch

vocab = ["E", "S", "I", "e", "s", "i", "F", "T", "J", "b"] + [str(i) for i in range(10)]
token_ids = {x: i + 1 for i, x in enumerate(vocab)}
token_ids[" "] = 0
inv_token_ids = {v: k for k, v in token_ids.items()}


def generate_division(length):
    """
    Input:
       43210 - position
       33333 = a
      E    S - start, end
       I     - indicator
  4567       = b
 e   s       - start, end
     i       - indicator
       0     = current result
 F   T       - minus start, minus end
     J       - minus indicator

    Output:
        I    - new indicator
     i       - new indicator
  F   T      - result start, result end
      J      - minus indicator
        0    - ans

    ---------------------------------------------

    Input:
       43210 - position
       33333 = a
      E    S - start, end
          I  - indicator
  4567       = b
 e   s       - start, end
     i       - indicator
          0  = current result
       33333 = temp numerator
      F   T  - minus start, minus end
          J  - minus indicator

    Output:
          I  - new indicator
    i        - new indicator
      F   T  - result start, result end
         J   - minus indicator
          0  - answer
       33343 = temp numerator
         b   - borrow

    ---------------------------------------------

    Input:
       43210 - position
       33333 = a
      E    S - start, end
          I  - indicator
  4567       = b
 e   s       - start, end
  i          - indicator
          0  = current result
       37663 = temp numerator
      F   T  - result start, result end
       J     - minus indicator
       b     - minus borrow

    Output:
          I  - new indicator
 i           - new indicator
      F   T  - result start, result end
      J      - minus indicator
          0  - answer
       87663 = temp numerator
      b      - new borrow

    ---------------------------------------------

    Input:
       33333 - original numerator
       33333 = current numerator
      E    S - start, end
          I  - indicator
  4567       = denominator
 e   s       - start, end
 i           - indicator
          0  = current result
       33333 = temp numerator
      F   T  - result start, result end
      J      - minus indicator
      b      - minus borrow

    Output:
           I - new indicator
     i       - new indicator
       F   T - result start, result end
           J - minus indicator
           0 - answer
       33333 = current numerator
       33333 = temp numerator
             - new borrow
    """

    while True:
        len_a = np.random.randint(1, length + 1)
        len_b = np.random.randint(1, length + 1)
        a = ""
        for _ in range(len_a):
            a += str(np.random.randint(0, 10))
        a = int(a)
        len_a = len(str(a))
        b = ""
        for _ in range(len_b):
            b += str(np.random.randint(0, 10))
        b = int(b)
        len_b = len(str(b))
        c = a * b
        len_c = len(str(c))
        if c > 0:
            break

    inv_str_a = str(a)[::-1]
    inv_str_b = str(b)[::-1]
    inv_str_c = str(c)[::-1]

    I = len_c - 1
    i = 0
    c_cur = c - b // int(10 ** (I + 1)) * int(10 ** (I + 1)) * a
    c_tmp = c_cur
    b_cur = int((c - c_cur) / a)
    borrow = " "

    out = list()
    # count = 0

    while True:
        max_len = len_c + 1 + len_a + 1
        input0 = (max_len - len_c) * " " + str(c)
        input1 = "E" + " " * (len_c - 1) + "S"
        input1 = (max_len - len(input1)) * " " + input1
        input2 = " " + str(a) + " " * (len_c + 1)
        input3 = "e" + " " * (len_a - 1) + "s"
        input3 = input3 + (max_len - len(input3)) * " "

        input4 = (max_len - len(str(c_cur))) * " " + str(c_cur)
        input5 = (max_len - len(str(c_tmp))) * " " + str(c_tmp)
        input6 = "I" + " " * I
        input6 = (max_len - len(input6)) * " " + input6
        input7 = "i" + " " * i + " " * (len_c + 1)
        input7 = (max_len - len(input7)) * " " + input7
        input8 = "0" * (len_c - len(str(b_cur))) + str(b_cur)
        input8 = (max_len - len(input8)) * " " + input8
        input9 = "F" + " " * (len_a - 1) + "T" + " " * I
        input9 = (max_len - len(input9)) * " " + input9
        input10 = "J" + " " * i + " " * I
        input10 = (max_len - len(input10)) * " " + input10
        input11 = borrow + " " * i + " " * I
        input11 = (max_len - len(input11)) * " " + input11

        # print("///////////////////////////", count)
        # count += 1
        # print(input0)
        # print(input1)
        # print(input2)
        # print(input3)
        # print(input4)
        # print(input5)
        # print(input6)
        # print(input7)
        # print(input8)
        # print(input9)
        # print(input10)
        # print(input11)

        move_I, move_i, reset_i, increase_b = False, False, False, False
        if i < len_a:
            if i + I >= len(str(c_tmp)):
                move_I = True
                reset_i = True
                new_c_tmp = c_cur
                new_borrow = " "
            else:
                m = int(inv_str_a[i]) + int(borrow == "b")
                if c_tmp % int(10 ** (i + I + 1)) < m * 10 ** (i + I):
                    new_borrow = "b"
                    new_c_tmp = int(c_tmp + 10 ** (i + I + 1) - m * 10 ** (i + I))
                else:
                    new_borrow = " "
                    new_c_tmp = int(c_tmp - m * 10 ** (i + I))
                move_i = True
        elif i == len_a:
            m = int(borrow == "b")
            if c_tmp % int(10 ** (i + I + 1)) < m * 10 ** (i + I):
                move_I = True
                reset_i = True
                new_borrow = " "
                new_c_tmp = c_cur
            else:
                increase_b = True
                reset_i = True
                new_borrow = " "
                new_c_tmp = int(c_tmp - m * 10 ** (i + I))
                c_cur = new_c_tmp
        else:
            raise

        if move_I:
            new_I = I - 1
        else:
            new_I = I
        if reset_i:
            new_i = 0
        elif move_i:
            new_i = i + 1
        else:
            raise
        if increase_b:
            new_b_cur = b_cur + int(10 ** I)
        else:
            new_b_cur = b_cur

        output0 = "I" + " " * new_I
        output0 = (max_len - len(output0)) * " " + output0
        output1 = "i" + " " * new_i + " " * (len_c + 1)
        output1 = (max_len - len(output1)) * " " + output1
        output2 = "F" + " " * (len_a - 1) + "T" + " " * new_I
        output2 = (max_len - len(output2)) * " " + output2
        output3 = "J" + " " * new_i + " " * new_I
        output3 = (max_len - len(output3)) * " " + output3
        output4 = "0" * (len_c - len(str(new_b_cur))) + str(new_b_cur)
        output4 = (max_len - len(output4)) * " " + output4
        output5 = (max_len - len(str(c_cur))) * " " + str(c_cur)
        output6 = (max_len - len(str(new_c_tmp))) * " " + str(new_c_tmp)
        output7 = new_borrow + " " * new_i + " " * new_I
        output7 = (max_len - len(output7)) * " " + output7

        if new_I < 0:
            output0 = max_len * " "
            output2 = "F" + " " * (len_a - 1)
            output2 = (max_len - len(output2)) * " " + output2
            output3 = max_len * " "
            output7 = max_len * " "

        # print("-----------------")
        # print(output0)
        # print(output1)
        # print(output2)
        # print(output3)
        # print(output4)
        # print(output5)
        # print(output6)
        # print(output7)

        I = new_I
        i = new_i
        c_tmp = new_c_tmp
        b_cur = new_b_cur
        borrow = new_borrow

        out.append((input0, input1, input2, input3, input4, input5, input6, input7, input8, input9, input10, input11,
                    output0, output1, output2, output3, output4, output5, output6, output7))

        if I < 0:
            break

    # for ii, tmp in enumerate(out):
    #     print(ii, "---------------------")
    #     for jj, tmp2 in enumerate(tmp):
    #         if jj == 12:
    #             print("----------")
    #         print(tmp2)

    return out, a, b, c


def one_hot(i, num):
    onehot = [0] * num
    onehot[i] = 1
    return onehot


class Generator:
    def __init__(self, length):
        self._length = length
        self._generator = self.generate()

    def __next__(self):
        return next(self._generator)

    def generate(self):
        while True:
            maxlen = self._length * 3 + 2
            out, a, b, c = generate_division(self._length)
            for step in out:
                (input0, input1, input2, input3, input4, input5, input6, input7, input8, input9, input10, input11,
                 output0, output1, output2, output3, output4, output5, output6, output7) = step
                mask = [1] * len(input0) + [0] * (maxlen - len(input0))
                input0 = input0 + " " * (maxlen - len(input0))
                input1 = input1 + " " * (maxlen - len(input1))
                input2 = input2 + " " * (maxlen - len(input2))
                input3 = input3 + " " * (maxlen - len(input3))
                input4 = input4 + " " * (maxlen - len(input4))
                input5 = input5 + " " * (maxlen - len(input5))
                input6 = input6 + " " * (maxlen - len(input6))
                input7 = input7 + " " * (maxlen - len(input7))
                input8 = input8 + " " * (maxlen - len(input8))
                input9 = input9 + " " * (maxlen - len(input9))
                input10 = input10 + " " * (maxlen - len(input10))
                input11 = input11 + " " * (maxlen - len(input11))
                output0 = output0 + " " * (maxlen - len(output0))
                output1 = output1 + " " * (maxlen - len(output1))
                output2 = output2 + " " * (maxlen - len(output2))
                output3 = output3 + " " * (maxlen - len(output3))
                output4 = output4 + " " * (maxlen - len(output4))
                output5 = output5 + " " * (maxlen - len(output5))
                output6 = output6 + " " * (maxlen - len(output6))
                output7 = output7 + " " * (maxlen - len(output7))

                _input0 = [one_hot(token_ids[t], len(token_ids)) for t in input0]
                _input1 = [one_hot(token_ids[t], len(token_ids)) for t in input1]
                _input2 = [one_hot(token_ids[t], len(token_ids)) for t in input2]
                _input3 = [one_hot(token_ids[t], len(token_ids)) for t in input3]
                _input4 = [one_hot(token_ids[t], len(token_ids)) for t in input4]
                _input5 = [one_hot(token_ids[t], len(token_ids)) for t in input5]
                _input6 = [one_hot(token_ids[t], len(token_ids)) for t in input6]
                _input7 = [one_hot(token_ids[t], len(token_ids)) for t in input7]
                _input8 = [one_hot(token_ids[t], len(token_ids)) for t in input8]
                _input9 = [one_hot(token_ids[t], len(token_ids)) for t in input9]
                _input10 = [one_hot(token_ids[t], len(token_ids)) for t in input10]
                _input11 = [one_hot(token_ids[t], len(token_ids)) for t in input11]
                _output0 = [token_ids[t] for t in output0]
                _output1 = [token_ids[t] for t in output1]
                _output2 = [token_ids[t] for t in output2]
                _output3 = [token_ids[t] for t in output3]
                _output4 = [token_ids[t] for t in output4]
                _output5 = [token_ids[t] for t in output5]
                _output6 = [token_ids[t] for t in output6]
                _output7 = [token_ids[t] for t in output7]

                yield (_input0, _input1, _input2, _input3, _input4, _input5,
                       _input6, _input7, _input8, _input9, _input10, _input11,
                       _output0, _output1, _output2, _output3, _output4,
                       _output5, _output6, _output7, mask)


def evaluate_division(model, length=10):
    out, a, b, c = generate_division(length)
    step = out[0]
    (input0, input1, input2, input3, input4, input5, input6, input7, input8, input9, input10, input11,
     output0, output1, output2, output3, output4, output5, output6, output7) = step

    print("=============================")
    print(c)
    print(a)

    maxlen = len(input0)
    mask = [1] * len(input0)

    input0 = input0 + " " * (maxlen - len(input0))
    input1 = input1 + " " * (maxlen - len(input1))
    input2 = input2 + " " * (maxlen - len(input2))
    input3 = input3 + " " * (maxlen - len(input3))
    input4 = input4 + " " * (maxlen - len(input4))
    input5 = input5 + " " * (maxlen - len(input5))
    input6 = input6 + " " * (maxlen - len(input6))
    input7 = input7 + " " * (maxlen - len(input7))
    input8 = input8 + " " * (maxlen - len(input8))
    input9 = input9 + " " * (maxlen - len(input9))
    input10 = input10 + " " * (maxlen - len(input10))
    input11 = input11 + " " * (maxlen - len(input11))

    _input0 = [one_hot(token_ids[t], len(token_ids)) for t in input0]
    _input1 = [one_hot(token_ids[t], len(token_ids)) for t in input1]
    _input2 = [one_hot(token_ids[t], len(token_ids)) for t in input2]
    _input3 = [one_hot(token_ids[t], len(token_ids)) for t in input3]
    _input4 = [one_hot(token_ids[t], len(token_ids)) for t in input4]
    _input5 = [one_hot(token_ids[t], len(token_ids)) for t in input5]
    _input6 = [one_hot(token_ids[t], len(token_ids)) for t in input6]
    _input7 = [one_hot(token_ids[t], len(token_ids)) for t in input7]
    _input8 = [one_hot(token_ids[t], len(token_ids)) for t in input8]
    _input9 = [one_hot(token_ids[t], len(token_ids)) for t in input9]
    _input10 = [one_hot(token_ids[t], len(token_ids)) for t in input10]
    _input11 = [one_hot(token_ids[t], len(token_ids)) for t in input11]

    correct = False
    try:
        for i in range((length + 1) * (length + 1) * 10):
            (_out0, _out1, _out2, _out3, _out4, _out5, _out6, _out7,
             _dist0, _dist1, _dist2, _dist3, _dist4, _dist5, _dist6, _dist7, _loss) = model(
                torch.tensor([_input0], dtype=torch.float),
                torch.tensor([_input1], dtype=torch.float),
                torch.tensor([_input2], dtype=torch.float),
                torch.tensor([_input3], dtype=torch.float),
                torch.tensor([_input4], dtype=torch.float),
                torch.tensor([_input5], dtype=torch.float),
                torch.tensor([_input6], dtype=torch.float),
                torch.tensor([_input7], dtype=torch.float),
                torch.tensor([_input8], dtype=torch.float),
                torch.tensor([_input9], dtype=torch.float),
                torch.tensor([_input10], dtype=torch.float),
                torch.tensor([_input11], dtype=torch.float),
                torch.tensor([mask], dtype=torch.float)
            )
            _out0 = torch.argmax(_out0, dim=-1)[0]
            _out1 = torch.argmax(_out1, dim=-1)[0]
            _out2 = torch.argmax(_out2, dim=-1)[0]
            _out3 = torch.argmax(_out3, dim=-1)[0]
            _out4 = torch.argmax(_out4, dim=-1)[0]
            _out5 = torch.argmax(_out5, dim=-1)[0]
            _out6 = torch.argmax(_out6, dim=-1)[0]
            _out7 = torch.argmax(_out7, dim=-1)[0]

            _out0 = "".join([inv_token_ids[i] for i in np.array(_out0)])
            _out1 = "".join([inv_token_ids[i] for i in np.array(_out1)])
            _out2 = "".join([inv_token_ids[i] for i in np.array(_out2)])
            _out3 = "".join([inv_token_ids[i] for i in np.array(_out3)])
            _out4 = "".join([inv_token_ids[i] for i in np.array(_out4)])
            _out5 = "".join([inv_token_ids[i] for i in np.array(_out5)])
            _out6 = "".join([inv_token_ids[i] for i in np.array(_out6)])
            _out7 = "".join([inv_token_ids[i] for i in np.array(_out7)])

            # print("----------")
            # print(input0)
            # print(input1)
            # print(input2)
            # print(input3)
            # print(input4)
            # print(input5)
            # print(input6)
            # print(input7)
            # print(input8)
            # print(input9)
            # print(input10)
            # print(input11)
            # print("        ")
            # print(_out0)
            # print(_out1)
            # print(_out2)
            # print(_out3)
            # print(_out4)
            # print(_out5)
            # print(_out6)
            # print(_out7)

            _input6 = [one_hot(token_ids[t], len(token_ids)) for t in _out0]
            _input7 = [one_hot(token_ids[t], len(token_ids)) for t in _out1]
            _input9 = [one_hot(token_ids[t], len(token_ids)) for t in _out2]
            _input10 = [one_hot(token_ids[t], len(token_ids)) for t in _out3]
            _input8 = [one_hot(token_ids[t], len(token_ids)) for t in _out4]
            _input4 = [one_hot(token_ids[t], len(token_ids)) for t in _out5]
            _input5 = [one_hot(token_ids[t], len(token_ids)) for t in _out6]
            _input11 = [one_hot(token_ids[t], len(token_ids)) for t in _out7]

            b_cur = int(_out4.strip())

            if "I" not in _out0:  # eos
                break

        if b_cur == b:
            correct = True
    except:
        pass
    if not correct:
        print("wrong: ", a, b, c)
    return correct
