import numpy as np
import torch
import pickle as pkl

from one_line_addition.model import ParseModel
from one_line_addition.data import Generator, token_ids, evaluate_one_line_addition


batch_size = 256
g = Generator(length=7)
model = ParseModel(len(token_ids))
optim = torch.optim.Adam(model.parameters(), lr=0.0001)


step = 0
loss_list = list()
correct_list = list()
while step < 100000:
    # break
    step += 1
    input0, output0, mask = list(), list(), list()
    for _ in range(batch_size):
        _input0, _output0, _mask = next(g)
        input0.append(_input0)
        output0.append(_output0)
        mask.append(_mask)
    input0 = torch.tensor(input0, dtype=torch.float)
    output0 = torch.tensor(output0, dtype=torch.float)
    mask = torch.tensor(mask, dtype=torch.float32)
    out0, dist0, loss = model(input0, mask, output0)
    loss = loss.mean()
    total_loss = loss
    loss_list.append(loss.item())
    if step % 10 == 0:
        print(f"step {step}, loss {sum(loss_list[-10:]) / 10}")

    optim.zero_grad()
    total_loss.backward()
    optim.step()

    if step % 100 == 0:
        correct = dict()
        # correct[5] = 0.1 * sum([evaluate_one_line_addition(model, length=5) for _ in range(10)])
        # correct[6] = 0.1 * sum([evaluate_one_line_addition(model, length=6) for _ in range(10)])
        correct[7] = 0.1 * sum([evaluate_one_line_addition(model, length=7) for _ in range(10)])
        correct[8] = 0.1 * sum([evaluate_one_line_addition(model, length=8) for _ in range(10)])
        correct[9] = 0.1 * sum([evaluate_one_line_addition(model, length=9) for _ in range(10)])
        correct[10] = 0.1 * sum([evaluate_one_line_addition(model, length=10) for _ in range(10)])
        correct[15] = 0.1 * sum([evaluate_one_line_addition(model, length=15) for _ in range(10)])
        correct[20] = 0.1 * sum([evaluate_one_line_addition(model, length=20) for _ in range(10)])
        correct_list.append(correct)
        print("correct", correct)
        pkl.dump(correct_list, open("one_line_addition/result.pkl", "wb"))

state = {"model": model.state_dict()}
torch.save(state, "one_line_addition/model.pth")


# ckpt = torch.load("one_line_addition/model.pth")
# toload_state_dict = ckpt["model"]
# model.load_state_dict(toload_state_dict)

