from utils.logic.fol import *
from utils.wikidata_types import *
import pickle, json, random
import copy



def generate_1step_queries(abox,inexact_matches, ents, no_queries=4, no_queries_perkind=2):
    queries = []; queries_inexact = []
    queries_nl = []; queries_nl_inexact = []
    query_counter = 0; query_counter_perkind = 0
    # countries with special properties

    for triple in abox:
        if query_counter >= no_queries:
            break
        
        predicate = triple.split('(')[0].strip()
        arg1, arg2 = triple.split('(')[1].split(')')[0].split(',')
        arg1 = arg1.strip(); arg2 = arg2.strip()
        if predicate == 'LocatedIn' and get_wikidata_types(arg2) == 'country':
            if get_wikidata_types(arg1) == 'mountain':
                new_predicate = 'CanSkiIn'
            elif get_wikidata_types(arg1) == 'museum':
                new_predicate = 'CanVisitMuseum'
            elif get_wikidata_types(arg1) == 'river':
                new_predicate = 'CanSailIn'
            elif get_wikidata_types(arg1) == 'forest':
                new_predicate = 'CanCampIn'
            else:
                continue

            
            place = arg1
            country = arg2
            new_query = Predicate(new_predicate, Constant(country, 'country'))
            neg_country = random.sample(list(ents), 1)[0]
            
            if str(new_query) in [str(q) for q in queries]:
                continue
            query_counter += 1
            queries.append(Predicate(new_predicate, Constant(country, 'country')))
            queries.append(Predicate(new_predicate, Constant(neg_country, get_wikidata_types(neg_country))))
            queries_inexact.append(Predicate(inexact_matches[new_predicate], Constant(country, 'country')))
            queries_inexact.append(Predicate(inexact_matches[new_predicate], Constant(neg_country, get_wikidata_types(neg_country))))
            queries_nl.append(
                {
                    'query_id': f'{(len(queries_nl))}',
                    'query': f'{new_predicate}({country})',
                    'label': 'True',
                    'axiom': f'LocatedIn({place}, {country}) => {new_predicate}({country})',
                    'triples': [triple],
                    'types': {place: get_wikidata_types(place), country: 'country'}
                }
            )
            queries_nl.append(
                {
                    'query_id': f'{(len(queries_nl))}',
                    'query': f'{new_predicate}({neg_country})',
                    'label': 'False',
                    'axiom': f'LocatedIn({place}, {country}) => {new_predicate}({neg_country})',
                    'triples': [triple],
                    'types': {place: get_wikidata_types(place), country: 'country'}
                }
            )
            queries_nl_inexact.append(
                {
                    'query_id': f'{(len(queries_nl_inexact))}',
                    'query': f'{inexact_matches[new_predicate]}({country})',
                    'label': 'True',
                    'axiom': f'LocatedIn({place}, {country}) => {inexact_matches[new_predicate]}({country})',
                    'triples': [triple],
                    'types': {place: get_wikidata_types(place), country: 'country'}
                }
            )
            queries_nl_inexact.append(
                {
                    'query_id': f'{(len(queries_nl_inexact))}',
                    'query': f'{inexact_matches[new_predicate]}({neg_country})',
                    'label': 'False',
                    'axiom': f'LocatedIn({place}, {country}) => {inexact_matches[new_predicate]}({neg_country})',
                    'triples': [triple],
                    'types': {place: get_wikidata_types(place), country: 'country'}
                }
            )

    return queries_nl, queries_nl_inexact, queries, queries_inexact



