import asyncio
import ast
from perception import pseudo_label_pairs_probs


# Perceives the sample of images in each example and convert the result to a
# Prolog background knowledge file
def perception_to_kb_1(model, sample_indices,
                       all_examples, all_imgs_data,
                       bk_file='../sort1_bk.pl',
                       path='../prolog/tmp_bk.pl'):
    header = ":- ['{}'].\n\n".format(bk_file)
    # 1. generate variable pairs lists and example string
    var_pairs, idx_pairs, pos_sample_str, mapping_str = gen_vpairs_and_exlist(
        sample_indices, all_examples, monadic=True)
    # 2. use perception model to calculate the probabilistic distribution p(z|x)
    prob_dist, ground_truths = pseudo_label_pairs_probs(
        model, idx_pairs, all_imgs_data)
    # 3. generate string of all "nn('X1',1,p1).\nnn('X1',2,p2)..."
    prob_facts = gen_dyadic_prob_facts(var_pairs, idx_pairs, prob_dist)
    # random negative example
    neg_sample_str = "[s([9,8,7,6,5,4,3,1,2])]"
    # generate string of query "learn :- Pos=..., metaabd(Pos,[])."
    query_str = "\na :- Pos={}, Neg={}, metaabd(Pos,Neg).\n".format(
        pos_sample_str, neg_sample_str)
    with open(path, 'w') as pl:
        pl.write(header)
        pl.write(prob_facts)
        pl.write(mapping_str)
        pl.write(query_str)
    return var_pairs, idx_pairs


def perception_to_kb_2(model, sample_indices,
                       all_examples, all_imgs_data,
                       bk_file='../sort2_bk.pl',
                       path='../prolog/tmp_bk.pl'):
    header = ":- ['{}'].\n\n".format(bk_file)
    # 1. generate variable pairs lists and example string
    var_pairs, idx_pairs, pos_sample_str, mapping_str = gen_vpairs_and_exlist(
        sample_indices, all_examples, monadic=False)
    # 2. use perception model to calculate the probabilistic distribution p(z|x)
    prob_dist, ground_truths = pseudo_label_pairs_probs(
        model, idx_pairs, all_imgs_data)
    # 3. generate string of all "nn('X1',1,p1).\nnn('X1',2,p2)..."
    prob_facts = gen_dyadic_prob_facts(var_pairs, idx_pairs, prob_dist)
    # random negative example
    neg_sample_str = "[f([9,8,7,6,4,5],[1,2,3,4,5,6]),f([1,2,3,4,5,6,7],[7,5,6,4,3,2,1])]"
    # generate string of query "learn :- Pos=..., metaabd(Pos,[])."
    query_str = "\na :- Pos={}, Neg={}, metaabd(Pos,Neg).\n".format(
        pos_sample_str, neg_sample_str)
    with open(path, 'w') as pl:
        pl.write(header)
        pl.write(prob_facts)
        pl.write(mapping_str)
        pl.write(query_str)
    return var_pairs, idx_pairs


def pairwise(iterable):
    n = len(iterable)
    re = []
    for i in range(n):
        for j in range(i+1, n):
            re.append([iterable[i], iterable[j]])
    return re


def gen_vpairs_and_exlist(sample_indices, all_examples, monadic=False):
    re1 = []
    re2 = []
    re3 = "["
    map_facts = ""
    cnt = 0
    for i in sample_indices:
        sample = all_examples[i]
        vlist = []
        in_str = "["
        # for each example, generate pairwise facts
        for j in range(len(sample.x)):
            vname = 'X' + str(cnt)
            vlist.append(vname)
            in_str = in_str + "'" + vname + "',"
            cnt = cnt + 1
            map_facts = map_facts + \
                "id('{}',{}).\n".format(vname, sample.x_idxs[j])
        in_str = in_str[:-1] + "]"
        if monadic:
            ex_str = "s({0:s})".format(in_str)
        else:
            out_str = str(sample.y)
            ex_str = "f({0:s},{1:s})".format(in_str, out_str)
        vpairs = pairwise(vlist)
        ipairs = pairwise(sample.x_idxs)
        re1.append(vpairs)
        re2.append(ipairs)
        re3 = re3 + ex_str + ","
    re3 = re3[:-1] + "]"
    return re1, re2, re3, map_facts


# Generate probabilistic facts
def gen_prob_facts(var_list, label_names, prob_dist):
    n = len(var_list)
    assert len(prob_dist) == n
    re = ""
    for i in range(n):
        # for each example
        m = len(var_list[i])
        assert len(prob_dist[i]) == m
        for j in range(m):
            # for each 'X'
            vname = var_list[i][j]
            prob = prob_dist[i][j]
            for k in range(len(label_names)):
                lname = label_names[k]
                fact_str = "nn('{}',{},{}).\n".format(vname, lname, prob[k])
                re = re + fact_str
    return re


def gen_dyadic_prob_facts(var_pairs, idx_pairs, prob_dist):
    """
    Generate dyadic probabilistic facts for a batch of examples
    """
    n = len(var_pairs)  # batch length
    assert len(idx_pairs) == n
    re = ""
    for i in range(n):
        # for each example
        m = len(var_pairs[i])
        assert len(idx_pairs[i]) == m
        for j in range(m):
            # for each pair
            vpair = var_pairs[i][j]
            prob = prob_dist[i][j][1]  # only consider the probability of true
            fact_str = "nn('{}','{}',{}).\n".format(vpair[0], vpair[1], prob)
            re = re + fact_str
    return re


# Run prolog to get the output.
# Return the STDOUT and error codes (-1 for runtime error, -2 for timeout)
async def run_pl(file_path='../prolog/tmp_bk.pl', timeout=10):
    cmd = "swipl --stack-limit=8g -s {} -g a -t halt; rm -f {}".format(
        file_path, file_path)
    proc = await asyncio.create_subprocess_shell(
        cmd,
        stdout=asyncio.subprocess.PIPE,
        stderr=asyncio.subprocess.PIPE)

    try:
        # 2 seconds timeout
        stdout, stderr = await asyncio.wait_for(proc.communicate(),
                                                timeout=timeout)
        if proc.returncode == 0:
            return 0, stdout.decode('UTF-8')
        else:
            return -1, stderr.decode('UTF-8')  # runtime error
    except asyncio.TimeoutError as e:
        if proc.returncode is None:
            proc.kill()
        return -2, "Timeout " + str(e)  # timeout error


# Get the output results, which are abduced labels and the hypothesis.
def parse_pl_result_dyadic(pl_out_str):
    prog_str, pairs_str = read_pl_out_dyadic(pl_out_str)
    pairs_pos = ast.literal_eval(pairs_str)
    pairs_neg = [[p[1], p[0]] for p in pairs_pos]
    targets_pos = [1 for i in range(len(pairs_pos))]
    targets_neg = [0 for i in range(len(pairs_neg))]
    pairs_tot = pairs_pos + pairs_neg
    targets_tot = targets_pos + targets_neg

    return prog_str, pairs_tot, targets_tot


def read_pl_out_dyadic(pl_out_str):
    prog = ""
    pairs = None

    prog_start = False
    pairs_start = False
    for line in pl_out_str.splitlines():
        if line[0] == '-':
            if line[2:-1] == 'Program':
                prog_start = True
                continue
            elif line[2:-1] == 'Abduced Facts':
                prog_start = False
                pairs_start = True
                continue
        if prog_start:
            prog = prog + line + "\n"
        if pairs_start:
            pairs = line
    return prog, pairs
