import asyncio
import ast

import torch
import yaml
import numpy as np

from perception import pseudo_label_pairs_probs
from abduce_sort import gen_vpairs_and_exlist, gen_dyadic_prob_facts, run_pl

from data import load_mnist


async def eval_model_dyadic(i, model, learned_prog, pl_file_dir,
                            all_examples, imgs_data):
    var_pairs, idx_pairs, _, mapping_str = gen_vpairs_and_exlist(
        [i], all_examples)
    prob_dist, ground_truths = pseudo_label_pairs_probs(
        model, idx_pairs, imgs_data)
    prob_facts = gen_dyadic_prob_facts(var_pairs, idx_pairs, prob_dist)

    pl_file_path = pl_file_dir + '{}_bk.pl'.format(i)

    query_str = ":-['{}'].\n\na:-f({},Y),writeln(Y).".format(
        learned_prog, str(all_examples[i].x_idxs))

    with open(pl_file_path, 'w') as pl:
        pl.write(prob_facts)
        pl.write(mapping_str)
        pl.write(query_str)

    pl_err, pl_out = await run_pl(file_path=pl_file_path, timeout=60)

    if pl_err != 0 and i % 1000 == 0:
        print("[{}]".format(i), end=" ")
    elif pl_err == 0 and i % 1000 == 0:
        print("{}!".format(i), end=" ")
    return {'err': pl_err, 'out': pl_out}


async def safe_eval_model_dyadic(sem, i, model, learned_prog, pl_file_dir,
                                 all_examples, imgs_data):
    async with sem:  # semaphore limits num of simultaneous downloads
        return await eval_model_dyadic(i, model, learned_prog, pl_file_dir,
                                       all_examples, imgs_data)


async def eval_model_dyadic_concurrent(sem, model, all_examples, imgs_data,
                                       learned_prog, pl_file_dir):
    tasks = [
        # creating task starts coroutine
        asyncio.ensure_future(
            safe_eval_model_dyadic(sem, i, model, learned_prog,
                                   pl_file_dir, all_examples, imgs_data))
        for i
        in range(len(all_examples))
    ]
    return await asyncio.gather(*tasks)  # await moment all tasks done


def evaluate_model_dyadic(loop, sem, model, all_examples, imgs_data,
                          learned_prog, pl_file_dir):
    try:
        tasks_results = loop.run_until_complete(
            eval_model_dyadic_concurrent(sem, model, all_examples, imgs_data,
                                         learned_prog, pl_file_dir)
        )
    finally:
        loop.run_until_complete(loop.shutdown_asyncgens())

    # collect all results
    targets = []
    pred = []
    for i in range(len(all_examples)):
        pl_err = tasks_results[i]['err']
        pl_out = tasks_results[i]['out']
        ex = all_examples[i]
        if pl_err == 0:
            targets.append(ex.y)
            pred.append(ast.literal_eval(pl_out))
        else:
            pred.append(0)

    print()

    acc1 = 0
    acc2 = 0
    dim = len(targets[0])
    n = len(targets)
    for i in range(n):
        eq = np.equal(pred[i], targets[i])
        if eq.all():
            acc1 += 1
        acc2 += eq.sum()
    return acc1/n, acc2/(n*dim)


def test_sort(test_file, p_model_path, imgs_data, loop, sem):

    with open(test_file, 'r') as f:
        test_data = yaml.full_load(f)

    # test_data = test_data[0:1000]

    p_model = torch.load(p_model_path)

    try:
        acc1, acc2 = evaluate_model_dyadic(loop, sem, p_model, test_data,
                                           imgs_data,
                                           '../sort_learned.pl',
                                           '../prolog/tmp_sort_eval/')
    finally:
        loop.run_until_complete(loop.shutdown_asyncgens())

    return (acc1, acc2)


def main():

    data_path = '../../data/dyadic/'
    task_name = 'sort'

    test_file_3 = data_path + task_name + '_3_test' + '.yaml'
    test_file_5 = data_path + task_name + '_5_test' + '.yaml'
    test_file_7 = data_path + task_name + '_7_test' + '.yaml'

    _, imgs_test = load_mnist()
    '''
    p_models = ['models/sort_5_1.pt',
                'models/sort_5_2.pt',
                'models/sort_5_3.pt',
                'models/sort_5_4.pt',
                'models/sort_5_5.pt']
    '''
    p_models = ['models/sort_5_3.pt']

    sem = asyncio.Semaphore(16)
    loop = asyncio.get_event_loop()

    result_3 = []
    result_5 = []
    result_7 = []

    try:
        for p_model in p_models:
            print('Testing model {}:'.format(p_model))
            print('Length 3:', end=" ")
            r3_1, r3_2 = test_sort(test_file_3, p_model, imgs_test, loop, sem)
            print((r3_1, r3_2))
            result_3.append((r3_1, r3_2))
            print('Length 5:', end=" ")
            r5_1, r5_2 = test_sort(test_file_5, p_model, imgs_test, loop, sem)
            print((r5_1, r5_2))
            result_5.append((r5_1, r5_2))
            print('Length 7:', end=" ")
            r7_1, r7_2 = test_sort(
                test_file_7, p_model, imgs_test, loop, sem)
            print((r7_1, r7_2))
            result_7.append((r7_1, r7_2))
    finally:
        loop.run_until_complete(loop.shutdown_asyncgens())
        loop.close()

    print(result_3)
    print(result_5)
    print(result_7)


if __name__ == "__main__":
    main()