def generate_2step_queries(abox, inexact_matches,ents,  no_queries=4, no_queries_perkind=2):
    queries = []; queries_inexact = []
    queries_nl = []; queries_nl_inexact = []
    query_counter = 0; query_counter_perkind = 0
    # places being part of a continent
    # for triple in abox:
    #     if query_counter >= no_queries:
    #         break
        
    #     predicate = triple.split('(')[0].strip()
    #     arg1, arg2 = triple.split('(')[1].split(')')[0].split(',')
    #     arg1 = arg1.strip(); arg2 = arg2.strip()
    #     if predicate == 'PartOf' and arg2.lower() in ['asia', 'africa']:
            
    #         country = arg1
    #         continent = arg2
    #         for triple_new in abox:
    #             if query_counter_perkind >= no_queries_perkind:
    #                 query_counter_perkind = 0
    #                 break
    #             predicate_new = triple_new.split('(')[0].strip()
    #             arg1_new, arg2_new = triple_new.split('(')[1].split(')')[0].split(',')
    #             arg1_new = arg1_new.strip(); arg2_new = arg2_new.strip()
    #             if predicate_new == 'PartOf' and arg2_new == country:
    #                 query_counter_perkind += 1
    #                 query_counter += 1
    #                 place = arg1_new

    #                 neg_continent = random.sample(list(ents), 1)[0]
                    
    #                 queries.append(Predicate('PartOf', Constant(place, get_wikidata_types(place)), Constant(continent, 'continent')))
    #                 queries.append(Predicate('PartOf', Constant(place, get_wikidata_types(place)), Constant(neg_continent, get_wikidata_types(neg_continent))))

    #                 queries_inexact.append(Predicate(inexact_matches['PartOf'], Constant(place, get_wikidata_types(place)), Constant(continent, 'continent')))
    #                 queries_inexact.append(Predicate(inexact_matches['PartOf'], Constant(place, get_wikidata_types(place)), Constant(neg_continent, get_wikidata_types(neg_continent))))
    #                 queries_nl.append(
    #                     {
    #                         'query_id': f'{(len(queries_nl))}',
    #                         'query': f'PartOf({place}, {continent})',
    #                         'label': 'True',
    #                         'axiom': f'PartOf({place}, {country}) && PartOf({country}, {continent}) => PartOf({place}, {continent})',
    #                         'triples': [triple, triple_new],
    #                         'types': {place: get_wikidata_types(place), country: 'country', continent: 'continent'}
    #                     }
    #                 )
    #                 queries_nl.append(
    #                     {
    #                         'query_id': f'{(len(queries_nl))}',
    #                         'query': f'PartOf({place}, {neg_continent})',
    #                         'label': 'False',
    #                         'axiom': f'PartOf({place}, {country}) && PartOf({country}, {continent}) => PartOf({place}, {neg_continent})',
    #                         'triples': [triple, triple_new],
    #                         'types': {place: get_wikidata_types(place), country: 'country', neg_continent: 'continent'}
    #                     }
    #                 )
    #                 queries_nl_inexact.append(
    #                     {
    #                         'query_id': f'{(len(queries_nl_inexact))}',
    #                         'query': f'{inexact_matches["PartOf"]}({place}, {continent})',
    #                         'label': 'True',
    #                         'axiom': f'PartOf({place}, {country}) && PartOf({country}, {continent}) => {inexact_matches["PartOf"]}({place}, {continent})',
    #                         'triples': [triple, triple_new],
    #                         'types': {place: get_wikidata_types(place), country: 'country', continent: 'continent'}
    #                     }
    #                 )
    #                 queries_nl_inexact.append(
    #                     {
    #                         'query_id': f'{(len(queries_nl_inexact))}',
    #                         'query': f'{inexact_matches["PartOf"]}({place}, {neg_continent})',
    #                         'label': 'False',
    #                         'axiom': f'PartOf({place}, {country}) && PartOf({country}, {continent}) => {inexact_matches["PartOf"]}({place}, {neg_continent})',
    #                         'triples': [triple, triple_new],
    #                         'types': {place: get_wikidata_types(place), country: 'country', neg_continent: 'continent'}
    #                     }
    #                 )

    ##places being part of a country
    for triple in abox:
        if query_counter >= no_queries or query_counter_perkind >= no_queries_perkind:
            query_counter_perkind = 0
            break
        
        predicate = triple.split('(')[0].strip()
        arg1, arg2 = triple.split('(')[1].split(')')[0].split(',')
        arg1 = arg1.strip(); arg2 = arg2.strip()
        #if predicate == 'PartOf' and arg2.lower() in ['iran', 'taiwan', 'egypt', 'Indonesia', 'south korea', 'morocco', 'singapore', 'portugal', 'italy', 'spain']: 
        if predicate == 'PartOf' and 'country' in get_wikidata_types(arg2):
            country = arg2
            bigger_place = arg1
            for triple_new in abox:
                predicate_new = triple_new.split('(')[0].strip()
                arg1_new, arg2_new = triple_new.split('(')[1].split(')')[0].split(',')
                arg1_new = arg1_new.strip(); arg2_new = arg2_new.strip()
                if predicate_new == 'At' and arg2_new == bigger_place:

                    if query_counter_perkind >= no_queries_perkind or query_counter >= no_queries:
                        break
                    query_counter += 2
                    query_counter_perkind += 2
                    place = arg1_new
                    neg_country = copy.deepcopy(country)
                    while f'PartOf({bigger_place}, {neg_country})' in abox:
                        neg_country = random.sample(list(ents), 1)[0]

                    queries.append(Predicate('At', Constant(place, get_wikidata_types(place)), Constant(country, 'country')))
                    queries.append(Predicate('At', Constant(place, get_wikidata_types(place)), Constant(neg_country, get_wikidata_types(neg_country))))
                    queries_inexact.append(Predicate(inexact_matches['At'], Constant(place, get_wikidata_types(place)), Constant(country, 'country')))
                    queries_inexact.append(Predicate(inexact_matches['At'], Constant(place, get_wikidata_types(place)), Constant(neg_country, get_wikidata_types(neg_country))))
                    queries_nl.append(
                        {
                            'query_id': f'{(len(queries_nl))}',
                            'query': f'PartOf({place}, {country})',
                            'label': 'True',
                            'axiom': f'At({place}, {bigger_place}) && PartOf({bigger_place}, {country}) => At({place}, {country})',
                            'triples': [triple, triple_new],
                            'types': {place: get_wikidata_types(place), country: 'country', bigger_place: get_wikidata_types(bigger_place)}
                        }
                    )
                    queries_nl.append(
                        {
                            'query_id': f'{(len(queries_nl))}',
                            'query': f'PartOf({place}, {neg_country})',
                            'label': 'False',
                            'axiom': f'At({place}, {bigger_place}) && PartOf({bigger_place}, {country}) => At({place}, {neg_country})',
                            'triples': [triple, triple_new],
                            'types': {place: get_wikidata_types(place), country: 'country', bigger_place: get_wikidata_types(bigger_place)}
                        }
                    )
                    queries_nl_inexact.append(
                        {
                            'query_id': f'{(len(queries_nl_inexact))}',
                            'query': f'{inexact_matches["PartOf"]}({place}, {country})',
                            'label': 'True',
                            'axiom': f'At({place}, {bigger_place}) && PartOf({bigger_place}, {country}) => {inexact_matches["At"]}({place}, {country})',
                            'triples': [triple, triple_new],
                            'types': {place: get_wikidata_types(place), country: 'country', bigger_place: get_wikidata_types(bigger_place)}
                        }
                    )
                    queries_nl_inexact.append(
                        {
                            'query_id': f'{(len(queries_nl_inexact))}',
                            'query': f'{inexact_matches["PartOf"]}({place}, {neg_country})',
                            'label': 'False',
                            'axiom': f'At({place}, {bigger_place}) && PartOf({bigger_place}, {country}) => {inexact_matches["At"]}({place}, {neg_country})',
                            'triples': [triple, triple_new],
                            'types': {place: get_wikidata_types(place), country: 'country', bigger_place: get_wikidata_types(bigger_place)}
                        }
                    )
    print('queries generated:', len(queries))

    # countries that are land connected
    for triple in abox:
        if query_counter >= no_queries or query_counter_perkind >= no_queries_perkind:
            query_counter_perkind = 0
            break
        
        predicate = triple.split('(')[0].strip()
        arg1, arg2 = triple.split('(')[1].split(')')[0].split(',')
        arg1 = arg1.strip(); arg2 = arg2.strip()
        if predicate == 'ShareLandBorders':
            country1 = arg1
            country2 = arg2
            for triple_new in abox:
                
 
                predicate_new = triple_new.split('(')[0].strip()
                arg1_new, arg2_new = triple_new.split('(')[1].split(')')[0].split(',')
                arg1_new = arg1_new.strip(); arg2_new = arg2_new.strip()
                if predicate_new == 'ShareLandBorders' and arg1_new == country2 and arg2_new != country1:
                    if query_counter_perkind >= no_queries_perkind or query_counter >= no_queries:
                        break
                    query_counter += 2
                    query_counter_perkind += 2
                    country3 = arg2_new
                    neg_country = copy.deepcopy(country3)
                    while f'ShareLandBorders({country2}, {neg_country})' in abox:
                        neg_country = random.sample(list(ents), 1)[0]
                    neg_country = random.sample(list(ents), 1)[0]
                    queries.append(Predicate('LandConnected', Constant(country1, 'country'), Constant(country3, 'country')))
                    queries.append(Predicate('LandConnected', Constant(neg_country, get_wikidata_types(neg_country)), Constant(country3, 'country')))

                    queries_inexact.append(Predicate(inexact_matches['LandConnected'], Constant(country1, 'country'), Constant(country3, 'country')))
                    queries_inexact.append(Predicate(inexact_matches['LandConnected'], Constant(neg_country, get_wikidata_types(neg_country)), Constant(country3, 'country')))
                    queries_nl.append(
                        {
                            'query_id': f'{(len(queries_nl))}',
                            'query': f'LandConnected({country1}, {country3})',
                            'label': 'True',
                            'axiom': f'ShareLandBorders({country1}, {country2}) && ShareLandBorders({country2}, {country3}) => LandConnected({country1}, {country3})',
                            'triples': [triple, triple_new],
                            'types': {country1: 'country', country2: 'country', country3: 'country'}
                        }
                    )
                    queries_nl.append(
                        {
                            'query_id': f'{(len(queries_nl))}',
                            'query': f'LandConnected({neg_country}, {country3})',
                            'label': 'False',
                            'axiom': f'ShareLandBorders({country1}, {country2}) && ShareLandBorders({country2}, {country3}) => LandConnected({neg_country}, {country3})',
                            'triples': [triple, triple_new],
                            'types': {neg_country: get_wikidata_types(neg_country), country2: 'country', country3: 'country'}
                        }
                    )

                    queries_nl_inexact.append(
                        {
                            'query_id': f'{(len(queries_nl_inexact))}',
                            'query': f'{inexact_matches["LandConnected"]}({country1}, {country3})',
                            'label': 'True',
                            'axiom': f'ShareLandBorders({country1}, {country2}) && ShareLandBorders({country2}, {country3}) => {inexact_matches["LandConnected"]}({country1}, {country3})',
                            'triples': [triple, triple_new],
                            'types': {country1: 'country', country2: 'country', country3: 'country'}
                        }
                    )
                    queries_nl_inexact.append(
                        {
                            'query_id': f'{(len(queries_nl_inexact))}',
                            'query': f'{inexact_matches["LandConnected"]}({neg_country}, {country3})',
                            'label': 'False',
                            'axiom': f'ShareLandBorders({country1}, {country2}) && ShareLandBorders({country2}, {country3}) => {inexact_matches["LandConnected"]}({neg_country}, {country3})',
                            'triples': [triple, triple_new],
                            'types': {neg_country: get_wikidata_types(neg_country), country2: 'country', country3: 'country'}
                        }
                    )
    print('two step queries generated:', len(queries))

    #TODO: places contained in another place
    #TODO: forest and mountain, etc.


    return queries_nl, queries_nl_inexact, queries, queries_inexact

