import numpy as np
import torch


vocab = ["?", "0", "1", "#"]  # '#' is 'EOS'
token_ids = {x: i + 1 for i, x in enumerate(vocab)}
token_ids[" "] = 0
inv_token_ids = {v: k for k, v in token_ids.items()}


def generate_two_line_parity(length, from_beginning=False):
    _len = np.random.randint(1, length + 1)
    a = [int(np.random.choice([0, 1])) for _ in range(_len)]
    final = []
    t = 0
    for aa in a:
        t = (t + aa) % 2
        final.append(t)

    if from_beginning:
        i = 0
    else:
        i = np.random.randint(0, _len + 1)
    b = final[:i]

    if i == _len:
        ans = [str(_) for _ in final] + ["#"]
    else:
        ans = [str(_) for _ in final[:i+1]] + ["?"]

    line1 = [str(_) for _ in a]
    line2 = [str(_) for _ in b] + ["?"]
    final = [str(_) for _ in final]

    line1 = "".join(line1) + " "
    line2 = "".join(line2)
    ans = "".join(ans)
    final = "".join(final)
    return line1, line2, ans, final


def one_hot(i, num):
    onehot = [0] * num
    onehot[i] = 1
    return onehot


class Generator:
    def __init__(self, length):
        self._length = length
        self._generator = self.generate()

    def __next__(self):
        return next(self._generator)

    def generate(self):
        while True:
            maxlen = self._length + 1
            line1, line2, cot, final = generate_two_line_parity(self._length)
            if np.random.random() < 1.0:
                i = np.random.randint(0, len(cot))
                cot = cot[:i] + np.random.choice(vocab) + cot[i+1:]
            mask = [1] * len(line1) + [0] * (maxlen - len(line1))
            line1 = line1 + " " * (maxlen - len(line1))
            line2 = line2 + " " * (maxlen - len(line2))
            cot = cot + " " * (maxlen - len(cot))

            _line1 = [token_ids[t] for t in line1]
            _line2 = [token_ids[t] for t in line2]
            _line1_onehot = [one_hot(i, len(token_ids)) for i in _line1]
            _line2_onehot = [one_hot(i, len(token_ids)) for i in _line2]

            _y = [token_ids[t] for t in cot]

            yield (line1, line2, cot, final), _line1_onehot, _line2_onehot, _y, mask


def evaluate_two_line_parity(model, length=10):
    line1, line2, ans, final = generate_two_line_parity(length, from_beginning=True)

    print("=============================")
    print(line1, final)

    maxlen = len(line1)
    mask = [1] * maxlen
    correct = False

    line1 = line1 + " " * (maxlen - len(line1))
    _line1 = [token_ids[t] for t in line1]
    _line1_onehot = [one_hot(i, len(token_ids)) for i in _line1]
    try:
        for i in range(maxlen + 1):
            line2 = line2 + " " * (maxlen - len(line2))
            _line2 = [token_ids[t] for t in line2]
            _line2_onehot = [one_hot(i, len(token_ids)) for i in _line2]
            _out, _dist, _ = model(torch.tensor([_line1_onehot], dtype=torch.float),
                                                    torch.tensor([_line2_onehot], dtype=torch.float),
                                                    torch.tensor([mask], dtype=torch.float))
            _out = torch.argmax(_out, dim=-1)[0]
            _out = [inv_token_ids[i] for i in np.array(_out)]

            line2 = "".join(_out)
            # print(line2)
            if "#" in line2:
                break
        if line2.strip().strip("#") == final.strip():
            correct = True
    except:
        pass
    if not correct:
        print("wrong: ", line1, line2, final)
    return correct
