import asyncio

import torch
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import yaml

from learn_sort import get_task1, sample_curriculums, train_abduce_concurrent_dyadic
from data import load_mnist, DyadicDataset
from perception import PairCMLP


def main():
    task_name = 'sort_5'
    data_path = '../data/dyadic/'
    tmp_bk_dir = '../prolog/tmp_pl/'
    file_name = data_path + task_name + '.yaml'

    with open(file_name, 'r') as f:
        all_examples = yaml.full_load(f)

    examples = all_examples[0:3000]

    imgs_train, imgs_test = load_mnist()

    with open('../data/dyadic/test_pairs.yaml', 'r') as f:
        test_dydata_dict = yaml.full_load(f)
    test_dydata = DyadicDataset(test_dydata_dict['pairs'],
                                test_dydata_dict['targets'],
                                imgs_test)

    all_img_indices = []
    for e in examples:
        all_img_indices = all_img_indices + e.x_idxs

    device = torch.device("cuda")
    p_model = PairCMLP(10, 2).to(device)
    print(p_model)
    kwargs = {'batch_size': 64}
    kwargs.update({'num_workers': 8,
                   'pin_memory': True,
                   'shuffle': True},
                  )

    # Learning with abduction
    EPOCHS = 50
    EPOCHS_STAGE1 = 5
    N_BATCHES = 3000
    N_CORES = 20
    NN_EPOCHS = 10
    LR = 1e-4
    GAMMA = 1.0
    LOG_INTERVAL = 500

    optimizer = optim.Adam(p_model.parameters(), lr=LR)
    scheduler = StepLR(optimizer, step_size=1, gamma=GAMMA)

    sem = asyncio.Semaphore(N_CORES)
    loop = asyncio.get_event_loop()

    task1_examples = [examples[i] for i in get_task1(examples)]

    try:
        for T in range(EPOCHS):
            if T < EPOCHS_STAGE1:
                # learn "sorted"
                tsk = 1
                print("======\nEpoch {}: Task {}\n======".format(T, tsk))
                batches = sample_curriculums(
                    task1_examples, 5, n_batches=N_BATCHES, shuffle=True)

                tmp_bk_dir = tmp_bk_dir
                exp = task1_examples
                bk_file = '../sort1_bk.pl'
            else:
                # learn "sorting"
                tsk = 2
                print("======\nEpoch {}: Task {}\n======".format(T, tsk))
                batches = sample_curriculums(
                    examples, 5, n_batches=N_BATCHES, shuffle=True)

                tmp_bk_dir = tmp_bk_dir
                exp = examples
                bk_file = '../sort2_bk.pl'

            train_abduce_concurrent_dyadic(loop, sem, batches,
                                           p_model, optimizer,
                                           scheduler, exp,
                                           imgs_train, test_dydata,
                                           bk_file=bk_file,
                                           pl_file_dir=tmp_bk_dir,
                                           epochs=NN_EPOCHS,
                                           timeout=30,
                                           log_interval=LOG_INTERVAL,
                                           task=tsk,
                                           **kwargs)

            torch.save(p_model, 'models/sort_5.pt')
    finally:
        loop.run_until_complete(loop.shutdown_asyncgens())
        loop.close()


if __name__ == "__main__":
    main()
