import numpy as np
import random
import copy
import torch

vocab = ["E", "S", "I", "e", "s", "i", "F", "T", "J", "?", "c", "#"] + [str(i) for i in range(10)]
token_ids = {x: i + 1 for i, x in enumerate(vocab)}
token_ids[" "] = 0
inv_token_ids = {v: k for k, v in token_ids.items()}


def generate_multiplication(length):
    """
    Input:
       43210 - position
         123 = a
        E  S - start, end
         I   - indicator
        4567 = b
       e   s - start, end
          i  - indicator
       88275 = current result
     F   T   - result start, result end
        J    - result indicator
             - multiplication answer
             - addition indicator

    Output:
         I   - new indicator
         i   - new indicator
     f   t   - result start, result end
       j     - result indicator
        6    - multiplication answer
        ?    - addition indicator
       88275 = current result

    """
    len_a = np.random.randint(1, length + 1)
    len_b = np.random.randint(1, length + 1)
    a = ""
    for _ in range(len_a):
        a += str(np.random.randint(0, 10))
    a = int(a)
    len_a = len(str(a))
    b = ""
    for _ in range(len_b):
        b += str(np.random.randint(0, 10))
    b = int(b)
    len_b = len(str(b))
    inv_str_a = str(a)[::-1]
    inv_str_b = str(b)[::-1]
    # I = np.random.randint(0, len_a)
    # i = np.random.randint(0, len_b)

    max_len = len_a + len_b + 1
    final = [s for s in str(a * b)]
    final = [" "] * (max_len - len(final)) + final

    data = list()
    output5 = [" "] * (max_len - 1) + ["#"]
    output6 = [" "] * max_len
    for I in range(len_a):
        for i in range(len_b):
            cur = 0
            for x in range(I):
                for y in range(len_b):
                    cur += int(inv_str_a[x]) * int(inv_str_b[y]) * int(10 ** (x + y))
            x = I
            for y in range(i):
                cur += int(inv_str_a[x]) * int(inv_str_b[y]) * int(10 ** (x + y))

            new_i, new_I = i + 1, I
            if new_i == len_b:
                new_i = 0
                new_I = I + 1
            add_ind = "?"
            add_pos = I + i

            input0 = [_ for _ in range(max_len)]
            input0 = input0[::-1]  # position
            input1 = [s for s in str(a)]
            input1 = [" "] * (max_len - len(input1)) + input1  # a
            input2 = ["E"] + [" "] * (len_a - 1) + ["S"]
            input2 = [" "] * (max_len - len(input2)) + input2  # E  S
            input3 = ["I"] + [" "] * I
            input3 = [" "] * (max_len - len(input3)) + input3  # I
            input4 = [s for s in str(b)]
            input4 = [" "] * (max_len - len(input4)) + input4  # b
            input5 = ["e"] + [" "] * (len_b - 1) + ["s"]
            input5 = [" "] * (max_len - len(input5)) + input5  # e  s
            input6 = ["i"] + [" "] * i
            input6 = [" "] * (max_len - len(input6)) + input6  # i
            input7 = copy.deepcopy(output6)  # cur
            input8 = ["F"] + [" "] * (len_b - 1) + ["T"] + [" "] * I
            input8 = [" "] * (max_len - len(input8)) + input8  # F  T
            input9 = ["J"] + [" "] * (I + i)
            input9 = [" "] * (max_len - len(input9)) + input9  # J
            input10 = [" "] * max_len  # addition adder
            input11 = copy.deepcopy(output5)  # addition indicator

            output0 = ["I"] + [" "] * new_I
            output0 = [" "] * (max_len - len(output0)) + output0  # new_I
            output1 = ["i"] + [" "] * new_i
            output1 = [" "] * (max_len - len(output1)) + output1  # new_i
            output2 = ["F"] + [" "] * (len_b - 1) + ["T"] + [" "] * new_I
            output2 = [" "] * (max_len - len(output2)) + output2  # F  T
            output3 = ["J"] + [" "] * (new_I + new_i)
            output3 = [" "] * (max_len - len(output3)) + output3  # J
            output4 = int(inv_str_a[I]) * int(inv_str_b[i])
            output4 = [s for s in str(output4)] + [" "] * (I + i)
            output4 = [" "] * (max_len - len(output4)) + output4  # addition adder
            output5 = [add_ind] + [" "] * add_pos
            output5 = [" "] * (max_len - len(output5)) + output5  # addition indicator
            output6 = copy.deepcopy(input7)  # cur

            data.append((input0, input1, input2, input3, input4, input5, input6, input7, input8, input9, input10, input11,
                         output0, output1, output2, output3, output4, output5, output6, final))

            while add_ind != "#":
                input3 = copy.deepcopy(output0)
                input6 = copy.deepcopy(output1)
                input8 = copy.deepcopy(output2)
                input9 = copy.deepcopy(output3)
                input10 = copy.deepcopy(output4)
                input11 = copy.deepcopy(output5)
                input7 = copy.deepcopy(output6)

                u = int(output4[-add_pos - 1]) if output4[-add_pos - 1] != " " else 0
                v = 1 if add_ind == "c" else 0
                w = int(output6[-add_pos-1]) if output6[-add_pos-1] != " " else 0
                o = u + v + w
                if o >= 10:
                    add_ind = "c"
                else:
                    if output4[-add_pos-2] != " ":
                        add_ind = "?"
                    else:
                        add_ind = "#"
                add_pos += 1
                output4 = copy.deepcopy(output4)
                output4[-add_pos] = " "
                output5 = [add_ind] + [" "] * add_pos
                output5 = [" "] * (max_len - len(output5)) + output5
                output6 = copy.deepcopy(output6)
                output6[-add_pos] = str(o % 10)

                data.append((
                    input0, input1, input2, input3, input4, input5, input6, input7, input8, input9, input10, input11,
                    output0, output1, output2, output3, output4, output5, output6, final))

    # for ii, tmp in enumerate(data):
    #     print(ii, "---------------------")
    #     for jj, tmp2 in enumerate(tmp):
    #         if jj == 12:
    #             print("----------")
    #         print(tmp2)

    assert int("".join(data[-1][-2]).strip()) == int("".join(final).strip()), f"{a}, {b}"

    return data


