import json
import os
import random
import re
from tqdm import tqdm

from src.dlv_utils.dlv import DLVHandler

samples_path ='logicalDatasets/filtered_samples/generated_dataset_af_0_5_rf_0_0_ar_0_5_rr_0_0_pac_0.1_pn_0.5_pd_0.2_pc_0.2_acm_2_pcv_0.3_gw_false_grw_false_gn_100_s_42_sn_9825.jsonl'
save_checked_result = True
save_path = samples_path+'.checked'
check_num = -1

with open(samples_path, 'r') as f:
    samples = [json.loads(line) for line in f]
print('Number of samples:', len(samples))


if check_num < 0:
    selected_samples = samples
else:
    selected_samples = random.sample(samples, k= check_num)

dlv_helper = DLVHandler('check.dlv', useDlv2=True)

error_count = 0
checked_samples = []
for s in tqdm(selected_samples):
    uni_f = []
    for fact in s['facts']:
        if fact not in uni_f:
            uni_f.append(fact)
    s['facts'] = uni_f
    facts = uni_f
    rules = s['rules']

    for f in facts:
        dlv_helper.add_fact(f)

    for r in rules:
        dlv_helper.add_rule(r)

    dlv_r = dlv_helper.run_and_get_results()
    error_flag = 0
    new_query = []
    for query in s['queries']:
        query_text = query['query']
        query_label = query['label']
        if query_label == 'M':
            if query_text in dlv_r:
                error_flag = 1
                break

            if query_text not in s['facts']:
                new_query.append(query)
        else:
            if query_text not in dlv_r:
                error_flag = 2
                break

            if query_label == 'F':
                query_text = '-' + query_text

            if query_text not in s['facts']:
                new_query.append(query)
            else:
                s[query_label] -= 1

    if s['id'] == 'sample_Kn3xex9UZHn4KtvCqLfJ':
        pass

    if (error_flag != 0) or (len(new_query)) == 0 or (s['T'] + s['F'] == 0):
        error_count += 1
    else:
        s['queries'] = new_query
        checked_samples.append(s)

print('Error count:', error_count)

if save_checked_result:
    with open(save_path, 'w') as f:
        for s in checked_samples:
            f.write(json.dumps(s)+'\n')
    print('Save checked result to', save_path)
    print('Number of checked samples:', len(checked_samples))