def generate_3step_queries(abox, inexact_matches, ents, no_queries=4, no_queries_perkind=2):
    queries = []; queries_inexact = []
    queries_nl = []; queries_nl_inexact = []
    query_counter = 0; query_counter_perkind = 0

    # places being part of a continent
    for triple in abox:
        if query_counter >= no_queries:
            break
        
        predicate = triple.split('(')[0].strip()
        arg1, arg2 = triple.split('(')[1].split(')')[0].split(',')
        arg1 = arg1.strip(); arg2 = arg2.strip()
        if predicate == 'PartOf' and arg2.lower() in ['asia', 'africa']:
            country = arg1
            continent = arg2
            for triple_new in abox:
                if query_counter_perkind >= no_queries_perkind:
                    query_counter_perkind = 0
                    break
                predicate_new = triple_new.split('(')[0].strip()
                arg1_new, arg2_new = triple_new.split('(')[1].split(')')[0].split(',')
                arg1_new = arg1_new.strip(); arg2_new = arg2_new.strip()
                if predicate_new == 'LocatedIn' and arg2_new == country:
                    place = arg1_new
                    # for triple_new2 in abox:
                    if query_counter_perkind >= no_queries_perkind:
                            query_counter_perkind = 0
                            break

                    query_counter_perkind += 1
                    query_counter += 1

                    neg_continent = random.sample(list(ents), 1)[0]

                    queries.append(Predicate('PartOf', Constant(place, get_wikidata_types(place)), Constant(continent, 'continent')))
                    queries.append(Predicate('PartOf', Constant(place, get_wikidata_types(place)), Constant(neg_continent, get_wikidata_types(neg_continent))))
                    queries_inexact.append(Predicate(inexact_matches['PartOf'], Constant(place, get_wikidata_types(place)), Constant(continent, 'continent')))
                    queries_inexact.append(Predicate(inexact_matches['PartOf'], Constant(place, get_wikidata_types(place)), Constant(neg_continent, get_wikidata_types(neg_continent))))
                    queries_nl.append(
                        {
                            'query_id': f'{(len(queries_nl))}',
                            'query': f'PartOf({place}, {continent})',
                            'label': 'True',
                            'axiom': f'LocatedIn({place}, {country}) && PartOf({country}, {continent}) => PartOf({place}, {continent})',
                            'triples': [triple, triple_new],
                            'types': {place: get_wikidata_types(place), place: get_wikidata_types(place), country: 'country', continent: 'continent'}
                        }
                    )
                    queries_nl.append(
                        {
                            'query_id': f'{(len(queries_nl))}',
                            'query': f'PartOf({place}, {neg_continent})',
                            'label': 'False',
                            'axiom': f'LocatedIn({place}, {country}) && PartOf({country}, {continent}) => PartOf({place}, {neg_continent})',
                            'triples': [triple, triple_new],
                            'types': {place: get_wikidata_types(place), place: get_wikidata_types(place), country: 'country', neg_continent: 'continent'}
                        }
                    )

                    queries_nl_inexact.append(
                        {
                            'query_id': f'{(len(queries_nl_inexact))}',
                            'query': f'{inexact_matches["PartOf"]}({place}, {continent})',
                            'label': 'True',
                            'axiom': f'LocatedIn({place}, {country}) && PartOf({country}, {continent}) => {inexact_matches["PartOf"]}({place}, {continent})',
                            'triples': [triple, triple_new],
                            'types': {place: get_wikidata_types(place), place: get_wikidata_types(place), country: 'country', continent: 'continent'}
                        }
                    )
                    queries_nl_inexact.append(
                        {
                            'query_id': f'{(len(queries_nl_inexact))}',
                            'query': f'{inexact_matches["PartOf"]}({place}, {neg_continent})',
                            'label': 'False',
                            'axiom': f'LocatedIn({place}, {country}) && PartOf({country}, {continent}) => {inexact_matches["PartOf"]}({place}, {neg_continent})',
                            'triples': [triple, triple_new],
                            'types': {place: get_wikidata_types(place), place: get_wikidata_types(place), country: 'country', neg_continent: 'continent'}
                        }
                    )

                    if query_counter_perkind >= no_queries_perkind or query_counter >= no_queries:
                        query_counter_perkind = 0
                        return queries_nl, queries_nl_inexact, queries, queries_inexact
                    
                    # for triple_new2 in abox:
                    #     if query_counter_perkind >= no_queries_perkind:
                    #         query_counter_perkind = 0
                    #         break
                    #     predicate_new2 = triple_new2.split('(')[0].strip()
                    #     arg1_new2, arg2_new2 = triple_new2.split('(')[1].split(')')[0].split(',')
                    #     arg1_new2 = arg1_new2.strip(); arg2_new2 = arg2_new2.strip()
                    #     if predicate_new2 == 'LocatedIn' and arg1_new2 == place:
                    #         query_counter_perkind += 1
                    #         query_counter += 1
                    #         place2 = arg2_new2
                            
                    #         queries.append(Predicate('PartOf', Constant(place, get_wikidata_types(place)), Constant(continent, 'continent')))
                    #         queries_nl.append(
                    #             {
                    #                 'query_id': f'{(len(queries_nl))}',
                    #                 'query': f'PartOf({place2}, {continent})',
                    #                 'label': 'True',
                    #                 'axiom': f'LocatedIn({place2}, {place}) && PartOf({place}, {country}) && PartOf({place2}, {place}) => PartOf({place2}, {continent})',
                    #                 'triples': [triple, triple_new, triple_new2],
                    #                 'types': {place2: get_wikidata_types(place2), place: get_wikidata_types(place), country: 'country', continent: 'continent'}
                    #             }
                    #         )

                    #         if query_counter_perkind >= no_queries_perkind or query_counter >= no_queries:
                    #             query_counter_perkind = 0
                    #             return queries_nl, queries_nl_inexact, queries, queries_inexact
    return queries_nl, queries_nl_inexact, queries, queries_inexact