def one_hot(i, num):
    onehot = [0] * num
    onehot[i] = 1
    return onehot


class Generator:
    def __init__(self, length):
        self._length = length
        self._generator = self.generate()

    def __next__(self):
        return next(self._generator)

    def generate(self):
        while True:
            maxlen = self._length * 2 + 1
            data = generate_multiplication(self._length)
            for (input0, input1, input2, input3, input4, input5, input6, input7, input8, input9, input10, input11,
                 output0, output1, output2, output3, output4, output5, output6, final) in data:
                mask = [1] * len(input0) + [0] * (maxlen - len(input0))
                input0 = input0 + [0] * (maxlen - len(input0))
                input1 = input1 + [" "] * (maxlen - len(input1))
                input2 = input2 + [" "] * (maxlen - len(input2))
                input3 = input3 + [" "] * (maxlen - len(input3))
                input4 = input4 + [" "] * (maxlen - len(input4))
                input5 = input5 + [" "] * (maxlen - len(input5))
                input6 = input6 + [" "] * (maxlen - len(input6))
                input7 = input7 + [" "] * (maxlen - len(input7))
                input8 = input8 + [" "] * (maxlen - len(input8))
                input9 = input9 + [" "] * (maxlen - len(input9))
                input10 = input10 + [" "] * (maxlen - len(input10))
                input11 = input11 + [" "] * (maxlen - len(input11))
                output0 = output0 + [" "] * (maxlen - len(output0))
                output1 = output1 + [" "] * (maxlen - len(output1))
                output2 = output2 + [" "] * (maxlen - len(output2))
                output3 = output3 + [" "] * (maxlen - len(output3))
                output4 = output4 + [" "] * (maxlen - len(output4))
                output5 = output5 + [" "] * (maxlen - len(output5))
                output6 = output6 + [" "] * (maxlen - len(output6))

                _input0 = input0
                _input1 = [one_hot(token_ids[t], len(token_ids)) for t in input1]
                _input2 = [one_hot(token_ids[t], len(token_ids)) for t in input2]
                _input3 = [one_hot(token_ids[t], len(token_ids)) for t in input3]
                _input4 = [one_hot(token_ids[t], len(token_ids)) for t in input4]
                _input5 = [one_hot(token_ids[t], len(token_ids)) for t in input5]
                _input6 = [one_hot(token_ids[t], len(token_ids)) for t in input6]
                _input7 = [one_hot(token_ids[t], len(token_ids)) for t in input7]
                _input8 = [one_hot(token_ids[t], len(token_ids)) for t in input8]
                _input9 = [one_hot(token_ids[t], len(token_ids)) for t in input9]
                _input10 = [one_hot(token_ids[t], len(token_ids)) for t in input10]
                _input11 = [one_hot(token_ids[t], len(token_ids)) for t in input11]
                _output0 = [token_ids[t] for t in output0]
                _output1 = [token_ids[t] for t in output1]
                _output2 = [token_ids[t] for t in output2]
                _output3 = [token_ids[t] for t in output3]
                _output4 = [token_ids[t] for t in output4]
                _output5 = [token_ids[t] for t in output5]
                _output6 = [token_ids[t] for t in output6]

                yield (_input0, _input1, _input2, _input3, _input4, _input5, _input6, _input7, _input8, _input9, _input10, _input11,
                       _output0, _output1, _output2, _output3, _output4, _output5, _output6, mask)


