import numpy as np
import torch

p = 7
R = 4
vocab = list("+-*/()") + [str(i) for i in range(p)]
token_ids = {x: i + 1 for i, x in enumerate(vocab)}
token_ids[" "] = 0
inv_token_ids = {v: k for k, v in token_ids.items()}


def add(a, b):
    return (a + b) % p


def minus(a, b):
    return (a - b) % p


def multiply(a, b):
    return (a * b) % p


def divided_by(a, b):
    for i in range(p):
        if (a + i * p) % b == 0:
            return (a + i * p) // b


def cal(a, sign, b):
    if sign == "+":
        return add(a, b)
    elif sign == "-":
        return minus(a, b)
    elif sign == "*":
        return multiply(a, b)
    elif sign == "/":
        return divided_by(a, b)
    else:
        return " "


def one_hot(i, num):
    onehot = [0] * num
    onehot[i] = 1
    return onehot


def generate_one_arithmetic(length, debug=False):
    pool = list()
    for _ in range(length):
        x = np.random.randint(0, p)
        pool.append((str(x), None, x))
    while True:
        i, j = np.random.choice([_ for _ in range(len(pool))], 2, replace=False)
        x_i, sign_i, v_i = pool[i]
        x_j, sign_j, v_j = pool[j]
        o_x_i, o_v_i = x_i, v_i
        new_sign = np.random.choice(list("+-*/"))
        if new_sign == "/" and v_j == 0:
            continue

        if sign_i is None:
            x_i = x_i
        elif new_sign in "+-":
            if sign_i in "*/":
                x_i = x_i
            else:
                x_i = x_i
        else:
            if sign_i in "*/":
                x_i = x_i
            else:
                x_i = f"({x_i})"

        if sign_j is None:
            x_j = x_j
        elif new_sign in "+-":
            if sign_j in "*/":
                x_j = x_j
            else:
                if new_sign == "+":
                    x_j = x_j
                else:
                    x_j = f"({x_j})"
        else:
            if sign_j in "*/":
                if new_sign == "*":
                    x_j = x_j
                else:
                    x_j = f"({x_j})"
            else:
                x_j = f"({x_j})"

        new_x = f"{x_i}{new_sign}{x_j}"
        new_v = cal(v_i, new_sign, v_j)

        if debug:
            print("1.", x_i, sign_i, v_i)
            print("2.", x_j, sign_j, v_j)
            print("3.", new_x, new_sign, new_v)

        if len(new_x) == length:
            return new_x, new_v
        elif len(new_x) > length:
            return o_x_i, o_v_i

        pool.pop(max(i, j))
        pool.pop(min(i, j))
        pool.append((new_x, new_sign, new_v))


def parse_in_2R(x, mid):
    n = len(x)
    if x[mid] == "(":
        left, right = mid, mid + 4
        if right < n and x[right] == ")":
            a, sign, b = int(x[left + 1]), x[left + 2], int(x[left + 3])
            mask = [0] * left + [1] * (right - left + 1) + [0] * (n - right - 1)
            return mask, (left, right), (a, sign, b)
    elif x[mid] == ")":
        left, right = mid - 4, mid
        if left >= 0 and x[left] == "(":
            a, sign, b = int(x[left + 1]), x[left + 2], int(x[left + 3])
            mask = [0] * left + [1] * (right - left + 1) + [0] * (n - right - 1)
            return mask, (left, right), (a, sign, b)
    elif x[mid] in "+-*/":
        left, right = mid - 1, mid + 1
        if left >= 0 and right < n and x[left] not in "()+-*/" and x[right] not in "()+-*/":
            a, sign, b = int(x[left]), x[left + 1], int(x[left + 2])
            if left - 1 >= 0 and right + 1 < n and x[left - 1] == "(" and x[right + 1] == ")":
                left, right = left - 1, right + 1
            mask = [0] * left + [1] * (right - left + 1) + [0] * (n - right - 1)
            return mask, (left, right), (a, sign, b)
    else:
        num_signs, on_left, on_right = 0, False, False
        if mid - 1 >= 0 and x[mid - 1] in "+-*/":
            num_signs += 1
            on_left = True
        if mid + 1 < n and x[mid + 1] in "+-*/":
            num_signs += 1
            on_right = True
        if num_signs == 1:
            if on_left:
                left, right = mid - 2, mid
            else:
                left, right = mid, mid + 2
            try:
                a, sign, b = int(x[left]), x[left + 1], int(x[left + 2])
            except:
                a, sign, b = None, None, None
            if left - 1 >= 0 and right + 1 < n and x[left - 1] == "(" and x[right + 1] == ")":
                left, right = left - 1, right + 1
            mask = [0] * left + [1] * (right - left + 1) + [0] * (n - right - 1)
            return mask, (left, right), (a, sign, b)
        elif num_signs == 2:
            if x[mid - 1] in "*/" or (x[mid - 1] in "+-" and x[mid + 1] in "+-"):
                left, right = mid - 2, mid
            else:
                left, right = mid, mid + 2
            try:
                a, sign, b = int(x[left]), x[left + 1], int(x[left + 2])
            except:
                a, sign, b = None, None, None
            mask = [0] * left + [1] * (right - left + 1) + [0] * (n - right - 1)
            return mask, (left, right), (a, sign, b)

    return [0] * n, (None, None), (None, None, None)