def generate_4step_queries(abox, inexact_matches, ents, no_queries=4, no_queries_perkind=2):
    queries = []; queries_inexact = []
    queries_nl = []; queries_nl_inexact = []
    query_counter = 0; query_counter_perkind = 0
    # places being part of countries that are land connected
    for triple in abox:
        if query_counter >= no_queries:
            break
        
        predicate = triple.split('(')[0].strip()
        arg1, arg2 = triple.split('(')[1].split(')')[0].split(',')
        arg1 = arg1.strip(); arg2 = arg2.strip()
        if predicate == 'ShareLandBorders':
            country1 = arg1
            country2 = arg2
            for triple_new in abox:
                if query_counter_perkind >= no_queries_perkind:
                    query_counter_perkind = 0
                    break
                predicate_new = triple_new.split('(')[0].strip()
                arg1_new, arg2_new = triple_new.split('(')[1].split(')')[0].split(',')
                arg1_new = arg1_new.strip(); arg2_new = arg2_new.strip()
                if predicate_new == 'PartOf' and arg2_new == country1:
                    place = arg1_new
                    for triple_new2 in abox:
                        if query_counter_perkind >= no_queries_perkind:
                            query_counter_perkind = 0
                            break
                        predicate_new2 = triple_new2.split('(')[0].strip()
                        arg1_new2, arg2_new2 = triple_new2.split('(')[1].split(')')[0].split(',')
                        arg1_new2 = arg1_new2.strip(); arg2_new2 = arg2_new2.strip()
                        if predicate_new2 == 'PartOf' and arg2_new2 == country2:
                            query_counter_perkind += 2; query_counter += 2
                            place2 = arg1_new2
                            neg_place2 = random.sample(list(ents), 1)[0]
                            queries.append(Predicate('LandConnected', Constant(place, get_wikidata_types(place)), Constant(place2, get_wikidata_types(place2))))
                            queries.append(Predicate('LandConnected', Constant(place, get_wikidata_types(place)), Constant(neg_place2, get_wikidata_types(neg_place2))))
                            queries_inexact.append(Predicate(inexact_matches['LandConnected'], Constant(place, get_wikidata_types(place)), Constant(place2, get_wikidata_types(place2))))
                            queries_inexact.append(Predicate(inexact_matches['LandConnected'], Constant(place, get_wikidata_types(place)), Constant(neg_place2, get_wikidata_types(neg_place2))))
                            queries_nl.append(
                                {
                                    'query_id': f'{(len(queries_nl))}',
                                    'query': f'LandConnected({place}, {place2})',
                                    'label': 'True',
                                    'axiom': f'PartOf({place}, {country1}) && PartOf({place2}, {country2}) && ShareLandBorders({country1}, {country2}) => LandConnected({place2}, {place2})',
                                    'triples': [triple, triple_new, triple_new2],
                                    'types': {place2: get_wikidata_types(place2), place: get_wikidata_types(place), country1: 'country', country2: 'country'}
                                }
                            )
                            queries_nl.append(
                                {
                                    'query_id': f'{(len(queries_nl))}',
                                    'query': f'LandConnected({place}, {neg_place2})',
                                    'label': 'False',
                                    'axiom': f'PartOf({place}, {country1}) && PartOf({neg_place2}, {country2}) && ShareLandBorders({country1}, {country2}) => LandConnected({place}, {neg_place2})',
                                    'triples': [triple, triple_new, triple_new2],
                                    'types': {neg_place2: get_wikidata_types(neg_place2), place: get_wikidata_types(place), country1: 'country', country2: 'country'}
                                }
                            )
                            queries_nl_inexact.append(
                                {
                                    'query_id': f'{(len(queries_nl_inexact))}',
                                    'query': f'{inexact_matches["LandConnected"]}({place}, {place2})',
                                    'label': 'True',
                                    'axiom': f'PartOf({place}, {country1}) && PartOf({place2}, {country2}) && ShareLandBorders({country1}, {country2}) => {inexact_matches["LandConnected"]}({place2}, {place2})',
                                    'triples': [triple, triple_new, triple_new2],
                                    'types': {place2: get_wikidata_types(place2), place: get_wikidata_types(place), country1: 'country', country2: 'country'}
                                }
                            )
                            queries_nl_inexact.append(
                                {
                                    'query_id': f'{(len(queries_nl_inexact))}',
                                    'query': f'{inexact_matches["LandConnected"]}({place}, {neg_place2})',
                                    'label': 'False',
                                    'axiom': f'PartOf({place}, {country1}) && PartOf({neg_place2}, {country2}) && ShareLandBorders({country1}, {country2}) => {inexact_matches["LandConnected"]}({place}, {neg_place2})',
                                    'triples': [triple, triple_new, triple_new2],
                                    'types': {neg_place2: get_wikidata_types(neg_place2), place: get_wikidata_types(place), country1: 'country', country2: 'country'}
                                }
                            )
    
                            if query_counter_perkind >= no_queries_perkind or query_counter >= no_queries:
                                query_counter_perkind = 0
                                print('four step queries generated:', len(queries))
                                return queries_nl, queries_nl_inexact, queries, queries_inexact

    
    print('four step queries generated:', len(queries))
    return queries_nl, queries_nl_inexact, queries, queries_inexact


