import json
import re

from tqdm import tqdm


def transfer_predicate_to_nl(predicate, remove_quotation=False):
    # transfer symbolic predicates to natural language
    params = re.match(r'.*?\((.*?)\)', predicate).group(1)
    params = params.split(',')

    if len(params) > 1:
        if remove_quotation:
            params = [param[1:-1] for param in params]

        params_1 = params[:-1]
        params_1_str = ', '.join(params_1)
        nl_predicate = params_1_str + ' and ' + params[-1]
    else:
        nl_predicate = params[0]

    predicate_desc = re.match(r'(.*?)\(', predicate).group(1)
    if len(params) <= 1:
        nl_predicate += ' is '
    else:
        nl_predicate += ' are '

    if '-' in predicate:
        nl_predicate += 'not '
        predicate_desc = predicate_desc[1:]


    nl_predicate += predicate_desc
    return nl_predicate

def transfer_facts_to_nl(sample=None, facts=None):
    # transfer symbolic facts to natural language
    if facts is None:
        facts = sample['facts']
    nl_facts = []
    for fact in facts:
        objs = re.match(r'.*?\((.*?)\)', fact).group(1)
        objs = objs.split(',')
        objs = [obj[1:-1] for obj in objs]

        predicate = re.match(r'(.*?)\(', fact).group(1)

        objs_1 = objs[:-1]
        objs_1_str = ', '.join(objs_1)
        objs_str = objs_1_str
        if len(objs) > 1:
            objs_str += ' and ' + objs[-1]
        else:
            objs_str = objs[0]

        if len(objs) <= 1:
            objs_str += ' is '
        else:
            objs_str += ' are '

        if '-' in fact:
            objs_str += 'not '
            predicate = predicate[1:]

        objs_str += predicate +'.'
        nl_facts.append(objs_str)
    return nl_facts

def transfer_rules_to_nl(sample=None, rules=None):
    # transfer symbolic rules to natural language
    if rules is None:
        rules = sample['rules']
    nl_rules = []
    for rule in rules:
        consequents, preconditions = rule.split(':-')
        preconditions = preconditions.strip()

        default_set = []
        normal_set = []
        for precondition in preconditions.split(', '):
            if 'not ' in precondition:
                default_set.append(precondition)
            else:
                normal_set.append(precondition)

        # make positive condition first
        normal_set.sort(reverse=True)
        default_set.sort(reverse=True)

        normal_set_str_list = []
        for precondition in normal_set:
            normal_set_str_list.append(transfer_predicate_to_nl(precondition))

        default_set_str_list = []
        for precondition in default_set:
            default_set_str_list.append(transfer_predicate_to_nl(precondition.split('not ')[1]))

        if rule.startswith(':-'):
            rule_str = 'It\'s not permissible for ['
            rule_str += '; '.join(normal_set_str_list+default_set_str_list) + '] to be true at the same time'
            nl_rules.append(rule_str)
            continue


        if '|' in consequents:
            consequents = consequents.split('|')
            consequents = [consequent.strip() for consequent in consequents]
        else:
            consequents = [consequents]
        consequents_str_list = [transfer_predicate_to_nl(consequent) for consequent in consequents]

        rule_str = ''
        if len(normal_set) > 0:
            rule_str = '[If] ' + '; '.join(normal_set_str_list)
        if len(default_set) > 0:
            rule_str += ', [unless] ' + '; '.join(default_set_str_list)
        rule_str += ', [then] ' + ' or '.join(consequents_str_list)

        nl_rules.append(rule_str)
    return nl_rules

def transfer_queries_to_nl(sample):
    # transfer symbolic queries to natural language
    queries = sample['queries']
    nl_queries = []
    for query in queries:
        nl_queries.append({'query': transfer_predicate_to_nl(query['query'], remove_quotation=True), 'label':query['label']})
    return nl_queries

def transfer_nm_queries_to_nl(sample, is_nm_dataset=False):
    # transfer symbolic queries to natural language
    queries = sample['queries']
    nl_queries = []
    for query in queries:
        nl_query = {**query, 'query': transfer_predicate_to_nl(query['query'], remove_quotation=True), 'label':query['label']}
        if is_nm_dataset and 'new_fact' in nl_query:
            nl_query['new_fact'] = transfer_predicate_to_nl(nl_query['new_fact'])
        nl_queries.append(nl_query)
    return nl_queries

def build_nl_dataset(symbolic_dataset_path, language_dataset_path):
    # read symbolic dataset
    with open(symbolic_dataset_path, 'r') as f:
        symbolic_dataset = [json.loads(line) for line in f]

    nl_dataset = []
    # transfer symbolic dataset to natural language dataset
    for sample in tqdm(symbolic_dataset):
        try:
            nl_sample = {**sample}
            nl_sample['facts'] = transfer_facts_to_nl(sample)
            nl_sample['rules'] = transfer_rules_to_nl(sample)
            nl_sample['queries'] = transfer_queries_to_nl(sample)
            nl_dataset.append(nl_sample)
        except:
            print('Error in sample', sample)

    with open(language_dataset_path, 'w') as f:
        for sample in nl_dataset:
            f.write(json.dumps(sample, ensure_ascii=False) + '\n')
    print('NL dataset saved to', language_dataset_path)

if __name__ == '__main__':
    symbolic_dataset_path = 'logicalDatasets/checked_dataset/related_word_symbolic_dataset.jsonl'
    language_dataset_path = 'logicalDatasets/checked_dataset/related_word_natural_language_dataset.jsonl'
    build_nl_dataset(symbolic_dataset_path, language_dataset_path)