import asyncio
import random
import itertools
import operator
import torch
from abduce_sort import run_pl, parse_pl_result_dyadic, perception_to_kb_1, perception_to_kb_2
from perception import train, test
from data import DyadicDataset


# get the indices of the "sorted" task
def get_task1(all_examples):
    re = []
    for i, s in enumerate(all_examples):
        if s.srtd == 1:
            re.append(i)
    return re


# Sample batches
def sample_batches(all_examples, batch_size, shuffle=True):
    n_samples = len(all_examples)
    n_batches = n_samples // batch_size + \
        (1 if n_samples % batch_size > 0 else 0)
    re = []

    indices = list(range(n_samples))
    if shuffle:
        random.shuffle(indices)
    for i in range(n_batches):
        batch = indices[i*batch_size:(i+1)*batch_size]
        re.append(batch)
    return re


# Sample curriculum
def sample_curriculums(all_examples, batch_vars, n_batches=500, shuffle=True):
    print("Sampling batches with {} variables".format(batch_vars), end="...")
    n_samples = len(all_examples)
    re = []

    indices = list(range(n_samples))
    if shuffle:
        random.shuffle(indices)

    skip_list = []
    i = 0
    while i < n_samples and len(re) < n_batches:
        if i % 1000 == 0:
            print(i, end="...")
        idx = indices[i]
        i = i + 1
        if idx in skip_list:
            continue
        else:
            skip_list.append(idx)
            n_vars = len(all_examples[idx].x_idxs)
            if batch_vars >= n_vars and batch_vars - n_vars <= 1:
                batch = [idx]
                re.append(batch)
            elif batch_vars - n_vars >= 2:
                batch = [idx]
                for j in range(n_samples):
                    if n_vars >= batch_vars:
                        break
                    idx2 = indices[j]
                    if idx2 in skip_list:
                        continue
                    else:
                        j_vars = len(all_examples[idx2].x_idxs)
                        if n_vars+j_vars <= batch_vars:
                            batch.append(idx2)
                            skip_list.append(idx2)
                            n_vars = n_vars+j_vars
                        else:
                            break
                re.append(batch)
    print("done ({} batches).".format(len(re)))
    return re


async def abduce_coroutine_dyadic(i, model, samples, all_examples,
                                  train_img_data,
                                  bk_file,
                                  pl_file_dir, timeout,
                                  task=1):
    pl_file_path = pl_file_dir + '{}_bk.pl'.format(i)
    if task == 1:
        perception_to_kb_1(model, samples, all_examples, train_img_data,
                           bk_file=bk_file,
                           path=pl_file_path)
    elif task == 2:
        perception_to_kb_2(model, samples, all_examples, train_img_data,
                           bk_file=bk_file,
                           path=pl_file_path)

    pl_err, pl_out = await run_pl(file_path=pl_file_path, timeout=timeout)
    if pl_err != 0:
        print("[{}]".format(i), end=" ")
    else:
        print("{}!".format(i), end=" ")
    return {'err': pl_err, 'out': pl_out}


async def safe_abduce_dyadic(i, sem, model, samples, all_examples,
                             train_img_data,
                             bk_file,
                             pl_file_dir, timeout, task=1):
    async with sem:  # semaphore limits num of simultaneous downloads
        return await abduce_coroutine_dyadic(i, model, samples, all_examples,
                                             train_img_data,
                                             bk_file, pl_file_dir,
                                             timeout, task=task)


async def abduce_concurrent_dyadic(sem, batches, model, all_examples,
                                   train_img_data, bk_file,
                                   pl_file_dir, timeout,
                                   task=1):
    tasks = [
        # creating task starts coroutine
        asyncio.ensure_future(safe_abduce_dyadic(i, sem, model, samples,
                                                 all_examples, train_img_data,
                                                 bk_file,
                                                 pl_file_dir, timeout,
                                                 task=task))
        for i, samples
        in enumerate(batches)
    ]
    return await asyncio.gather(*tasks)  # await moment all tasks done


# Learn neural net with metagol induction
def train_abduce_concurrent_dyadic(loop, sem, batches, model, optimizer,
                                   scheduler, all_examples,
                                   train_img_data, test_dyadic_data,
                                   bk_file='../sort1_bk.pl',
                                   pl_file_dir='../prolog/tmp/',
                                   timeout=10, device=torch.device("cuda"),
                                   epochs=10, log_interval=500, task=1,
                                   **kwargs):
    # get all pairs and construct a dyadic dataset
    train_pairs = []  # images for updating neural model
    train_targets = []  # logically abduced labels

    # Start concurrent abduction
    try:
        tasks_results = loop.run_until_complete(
            abduce_concurrent_dyadic(sem, batches, model,
                                     all_examples, train_img_data,
                                     bk_file,
                                     pl_file_dir, timeout,
                                     task=task)
        )
    finally:
        loop.run_until_complete(loop.shutdown_asyncgens())

    succ = 0
    progs = []
    for i, samples in enumerate(batches):
        pl_err = tasks_results[i]['err']
        pl_out = tasks_results[i]['out']
        if pl_err == 0:
            succ = succ + 1
            # gather the outputs
            prog_str, pairs, targets = \
                parse_pl_result_dyadic(pl_out)
            # print("\n{}:".format(i))
            # print(pairs)
            # for s in samples:
            #     print(all_examples[s].x, end=' ')
            train_pairs = train_pairs + pairs
            train_targets = train_targets + targets
            progs.append(prog_str)
        else:
            # if anything wrong happend, just skip this batch
            continue

    # Change the targets of training images
    acc = 0
    for i, p in enumerate(train_pairs):
        ground_truth = 1 if train_img_data.targets[p[0]
                                                   ] > train_img_data.targets[p[1]] else 0
        if train_targets[i] == ground_truth:
            acc = acc + 1

    print("\nMost Frequent Program:")
    print(most_common(progs), end="")
    print("Successfully abduced batches: {}/{}".format(succ, len(batches)))
    print("Abduced label Acc: {}".format(acc/len(train_pairs)))

    abduced_dataset = DyadicDataset(train_pairs, train_targets, train_img_data)
    sup_train_loader = torch.utils.data.DataLoader(abduced_dataset, **kwargs)
    sup_test_loader = torch.utils.data.DataLoader(test_dyadic_data, **kwargs)

    # Update neural model
    for epoch in range(1, epochs):
        train(model, device, sup_train_loader,
              optimizer, epoch, log_interval)
        test(model, device, sup_test_loader)
        scheduler.step()
    return tasks_results


def most_common(L):
    # get an iterable of (item, iterable) pairs
    SL = sorted((x, i) for i, x in enumerate(L))
    # print 'SL:', SL
    groups = itertools.groupby(SL, key=operator.itemgetter(0))
    # auxiliary function to get "quality" for an item

    def _auxfun(g):
        item, iterable = g
        count = 0
        min_index = len(L)
        for _, where in iterable:
            count += 1
            min_index = min(min_index, where)
        # print 'item %r, count %r, minind %r' % (item, count, min_index)
        return count, -min_index
    # pick the highest-count/earliest item
    return max(groups, key=_auxfun)[0]