def generate_5step_queries(abox, inexact_matches, ents, no_queries=4, no_queries_perkind=2):
    # needs border crossing to visit places
    queries = []; queries_inexact = []
    queries_nl = []; queries_nl_inexact = []
    query_counter = 0; query_counter_perkind = 0

    neighbors_dict = {}

    # finding neighboring countries
    for triple in abox:
        predicate = triple.split('(')[0].strip()
        arg1, arg2 = triple.split('(')[1].split(')')[0].split(',')
        arg1 = arg1.strip(); arg2 = arg2.strip()
        if predicate == 'ShareLandBorders':
            if arg1 in neighbors_dict:
                neighbors_dict[arg1].append(arg2)
            else:
                neighbors_dict[arg1] = [arg2]
            if arg2 in neighbors_dict:
                neighbors_dict[arg2].append(arg1)
            else:
                neighbors_dict[arg2] = [arg1]

    
    for country1 in neighbors_dict.keys():
        for country2 in neighbors_dict[country1]:
            for triple in abox:
                predicate = triple.split('(')[0].strip()
                arg1, arg2 = triple.split('(')[1].split(')')[0].split(',')
                arg1 = arg1.strip(); arg2 = arg2.strip()

                if predicate == 'PartOf' and arg2 == country1:
                    place1 = arg1
                    for triple2 in abox:
                        predicate2 = triple2.split('(')[0].strip()
                        arg1_2, arg2_2 = triple2.split('(')[1].split(')')[0].split(',')
                        arg1_2 = arg1_2.strip(); arg2_2 = arg2_2.strip()
                        if predicate2 == 'At' and arg2_2 == place1:
                            attraction1 = arg1_2
                        
                            for triple3 in abox:
                                predicate3 = triple3.split('(')[0].strip()
                                arg1_3, arg2_3 = triple3.split('(')[1].split(')')[0].split(',')
                                arg1_3 = arg1_3.strip(); arg2_3 = arg2_3.strip()
                                if predicate3 == 'PartOf' and arg2_3 == country2:
                                    place2 = arg1_3
                                    for triple4 in abox:
                                        predicate4 = triple4.split('(')[0].strip()
                                        arg1_4, arg2_4 = triple4.split('(')[1].split(')')[0].split(',')
                                        arg1_4 = arg1_4.strip(); arg2_4 = arg2_4.strip()
                                        if predicate4 == 'At' and arg2_4 == place2:
                                            attraction2 = arg1_4

                                            query_counter += 1
                                            query_counter_perkind += 1
                                            neg_attraction2 = random.sample(list(ents), 1)[0]

                                            queries.append(Predicate('NeedsBorderCrossing', Constant(attraction1, get_wikidata_types(attraction1)), Constant(attraction2, get_wikidata_types(attraction2))))
                                            queries.append(Predicate('NeedsBorderCrossing', Constant(attraction1, get_wikidata_types(attraction1)), Constant(neg_attraction2, get_wikidata_types(neg_attraction2))))
                                            queries_inexact.append(Predicate(inexact_matches['NeedsBorderCrossing'], Constant(attraction1, get_wikidata_types(attraction1)), Constant(attraction2, get_wikidata_types(attraction2))))
                                            queries_inexact.append(Predicate(inexact_matches['NeedsBorderCrossing'], Constant(attraction1, get_wikidata_types(attraction1)), Constant(neg_attraction2, get_wikidata_types(neg_attraction2))))
                                            queries_nl.append(
                                                {
                                                    'query_id': f'{(len(queries_nl))}',
                                                    'query': f'NeedsBorderCrossing({attraction1}, {attraction2})',
                                                    'label': 'True',
                                                    'axiom': f'PartOf({attraction1}, {place1}) && PartOf({place1}, {country1}) && PartOf({attraction2}, {place2}) && PartOf({place2}, {country2}) && ShareLandBorders({country1}, {country2}) => NeedsBorderCrossing({attraction1}, {attraction2})',
                                                    'triples': [triple, triple2, triple3, triple4],
                                                    'types': {attraction1: get_wikidata_types(attraction1), attraction2: get_wikidata_types(attraction2), place1: get_wikidata_types(place1), place2: get_wikidata_types(place2), country1: 'country', country2: 'country'}
                                                }
                                            )
                                            queries_nl.append(
                                                {
                                                    'query_id': f'{(len(queries_nl))}',
                                                    'query': f'NeedsBorderCrossing({attraction1}, {neg_attraction2})',
                                                    'label': 'False',
                                                    'axiom': f'PartOf({attraction1}, {place1}) && PartOf({place1}, {country1}) && PartOf({neg_attraction2}, {place2}) && PartOf({place2}, {country2}) && ShareLandBorders({country1}, {country2}) => NeedsBorderCrossing({attraction1}, {neg_attraction2})',
                                                    'triples': [triple, triple2, triple3, triple4],
                                                    'types': {attraction1: get_wikidata_types(attraction1), neg_attraction2: get_wikidata_types(neg_attraction2), place1: get_wikidata_types(place1), place2: get_wikidata_types(place2), country1: 'country', country2: 'country'}
                                                }
                                            )
                                            queries_nl_inexact.append(
                                                {
                                                    'query_id': f'{(len(queries_nl_inexact))}',
                                                    'query': f'{inexact_matches["NeedsBorderCrossing"]}({attraction1}, {attraction2})',
                                                    'label': 'True',
                                                    'axiom': f'PartOf({attraction1}, {place1}) && PartOf({place1}, {country1}) && PartOf({attraction2}, {place2}) && PartOf({place2}, {country2}) && ShareLandBorders({country1}, {country2}) => {inexact_matches["NeedsBorderCrossing"]}({attraction1}, {attraction2})',
                                                    'triples': [triple, triple2, triple3, triple4],
                                                    'types': {attraction1: get_wikidata_types(attraction1), attraction2: get_wikidata_types(attraction2), place1: get_wikidata_types(place1), place2: get_wikidata_types(place2), country1: 'country', country2: 'country'}
                                                }
                                            )
                                            queries_nl_inexact.append(
                                                {
                                                    'query_id': f'{(len(queries_nl_inexact))}',
                                                    'query': f'{inexact_matches["NeedsBorderCrossing"]}({attraction1}, {neg_attraction2})',
                                                    'label': 'False',
                                                    'axiom': f'PartOf({attraction1}, {place1}) && PartOf({place1}, {country1}) && PartOf({neg_attraction2}, {place2}) && PartOf({place2}, {country2}) && ShareLandBorders({country1}, {country2}) => {inexact_matches["NeedsBorderCrossing"]}({attraction1}, {neg_attraction2})',
                                                    'triples': [triple, triple2, triple3, triple4],
                                                    'types': {attraction1: get_wikidata_types(attraction1), neg_attraction2: get_wikidata_types(neg_attraction2), place1: get_wikidata_types(place1), place2: get_wikidata_types(place2), country1: 'country', country2: 'country'}
                                                }
                                            )
                                            if query_counter_perkind >= no_queries_perkind or query_counter >= no_queries:
                                                return queries_nl, queries_nl_inexact, queries, queries_inexact

    return queries_nl, queries_nl_inexact, queries, queries_inexact




