import numpy as np
import torch
import pickle as pkl

from two_line_parity.model import ParseModel
from two_line_parity.data import Generator, token_ids, evaluate_two_line_parity


EVALUATION = True


batch_size = 256
g = Generator(length=7)
model = ParseModel(len(token_ids))
optim = torch.optim.Adam(model.parameters(), lr=0.0001)

if EVALUATION:
    ckpt = torch.load("two_line_parity/model.pth")
    toload_state_dict = ckpt["model"]
    model.load_state_dict(toload_state_dict)


step = 0
loss_list = list()
correct_list = list()
while step < 50000:
    step += 1
    if step % 1000 == 0:
        correct = dict()
        # correct[3] = 0.1 * sum([evaluate_two_line_parity(model, length=3) for _ in range(10)])
        # correct[5] = 0.1 * sum([evaluate_two_line_parity(model, length=5) for _ in range(10)])
        # correct[6] = 0.1 * sum([evaluate_two_line_parity(model, length=6) for _ in range(10)])
        correct[7] = 0.1 * sum([evaluate_two_line_parity(model, length=7) for _ in range(10)])
        correct[30] = 0.1 * sum([evaluate_two_line_parity(model, length=8) for _ in range(10)])
        correct[40] = 0.1 * sum([evaluate_two_line_parity(model, length=9) for _ in range(10)])
        correct[50] = 0.1 * sum([evaluate_two_line_parity(model, length=10) for _ in range(10)])
        correct[60] = 0.1 * sum([evaluate_two_line_parity(model, length=60) for _ in range(10)])
        correct[100] = 0.1 * sum([evaluate_two_line_parity(model, length=100) for _ in range(10)])
        correct[200] = 0.1 * sum([evaluate_two_line_parity(model, length=200) for _ in range(10)])
        correct[500] = 0.1 * sum([evaluate_two_line_parity(model, length=500) for _ in range(10)])
        correct[1000] = 0.1 * sum([evaluate_two_line_parity(model, length=1000) for _ in range(10)])
        correct_list.append(correct)
        print("correct", correct)
        pkl.dump(correct_list, open("two_line_parity/result.pkl", "wb"))

    if EVALUATION:
        continue

    line1, line2, y, mask = list(), list(), list(), list()
    for i in range(batch_size):
        _, _line1, _line2, _y, _mask = next(g)
        line1.append(_line1)
        line2.append(_line2)
        y.append(_y)
        mask.append(_mask)
    line1 = torch.tensor(line1, dtype=torch.float)
    line2 = torch.tensor(line2, dtype=torch.float)
    y = torch.tensor(y, dtype=torch.int)
    mask = torch.tensor(mask, dtype=torch.float32)
    _, _, loss = model(line1, line2, mask, y)
    loss = loss.mean()
    total_loss = loss
    loss_list.append(loss.item())
    if step % 10 == 0:
        print(f"step {step}, loss {sum(loss_list[-10:]) / 10}")

    optim.zero_grad()
    total_loss.backward()
    optim.step()


if not EVALUATION:
    state = {"model": model.state_dict()}
    torch.save(state, "two_line_parity/model.pth")
