import json
import os

from tqdm import tqdm

from src.template_utils.modify import check_rules_if_contain_default_negation, add_fact_against_default_rule
from src.utils.eval_utils import get_overall_datasets
from src.dlv_utils.dlv import DLVHandler



def main():
    save_path_base = 'datasets/nm_dataset'
    if not os.path.exists(save_path_base):
        os.makedirs(save_path_base)

    datasets = get_overall_datasets()
    dlv2 = DLVHandler(useDlv2=True)
    for dname, dataset in tqdm(datasets.items()):
        if 'language' in dname:
            continue

        # filter the sample containing default negation
        # acquire the min rules and facts
        count_DN = 0
        DN_dataset = []
        for sample in tqdm(dataset):
            facts = sample['facts']
            rules = sample['rules']

            rules_flag = [1] * len(rules)
            facts_flag = [1] * len(facts)
            o_result = set(dlv2.quick_run(facts, rules)[1:-1].split(', '))

            for i in range(len(facts)):
                facts_flag[i] = 0
                facts_selected = [facts[i] for i in range(len(facts)) if facts_flag[i] == 1]
                n_result = set(dlv2.quick_run(facts_selected, rules)[1:-1].split(', '))
                if o_result != n_result:
                    facts_flag[i] = 1
            facts_selected = [facts[i] for i in range(len(facts)) if facts_flag[i] == 1]

            for i in range(len(rules)):
                rules_flag[i] = 0
                rules_selected = [rules[i] for i in range(len(rules)) if rules_flag[i] == 1]
                n_result = set(dlv2.quick_run(facts_selected, rules_selected)[1:-1].split(', '))
                if o_result != n_result:
                    rules_flag[i] = 1
            rules_selected = [rules[i] for i in range(len(rules)) if rules_flag[i] == 1]
            hasDN = check_rules_if_contain_default_negation(rules_selected)
            if hasDN:
                count_DN += 1
                DN_dataset.append({'sample': sample, 'rules_min': rules_selected, 'facts_min': facts_selected})
        print('count_DN:', count_DN)

        # build new fact against default negation
        df_count = 0
        DN_handled_dataset = []
        for sample in tqdm(DN_dataset):
            facts = sample['facts_min']
            rules = sample['rules_min']
            o_result = set(dlv2.quick_run(facts, rules)[1:-1].split(', '))
            new_fact, o_answer = add_fact_against_default_rule(facts, rules, dlv2)
            if new_fact is None:
                continue

            n_result = set(dlv2.quick_run(facts + [new_fact], rules)[1:-1].split(', '))
            if o_result != n_result:
                df_count += 1
                DN_handled_dataset.append({**sample, 'new_fact': new_fact, 'o_answer': o_answer})
        print('df p:', df_count / len(DN_dataset))
        DN_final_dataset = []

        # build irrelated facts and rules
        for sample in tqdm(DN_handled_dataset):
            facts = sample['facts_min']
            rules = sample['rules_min']
            new_fact = sample['new_fact']
            o_answer = sample['o_answer']

            o_result = set(dlv2.quick_run(facts, rules)[1:-1].split(', '))
            irrelated_facts = []
            for fact in sample['sample']['facts']:
                if fact not in facts:
                    n_result = set(dlv2.quick_run(facts + irrelated_facts + [new_fact], rules)[1:-1].split(', '))
                    if o_result == n_result:
                        irrelated_facts.append(fact)

            irrelated_rules = []
            for rule in sample['sample']['rules']:
                if rule not in rules:
                    n_result = set(dlv2.quick_run(facts, rules + irrelated_rules)[1:-1].split(', '))
                    if o_result == n_result:
                        irrelated_rules.append(rule)

            query = o_answer
            label = 'T'
            if '-' in query:
                query = query[1:]
                label = 'F'

            if not (query in dlv2.quick_run(facts+irrelated_facts, rules+irrelated_rules) and
                    query not in dlv2.quick_run(facts+irrelated_facts+[new_fact], rules+irrelated_rules)):
                continue

            DN_final_dataset.append({
                'Source_ID': sample['sample']['id'],
                'facts': facts,
                'irrelated_facts': irrelated_facts,
                'rules': rules,
                'irrelated_rules': irrelated_rules,
                'queries': [
                    {'query': query, 'label': label, 'type': 'original'},
                    {'query': query, 'label': 'M', 'new_fact': new_fact, 'type': 'new'}
                ]
            })

        # save to jsonl
        with open(os.path.join(save_path_base, f'{dname}.jsonl'), 'w') as f:
            for sample in DN_final_dataset:
                f.write(json.dumps(sample) + '\n')
        print(f'{dname} done')

if __name__ == '__main__':
    main()