def generate_6step_queries(abox, inexact_matches, ents, no_queries=4, no_queries_perkind=2):
        neighbors = {'iran': ['turkmenistan', 'afghanistan', 'pakistan', 'azerbaijan', 'armenia', 'turkey', 'iraq', 'kuwait', 'syria', 'jordan', 'georgia', 'russia', 'india', 'china'], 
                 'thailand': ['myanmar', 'laos', 'cambodia', 'malaysia', 'china', 'cambodia', 'myanmar', 'vietnam', 'thailand'],
                    'niger': ['mali', 'algeria', 'libya', 'chad', 'cameroon', 'benin', 'togo', 'burkina faso', 'mauritania']}
        queries = []; queries_inexact = []
        queries_nl = []; queries_nl_inexact = []
        query_counter = 0; query_counter_perkind = 0
        # can drive between two places
        for triple in abox:

            
            predicate = triple.split('(')[0].strip()
            arg1, arg2 = triple.split('(')[1].split(')')[0].split(',')
            arg1 = arg1.strip(); arg2 = arg2.strip()
            if predicate == 'LocatedIn' and arg2.lower() in ['iran', 'niger', 'thailand']:
                place1 = arg1
                country1 = arg2
                for triple_new in abox:

                    predicate_new = triple_new.split('(')[0].strip()
                    arg1_new, arg2_new = triple_new.split('(')[1].split(')')[0].split(',')
                    arg1_new = arg1_new.strip(); arg2_new = arg2_new.strip()
                    if predicate_new == 'LocatedIn' and arg2_new in neighbors[country1.lower()]:
                        query_counter_perkind += 2; query_counter += 2
                        place2 = arg1_new
                        country2 = arg2_new

                        neg_place2 = random.sample(list(ents), 1)[0]
                        queries.append(Predicate('CanDriveBetween', Constant(place1, get_wikidata_types(place1)), Constant(place2, get_wikidata_types(place2))))
                        queries.append(Predicate('CanDriveBetween', Constant(place1, get_wikidata_types(place1)), Constant(neg_place2, get_wikidata_types(neg_place2))))
                        queries_inexact.append(Predicate(inexact_matches['CanDriveBetween'], Constant(place1, get_wikidata_types(place1)), Constant(place2, get_wikidata_types(place2))))
                        queries_inexact.append(Predicate(inexact_matches['CanDriveBetween'], Constant(place1, get_wikidata_types(place1)), Constant(neg_place2, get_wikidata_types(neg_place2))))

                        queries_nl.append(
                            {
                                'query_id': f'{(len(queries_nl))}',
                                'query': f'CanDriveBetween({place1}, {place2})',
                                'label': 'True',
                                'axiom': f'LocatedIn({place1}, {country1}) && LocatedIn({place2}, {arg2_new}) && LandConnected({country1}, {country2}) => CanDriveBetween({place1}, {place2})',
                                'triples': [triple, triple_new],
                                'types': {place1: get_wikidata_types(place1), place2: get_wikidata_types(place2), country1: 'country', arg2_new: 'country'}
                            }
                        )
                        queries_nl.append(
                            {
                                'query_id': f'{(len(queries_nl))}',
                                'query': f'CanDriveBetween({place1}, {neg_place2})',
                                'label': 'False',
                                'axiom': f'LocatedIn({place1}, {country1}) && LocatedIn({neg_place2}, {arg2_new}) && LandConnected({country1}, {country2}) => CanDriveBetween({place1}, {neg_place2})',
                                'triples': [triple, triple_new],
                                'types': {place1: get_wikidata_types(place1), neg_place2: get_wikidata_types(neg_place2), country1: 'country', arg2_new: 'country'}
                            }
                        )
                        queries_nl_inexact.append(
                            {
                                'query_id': f'{(len(queries_nl_inexact))}',
                                'query': f'{inexact_matches["CanDriveBetween"]}({place1}, {place2})',
                                'label': 'True',
                                'axiom': f'LocatedIn({place1}, {country1}) && LocatedIn({place2}, {arg2_new}) && LandConnected({country1}, {country2}) => {inexact_matches["CanDriveBetween"]}({place1}, {place2})',
                                'triples': [triple, triple_new],
                                'types': {place1: get_wikidata_types(place1), place2: get_wikidata_types(place2), country1: 'country', arg2_new: 'country'}
                            }
                        )
                        queries_nl_inexact.append(
                            {
                                'query_id': f'{(len(queries_nl_inexact))}',
                                'query': f'{inexact_matches["CanDriveBetween"]}({place1}, {neg_place2})',
                                'label': 'False',
                                'axiom': f'LocatedIn({place1}, {country1}) && LocatedIn({neg_place2}, {arg2_new}) && LandConnected({country1}, {country2}) => {inexact_matches["CanDriveBetween"]}({place1}, {neg_place2})',
                                'triples': [triple, triple_new],
                                'types': {place1: get_wikidata_types(place1), neg_place2: get_wikidata_types(neg_place2), country1: 'country', arg2_new: 'country'}
                            }
                        )

                        if query_counter_perkind >= no_queries_perkind or query_counter >= no_queries:
                            query_counter_perkind = 0
                            return queries_nl, queries_nl_inexact, queries, queries_inexact
        return queries_nl, queries_nl_inexact, queries, queries_inexact
