import numpy
import torch
from utils import Timer, get_device
from setting import save_model_state_dict, load_bats_for_train
from wv_wrapper import (BaseWord2VecWrapper, IdentityWord2VecWrapper,
                        MLPConcatWord2VecWrapper, MLPWord2VecWrapper, AGNWord2VecWrapper)


def test_word_analogies(wv, dic):  # dic: {category: [(a, b, c, d), ...]}
    correct_acc, all_acc = 0, 0
    for category, examples in dic.items():
        correct, all = 0, 0
        for a, b, c, d in examples:
            if a not in wv or b not in wv or c not in wv or d not in wv:
                continue
            correct += d == wv.most_similar(positive=[b, c], negative=[a], topn=1)[0][0]  # [(word, similarity)]
            all += 1
        print(f"{category}:\n{correct / all} ({correct} / {all})", flush=True)
        correct_acc += correct
        all_acc += all
    print(f"[overall] {correct_acc / all_acc} ({correct_acc} / {all_acc})", flush=True)


def test(model, data_ds, batch_size=64, num=-1):
    timer = Timer()
    a_batch, b_batch, c_batch, ds_batch = data_ds
    if num == -1:
        num = len(a_batch)
    perm = numpy.random.permutation(len(a_batch))[:num]
    correct = model.eval_data((a_batch[perm], b_batch[perm], c_batch[perm], ds_batch[perm]))
    print(f"{correct / len(perm)} ({correct} / {num})", flush=True)
    print(timer.stop(), flush=True)
    return correct / len(perm)


def train(model, train_abcd, train_ds, valid_ds, batch_size=32, epoch_num=100, learning_rate=1e-3,
          weight_decay=1e-4, verbose=True):
    print(model, flush=True)
    ta, tb, tc, td = train_abcd
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    for epoch in range(epoch_num):
        # print
        if verbose and epoch % 10 == 0:
            print(f"epoch {epoch}", flush=True)
            print("train:", flush=True)
            test(model, train_ds, num=1000)
            print("test:", flush=True)
            test(model, valid_ds, num=1000)

        model.train()
        all_loss = 0
        perm = numpy.random.permutation(len(ta))
        timer = Timer()
        for index in numpy.array_split(perm, (len(perm) + batch_size - 1) // batch_size):
            pred = model.forward_by_index(ta[index], tb[index], tc[index])
            loss = -model.batch_similarity(pred, model.normalized_vectors[td[index]]).mean()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            all_loss += loss * len(index)
        print(f"mean loss {all_loss / len(ta)}", flush=True)
        print(timer.stop())
    save_model_state_dict(model)
    return test(model, valid_ds)


if __name__ == "__main__":
    # model = BaseWord2VecWrapper()
    # model = IdentityWord2VecWrapper().to(get_device())
    model = AGNWord2VecWrapper(hidden_dim=64, layer_num=4).to(get_device())
    # model = MLPWord2VecWrapper(hidden_dim=32, layer_num=3).to(get_device())
    # model = MLPConcatWord2VecWrapper(hidden_dim=32, layer_num=3).to(get_device())
    # print(len(model.index_to_word))

    train_abcd, train_ds, valid_ds, test_ds = load_bats_for_train(model)
    print(len(train_abcd[0]), len(valid_ds[0]), len(test_ds[0]))

    train(model, train_abcd, train_ds, valid_ds, batch_size=32, epoch_num=1)
