import torch
import pickle as pkl

from multiplication.model import ParseModel
from multiplication.data import Generator, token_ids, evaluate_multiplication


import ray
ray.init()


EVALUATION = True


batch_size = 512
g = Generator(length=5)
model = ParseModel(len(token_ids))
optim = torch.optim.Adam(model.parameters(), lr=0.000005)

if EVALUATION or True:
    ckpt = torch.load("multiplication/model.pth")
    toload_state_dict = ckpt["model"]
    model.load_state_dict(toload_state_dict)


step = 0
loss_list = list()
correct_list = list()
correct = None
while step < 50000:
    step += 1
    if step % 1000 == 0:
        if correct is not None:
            correct = {k: 0.1 * sum(ray.get(v)) for k, v in correct.items()}
            correct_list.append(correct)
            print("correct", correct)
            pkl.dump(correct_list, open("multiplication/result.pkl", "wb"))

        correct = dict()
        correct[5] = [ray.remote(evaluate_multiplication).remote(model, length=5) for _ in range(10)]
        correct[6] = [ray.remote(evaluate_multiplication).remote(model, length=6) for _ in range(10)]
        correct[7] = [ray.remote(evaluate_multiplication).remote(model, length=7) for _ in range(10)]
        correct[8] = [ray.remote(evaluate_multiplication).remote(model, length=8) for _ in range(10)]
        correct[9] = [ray.remote(evaluate_multiplication).remote(model, length=9) for _ in range(10)]
        correct[10] = [ray.remote(evaluate_multiplication).remote(model, length=10) for _ in range(10)]
        correct[15] = [ray.remote(evaluate_multiplication).remote(model, length=15) for _ in range(10)]
        correct[20] = [ray.remote(evaluate_multiplication).remote(model, length=20) for _ in range(10)]
        correct[25] = [ray.remote(evaluate_multiplication).remote(model, length=25) for _ in range(10)]

    if EVALUATION:
        continue

    (
        input0, input1, input2, input3, input4, input5, input6, input7, input8, input9, input10, input11,
        output0, output1, output2, output3, output4, output5, output6, mask
     ) = (list(), list(), list(), list(), list(), list(), list(), list(), list(), list(), list(), list(), list(),
          list(), list(), list(), list(), list(), list(), list())
    for _ in range(batch_size):
        (_input0, _input1, _input2, _input3, _input4, _input5, _input6, _input7, _input8, _input9, _input10, _input11,
         _output0, _output1, _output2, _output3, _output4, _output5, _output6, _mask) = next(g)
        input0.append(_input0)
        input1.append(_input1)
        input2.append(_input2)
        input3.append(_input3)
        input4.append(_input4)
        input5.append(_input5)
        input6.append(_input6)
        input7.append(_input7)
        input8.append(_input8)
        input9.append(_input9)
        input10.append(_input10)
        input11.append(_input11)
        output0.append(_output0)
        output1.append(_output1)
        output2.append(_output2)
        output3.append(_output3)
        output4.append(_output4)
        output5.append(_output5)
        output6.append(_output6)
        mask.append(_mask)
    input0 = torch.tensor(input0, dtype=torch.float)
    input1 = torch.tensor(input1, dtype=torch.float)
    input2 = torch.tensor(input2, dtype=torch.float)
    input3 = torch.tensor(input3, dtype=torch.float)
    input4 = torch.tensor(input4, dtype=torch.float)
    input5 = torch.tensor(input5, dtype=torch.float)
    input6 = torch.tensor(input6, dtype=torch.float)
    input7 = torch.tensor(input7, dtype=torch.float)
    input8 = torch.tensor(input8, dtype=torch.float)
    input9 = torch.tensor(input9, dtype=torch.float)
    input10 = torch.tensor(input10, dtype=torch.float)
    input11 = torch.tensor(input11, dtype=torch.float)
    output0 = torch.tensor(output0, dtype=torch.float)
    output1 = torch.tensor(output1, dtype=torch.float)
    output2 = torch.tensor(output2, dtype=torch.float)
    output3 = torch.tensor(output3, dtype=torch.float)
    output4 = torch.tensor(output4, dtype=torch.float)
    output5 = torch.tensor(output5, dtype=torch.float)
    output6 = torch.tensor(output6, dtype=torch.float)
    mask = torch.tensor(mask, dtype=torch.float32)
    out0, out1, out2, out3, out4, out5, out6, dist0, dist1, dist2, dist3, dist4, dist5, dist6, loss = model(
        input0, input1, input2, input3, input4, input5, input6, input7, input8, input9, input10, input11, mask,
        output0, output1, output2, output3, output4, output5, output6)
    loss = loss.mean()
    total_loss = loss
    loss_list.append(loss.item())
    if step % 10 == 0:
        print(f"step {step}, loss {sum(loss_list[-10:]) / 10}")

    optim.zero_grad()
    total_loss.backward()
    optim.step()


if not EVALUATION:
    state = {"model": model.state_dict()}
    torch.save(state, "multiplication/model.pth")