if __name__ == "__main__":


    query_preds = []

    with open('data/queries_geo_inexact.pkl', 'rb') as f:
        all_queries_inexact = pickle.load(f)
    with open('data/queries_geo_nl.json', 'r') as f:
        all_queries_nl = json.load(f)
    with open('data/geo_inexact_matches.json', 'r') as f:
            inexact_matches_dict = json.load(f)
    
    exact_names = list(inexact_matches_dict.keys())

    for i, query in enumerate(all_queries_inexact):

        name = query.name
        for exact_name in exact_names:
            if name in inexact_matches_dict[exact_name]:
                query_preds.append({'predicate': name, 'correct_match': exact_name})
                break

    with open('data/geo_pred_matches.json', 'w') as f:
        json.dump(query_preds, f, indent=4, ensure_ascii=False)



    mode = 'extension'

    if mode == 'generation':
        with open('data/geo_inexact_matches.json', 'r') as f:
            inexact_matches = json.load(f)

        with open('data/abox.json', 'r') as f:
            abox = json.load(f)

        ents = set()
        for triple in abox:
            arg1, arg2 = triple.split('(')[1].split(')')[0].split(',')
            ents.add(arg1.strip())
            ents.add(arg2.strip())

        #one_step_queries_nl, one_step_queries_nl_inexact, one_step_queries, one_step_queries_inexact = generate_1step_queries(abox, inexact_matches, ents,no_queries=10, no_queries_perkind=10)
        two_step_queries_nl, two_step_queries_nl_inexact, two_step_queries, two_step_queries_inexact = generate_2step_queries(abox, inexact_matches, ents, no_queries=800, no_queries_perkind=700)
        # three_step_queries_nl, three_step_queries_nl_inexact, three_step_queries, three_step_queries_inexact = generate_3step_queries(abox, inexact_matches, ents, no_queries=10, no_queries_perkind=10)
        four_step_queries_nl, four_step_queries_nl_inexact, four_step_queries, four_step_queries_inexact = generate_4step_queries(abox, inexact_matches, ents, no_queries=100, no_queries_perkind=100)
        # five_step_queries_nl, five_step_queries_nl_inexact, five_step_queries, five_step_queries_inexact = generate_5step_queries(abox, inexact_matches, ents, no_queries=10, no_queries_perkind=10)
        six_step_queries_nl, six_step_queries_nl_inexact, six_step_queries, six_step_queries_inexact = generate_6step_queries(abox, inexact_matches, ents, no_queries=100, no_queries_perkind=100)

        #all_queries = one_step_queries + two_step_queries + three_step_queries + four_step_queries + five_step_queries + six_step_queries
        #all_queries_nl = one_step_queries_nl + two_step_queries_nl + three_step_queries_nl + four_step_queries_nl + five_step_queries_nl + six_step_queries_nl
        all_queries = two_step_queries + four_step_queries + six_step_queries
        all_queries_nl = two_step_queries_nl + four_step_queries_nl + six_step_queries_nl



        with open('data/queries.pkl', 'wb') as f:
            pickle.dump(all_queries, f)
        
        with open('data/queries_nl.json', 'w') as f:
            json.dump(all_queries_nl, f, indent=4, ensure_ascii=False)

    elif mode == 'extension':

        with open('data/queries_geo.pkl', 'rb') as f:
            all_queries = pickle.load(f)


        with open('data/geo_inexact_matches.json', 'r') as f:
            inexact_matches_dict = json.load(f)


        all_queries_inexact = []
        for i, query in enumerate(all_queries):

            if isinstance(inexact_matches_dict[query.name], list):
                new_predicate = random.sample(inexact_matches_dict[query.name], 1)[0]
                all_queries_inexact.append(Predicate(new_predicate, query.args[0], query.args[1]))
            else:
                new_query = query
                all_queries_inexact.append(new_query)



        with open('data/queries_geo_inexact.pkl', 'wb') as f:
            pickle.dump(all_queries_inexact, f)