def parse_in_4R(x, mid):
    n = len(x)
    orig = max(0, mid - R)
    sub_x = x[orig: min(n, mid + R + 1)]
    sub_mask, (left, right), (a, sign, b) = parse_in_2R(sub_x, mid - orig)
    if sum(sub_mask) == 0 or a is None:
        return [0] * n, (None, None), (None, None, None)
    for i, _mid in enumerate(range(max(0, mid - R), min(n, mid + R + 1))):
        if sub_mask[i]:
            _orig = max(0, _mid - R)
            _x = x[_orig: min(n, _mid + R + 1)]
            _mask, (_left, _right), (_, _, _) = parse_in_2R(_x, _mid - _orig)
            if _orig + _left < orig + left or _orig + _right > orig + right:
                return [0] * n, (None, None), (None, None, None)
    return [0] * orig + sub_mask + [0] * (n - min(n, mid + R + 1)), (orig + left, orig + right), (a, sign, b)


def get_cot(x):
    lefts = dict()
    for mid in range(len(x)):
        orig = max(0, mid - 2 * R)
        _tmp = x[orig:min(len(x), mid + 2 * R + 1)]
        sub_mask, (left, right), (a, sign, b) = parse_in_4R(_tmp, mid - orig)
        if sum(sub_mask) > 0:
            lefts[orig + left] = (orig + left, orig + right, cal(a, sign, b))
    new = list(x)
    for k, v in lefts.items():
        l, r, c = v
        mid = (l + r) // 2
        for i in range(l, r + 1):
            new[i] = " "
        new[mid] = str(c)
    cot = ""
    for t in new:
        cot += t
    return cot


class Generator:
    def __init__(self, length):
        self._length = length
        self._generator = self.generate()

    def __next__(self):
        return next(self._generator)

    def generate(self):
        while True:
            x, y_x = generate_one_arithmetic(self._length, False)
            cot = get_cot(x)
            mask = [1] * len(x) + [0] * (self._length - len(x))
            x = x + " " * (self._length - len(x))
            cot = cot + " " * (self._length - len(cot))

            _x = [token_ids[t] for t in x]
            _x_onehot = [one_hot(i, len(token_ids)) for i in _x]
            _y = [token_ids[t] for t in cot]

            yield x, cot, _x_onehot, _y, mask


def evaluate_one_arithmetic(model, length=50):
    x, y_x = generate_one_arithmetic(length)
    tmp = x
    print(tmp, y_x)
    correct = False
    try:
        while len(tmp) > 1:
            _x = [token_ids[t] for t in tmp]
            _x_onehot = [one_hot(i, len(token_ids)) for i in _x]
            _out, _dist = model(torch.tensor([_x_onehot], dtype=torch.float),
                                torch.tensor([[1] * len(_x)], dtype=torch.float))
            _y = torch.argmax(_out, dim=-1)[0]
            new_tmp = ""
            for index in _y:
                index = index.item()
                if index > 0:
                    new_tmp += str(inv_token_ids[index])
            if len(new_tmp) >= len(tmp):
                assert False
            tmp = new_tmp
            print(tmp)
        if int(tmp) == int(y_x):
            correct = True
    except:
        pass
    if not correct:
        print("wrong: ", x, y_x)
    return correct


if __name__ == '__main__':
    x, y_x = generate_one_arithmetic(20, debug=True)
    tmp = x
    print(tmp, y_x)
    while len(tmp) > 1:
        candidates = list()
        for mid in range(len(tmp)):
            orig = max(0, mid - 2 * R)
            _tmp = tmp[orig:min(len(tmp), mid + 2 * R + 1)]
            sub_mask, (left, right), (a, sign, b) = parse_in_4R(_tmp, mid - orig)
            if sum(sub_mask) > 0:
                candidates.append((sub_mask, (orig, mid, left, right), (a, sign, b)))
        i = np.random.randint(0, len(candidates))
        sub_mask, (orig, mid, left, right), (a, sign, b) = candidates[i]
        c = cal(a, sign, b)
        tmp = f"{tmp[:orig + left]}{c}{tmp[orig + right + 1:]}"
        print(tmp)