def evaluate_multiplication(model, length=10):
    data = generate_multiplication(length)
    (input0, input1, input2, input3, input4, input5, input6, input7, input8, input9, input10, input11,
     output0, output1, output2, output3, output4, output5, output6, final) = data[0]

    a = int("".join(input1).strip())
    b = int("".join(input4).strip())
    c = int("".join(final).strip())

    print("=============================")
    print(input1)
    print(input4)

    maxlen = len(input0)
    mask = [1] * len(input0)

    input0 = input0 + [0] * (maxlen - len(input0))
    input1 = input1 + [" "] * (maxlen - len(input1))
    input2 = input2 + [" "] * (maxlen - len(input2))
    input3 = input3 + [" "] * (maxlen - len(input3))
    input4 = input4 + [" "] * (maxlen - len(input4))
    input5 = input5 + [" "] * (maxlen - len(input5))
    input6 = input6 + [" "] * (maxlen - len(input6))
    input7 = input7 + [" "] * (maxlen - len(input7))
    input8 = input8 + [" "] * (maxlen - len(input8))
    input9 = input9 + [" "] * (maxlen - len(input9))
    input10 = input10 + [" "] * (maxlen - len(input10))
    input11 = input11 + [" "] * (maxlen - len(input11))

    _input0 = input0
    _input1 = [one_hot(token_ids[t], len(token_ids)) for t in input1]
    _input2 = [one_hot(token_ids[t], len(token_ids)) for t in input2]
    _input3 = [one_hot(token_ids[t], len(token_ids)) for t in input3]
    _input4 = [one_hot(token_ids[t], len(token_ids)) for t in input4]
    _input5 = [one_hot(token_ids[t], len(token_ids)) for t in input5]
    _input6 = [one_hot(token_ids[t], len(token_ids)) for t in input6]
    _input7 = [one_hot(token_ids[t], len(token_ids)) for t in input7]
    _input8 = [one_hot(token_ids[t], len(token_ids)) for t in input8]
    _input9 = [one_hot(token_ids[t], len(token_ids)) for t in input9]
    _input10 = [one_hot(token_ids[t], len(token_ids)) for t in input10]
    _input11 = [one_hot(token_ids[t], len(token_ids)) for t in input11]

    correct = False
    for i in range(len(data)):
        (_out0, _out1, _out2, _out3, _out4, _out5, _out6,
         _dist0, _dist1, _dist2, _dist3, _dist4, _dist5, _dist6, _loss) = model(
            torch.tensor([_input0], dtype=torch.float),
            torch.tensor([_input1], dtype=torch.float),
            torch.tensor([_input2], dtype=torch.float),
            torch.tensor([_input3], dtype=torch.float),
            torch.tensor([_input4], dtype=torch.float),
            torch.tensor([_input5], dtype=torch.float),
            torch.tensor([_input6], dtype=torch.float),
            torch.tensor([_input7], dtype=torch.float),
            torch.tensor([_input8], dtype=torch.float),
            torch.tensor([_input9], dtype=torch.float),
            torch.tensor([_input10], dtype=torch.float),
            torch.tensor([_input11], dtype=torch.float),
            torch.tensor([mask], dtype=torch.float)
        )
        _out0 = torch.argmax(_out0, dim=-1)[0]
        _out1 = torch.argmax(_out1, dim=-1)[0]
        _out2 = torch.argmax(_out2, dim=-1)[0]
        _out3 = torch.argmax(_out3, dim=-1)[0]
        _out4 = torch.argmax(_out4, dim=-1)[0]
        _out5 = torch.argmax(_out5, dim=-1)[0]
        _out6 = torch.argmax(_out6, dim=-1)[0]
        _out0 = [inv_token_ids[i] for i in np.array(_out0)]
        _out1 = [inv_token_ids[i] for i in np.array(_out1)]
        _out2 = [inv_token_ids[i] for i in np.array(_out2)]
        _out3 = [inv_token_ids[i] for i in np.array(_out3)]
        _out4 = [inv_token_ids[i] for i in np.array(_out4)]
        _out5 = [inv_token_ids[i] for i in np.array(_out5)]
        _out6 = [inv_token_ids[i] for i in np.array(_out6)]
        # print("----------")
        # print(_out0)
        # print(_out1)
        # print(_out2)
        # print(_out3)
        # print(_out4)
        # print(_out5)
        # print(_out6)
        # print(data[i][-2], " <<<<==== ground truth")

        if _out6 != data[i][-2]:
            break

        if "I" not in _out0:
            break

        index_I = _out0[::-1].index("I")

        _input3 = [one_hot(token_ids[t], len(token_ids)) for t in _out0]
        _input6 = [one_hot(token_ids[t], len(token_ids)) for t in _out1]
        _input8 = [one_hot(token_ids[t], len(token_ids)) for t in _out2]
        _input9 = [one_hot(token_ids[t], len(token_ids)) for t in _out3]
        _input10 = [one_hot(token_ids[t], len(token_ids)) for t in _out4]
        _input11 = [one_hot(token_ids[t], len(token_ids)) for t in _out5]
        _input7 = [one_hot(token_ids[t], len(token_ids)) for t in _out6]

        if index_I >= len(str(a)) and "#" in _out5:  # eos
            break

    try:
        cot_output = int("".join(_out6).strip())
    except Exception as e:
        cot_output = None
    if cot_output is not None and cot_output == c:
        correct = True

    if not correct:
        print("wrong: ", a, b, c)
    return correct
