from utils.logic.fol import *
from utils.wikidata_types import *
import pickle, json, random
import copy


def parse_triple(triple):
    if triple.count(',') == 1:
        predicate = triple.split('(')[0].strip()
        arg1, arg2 = triple.split('(')[1].split(')')[0].split(',')
        arg1 = arg1.strip(); arg2 = arg2.strip()
    elif triple.count(',') == 0:
        predicate = triple.split('(')[0].strip()
        arg1 = triple.split('(')[1].split(')')[0].strip()
        arg2 = 'None'
    return predicate, arg1, arg2


def generate_animals_queries(abox,inexact_matches, ents, max_queries_perkind= 100):
    queries = []; queries_inexact = []
    queries_nl = []; queries_nl_inexact = []
    query_counter = 0

    for triple in abox:
        predicate, arg1, arg2 = parse_triple(triple)
        if predicate == 'PreysOn' and get_wikidata_types(arg1) == 'bird':
            query = Predicate('HuntsFromAir', Constant(arg1, 'bird'), Constant(arg2, 'animal'))
            new_abox = copy.deepcopy(abox)
            random.shuffle(new_abox)
            for i, other_triple in enumerate(new_abox):
                _, other_arg1, other_arg2 = parse_triple(other_triple)

                if get_wikidata_types(other_arg1) == 'animal' and f'HuntsFromAir({arg1}, {other_arg1})' not in abox:
                    neg_query = Predicate('HuntsFromAir', Constant(arg1, 'bird'), Constant(other_arg1, 'animal'))
                    break
            queries.append(query)
            queries.append(neg_query)
            if len(queries) % 10 == 0:
                print(len(queries))
            queries_nl.append(
                {
                    'query_id': f'{(len(queries_nl))}',
                    'query': f'HuntsFromAir({arg1}, {arg2})',
                    'label': 'True',
                    'axiom': f'PreysOn({arg1}, {arg2}) => HuntsFromAir({arg1}, {arg2})',
                    'triples': [triple],
                    'types': {arg1: 'bird', arg2: 'animal'}
                })
            
            queries_nl.append(
                {
                    'query_id': f'{(len(queries_nl))}',
                    'query': f'HuntsFromAir({arg1}, {other_arg1})',
                    'label': 'False',
                    'axiom': f'PreysOn({arg1}, {arg2}) => HuntsFromAir({arg1}, {other_arg1})',
                    'triples': [triple, other_triple],
                    'types': {arg1: 'bird', other_arg1: 'animal'}
                })
                
        if len(queries) >= max_queries_perkind:
            break

    for triple in abox:
        predicate, arg1, arg2 = parse_triple(triple)
        if predicate == 'MainFoodSource' and get_wikidata_types(arg1) == 'taxon':
            query = Predicate('ReliesOnToSurvive', Constant(arg1, 'taxon'), Constant(arg2, 'food source'))
            for other_triple in abox:
                other_predicate, other_arg1, other_arg2 = parse_triple(other_triple)
                if get_wikidata_types(other_arg1) == 'animal' and f'MainFoodSource({arg1}, {other_arg1})' not in abox:
                    neg_query = Predicate('ReliesOnToSurvive', Constant(arg1, 'taxon'), Constant(other_arg1, 'animal'))
                    break
            queries.append(query)
            queries.append(neg_query)
            queries_nl.append(
                {
                    'query_id': f'{(len(queries_nl))}',
                    'query': f'ReliesOnToSurvive({arg1}, {arg2})',
                    'label': 'True',
                    'axiom': f'MainFoodSource({arg1}, {arg2}) => ReliesOnToSurvive({arg1}, {arg2})',
                    'triples': [triple],
                    'types': {arg1: 'taxon', arg2: 'food source'}
                })
            
            queries_nl.append(
                {
                    'query_id': f'{(len(queries_nl))}',
                    'query': f'ReliesOnToSurvive({arg1}, {other_arg1})',
                    'label': 'False',
                    'axiom': f'MainFoodSource({arg1}, {arg2}) => ReliesOnToSurvive({arg1}, {other_arg1})',
                    'triples': [triple, other_triple],
                    'types': {arg1: 'taxon', other_arg1: 'food source'}
                })
        if len(queries) % 10 == 0:
            print(len(queries))
        if len(queries) >= 2*max_queries_perkind:
            break

    for triple in abox:
        predicate, arg1, _ = parse_triple(triple)

        if predicate == 'GramPositive':
            taxon = arg1
            for new_triple in abox:
                new_predicate, new_arg1, new_arg2 = parse_triple(new_triple)
                if new_predicate == 'SubClassOf' and new_arg2 == taxon:
                    bacteria = new_arg1
                    break
            for other_triple in abox:
                other_predicate, other_arg1, other_arg2 = parse_triple(other_triple)
                if other_predicate == 'TreatsGramNegative':
                    antibiotic = other_arg1
                    query = Predicate('CannotBeUsedToTreat', Constant(bacteria, 'bacteria'), Constant(antibiotic, 'anti biotic'))
                    queries.append(query)
                    neg_query = Predicate('CanBeUsedToTreat', Constant(bacteria, 'bacteria'), Constant(antibiotic, 'anti biotic'))
                    queries.append(neg_query)
                    if len(queries) % 10 == 0:
                        print(len(queries))
                    queries_nl.append(
                        {
                            'query_id': f'{(len(queries_nl))}',
                            'query': f'CannotBeUsedToTreat({bacteria}, {antibiotic})',
                            'label': 'True',
                            'axiom': f'GramPositive({taxon}) && SubClassOf({bacteria}, {taxon}) => CannotBeUsedToTreat({bacteria}, {antibiotic})',
                            'triples': [triple, new_triple, other_triple],
                            'types': {bacteria: 'bacteria', antibiotic: 'anti biotic'}
                        })
                    queries_nl.append(
                        {
                            'query_id': f'{(len(queries_nl))}',
                            'query': f'CanBeUsedToTreat({bacteria}, {antibiotic})',
                            'label': 'False',
                            'axiom': f'GramPositive({taxon}) && SubClassOf({bacteria}, {taxon}) => CannotBeUsedToTreat({bacteria}, {antibiotic})',
                            'triples': [triple, new_triple, other_triple],
                            'types': {bacteria: 'bacteria', antibiotic: 'anti biotic'}
                        })
                    if len(queries) >= 3* max_queries_perkind:
                        break
        if len(queries) >= 3*max_queries_perkind:
            break
    for triple in abox:
        predicate, arg1, _ = parse_triple(triple)
        if predicate == 'GramNegative':
            taxon = arg1
            for new_triple in abox:
                new_predicate, new_arg1, new_arg2 = parse_triple(new_triple)
                if new_predicate == 'SubClassOf' and new_arg2 == taxon:
                    bacteria = new_arg1
                    break
            for other_triple in abox:
                other_predicate, other_arg1, other_arg2 = parse_triple(other_triple)
                if other_predicate == 'TreatsGramNegative':
                    antibiotic = other_arg1
                    query = Predicate('Disinfects', Constant(bacteria, 'bacteria'), Constant(antibiotic, 'anti biotic'))
                    queries.append(query)
                    neg_query = Predicate('DoesNotDisinfect', Constant(bacteria, 'bacteria'), Constant(antibiotic, 'anti biotic'))
                    queries.append(neg_query)
                    if len(queries) % 10 == 0:
                        print(len(queries))
                    queries_nl.append(
                        {
                            'query_id': f'{(len(queries_nl))}',
                            'query': f'Disinfects({bacteria}, {antibiotic})',
                            'label': 'True',
                            'axiom': f'GramNegative({taxon}) && SubClassOf({bacteria}, {taxon}) => Disinfects({bacteria}, {antibiotic})',
                            'triples': [triple, new_triple, other_triple],
                            'types': {bacteria: 'bacteria', antibiotic: 'anti biotic'}
                        })
                    queries_nl.append(
                        {
                            'query_id': f'{(len(queries_nl))}',
                            'query': f'DoesNotDisinfect({bacteria}, {antibiotic})',
                            'label': 'False',
                            'axiom': f'GramNegative({taxon}) && SubClassOf({bacteria}, {taxon}) => Disinfects({bacteria}, {antibiotic})',
                            'triples': [triple, new_triple, other_triple],
                            'types': {bacteria: 'bacteria', antibiotic: 'anti biotic'}
                        })
                    if len(queries) >= 4*max_queries_perkind:
                        break
        if len(queries) >= 4*max_queries_perkind:
            break

    selected_indices = random.sample(range(len(queries)), min(len(queries), max_queries_perkind))
    queries = [queries[i] for i in selected_indices]
    queries_nl = [queries_nl[i] for i in selected_indices]
    print("animal queries generated")

    return queries_nl, queries_nl_inexact, queries, queries_inexact


def generate_foods_queries(abox, inexact_matches, ents, max_queries_perkind = 100):
    queries = []; queries_inexact = []
    queries_nl = []; queries_nl_inexact = []
    query_counter = 0
    countries_set = set()
    for triple in abox:
        pred, arg1, arg2 = parse_triple(triple)
        if pred == 'OriginatesFrom':
            countries_set.add(arg2)


    for triple in abox:
        predicate, arg1, arg2 = parse_triple(triple)
        if predicate == 'OriginatesFrom':
            food = arg1; country = arg2
            for other_triple in abox:
                other_predicate, other_arg1, other_arg2 = parse_triple(other_triple)
                if other_predicate == 'PartOf' and other_arg2 == country:
                    city = other_arg1
                else:
                    continue
                query = Predicate('ConsumedByPeopleIn', Constant(arg1, 'food'), Constant(city, 'city'))
                queries.append(query)
                for other_country in countries_set:
                    if f'OriginatesFrom({arg1}, {other_country})' not in abox and get_wikidata_types(other_country) == 'country':
                        neg_query = Predicate('OriginatesFrom', Constant(arg1, 'food'), Constant(other_country, 'country'))
                        break
                queries.append(neg_query)
                queries_nl.append(
                    {
                        'query_id': f'{(len(queries_nl))}',
                        'query': f'ConsumedByPeopleIn({food}, {city})',
                        'label': 'True',
                        'axiom': f'OriginatesFrom({food}, {country}) && PartOf({city}, {country}) => ConsumedByPeopleIn({city}, {food})',
                        'triples': [triple, other_triple],
                        'types': {food: 'food', country: 'country', city: 'city'}
                    })
                queries_nl.append(
                    {
                        'query_id': f'{(len(queries_nl))}',
                        'query': f'OriginatesFrom({arg1}, {other_country})',
                        'label': 'False',
                        'axiom': f'OriginatesFrom({arg1}, {arg2}) => ConsumedByPeopleIn({arg1}, {other_arg2})',
                        'triples': [triple],
                        'types': {arg1: 'food', other_arg2: 'country'}
                    })
                if len(queries) % 10 == 0:
                    print(len(queries))
        if len(queries) >= 5* max_queries_perkind:
            break
    selected_indices = random.sample(range(len(queries)), min(len(queries), max_queries_perkind))
    queries = [queries[i] for i in selected_indices]
    queries_nl = [queries_nl[i] for i in selected_indices]
    print("foods queries generated")
    return queries_nl, queries_nl_inexact, queries, queries_inexact

def generate_vehicle_queries(abox, inexact_matches, ents, max_queries_perkind = 100):
    queries = []; queries_inexact = []
    queries_nl = []; queries_nl_inexact = []
    query_counter = 0

    countries_set = set()
    for triple in abox:
        pred, arg1, arg2 = parse_triple(triple)
        if pred == 'OriginatesFrom':
            countries_set.add(arg2)

    for triple in abox:
        predicate, arg1, arg2 = parse_triple(triple)
        if predicate == 'ManufacturedBy':
            vehicle = arg1; company = arg2
            for other_triple in abox:
                other_predicate, other_arg1, other_arg2 = parse_triple(other_triple)
                if other_predicate == 'LocatedIn' and other_arg1 == company:
                    country = other_arg2
                    query = Predicate('UsedIn', Constant(vehicle, 'vehicle'), Constant(country, 'country'))
                    queries.append(query)
                    if len(queries) % 10 == 0:
                        print(len(queries))
                    queries_nl.append(
                        {
                            'query_id': f'{(len(queries_nl))}',
                            'query': f'UsedIn({vehicle}, {country})',
                            'label': 'True',
                            'axiom': f'ManufacturedBy({vehicle}, {company}) && LocatedIn({company}, {country}) => UsedIn({vehicle}, {country})',
                            'triples': [triple, other_triple],
                            'types': {vehicle: 'vehicle', company: 'company', country: 'country'}
                        })
                    for other_country in countries_set:
                        if f'LocatedIn({company}, {other_country})' not in abox:
                            neg_query = Predicate('OriginatesFrom', Constant(vehicle, 'vehicle'), Constant(other_country, 'country'))
                            queries.append(neg_query)
                            queries_nl.append(
                                {
                                    'query_id': f'{(len(queries_nl))}',
                                    'query': f'OriginatesFrom({vehicle}, {other_country})',
                                    'label': 'False',
                                    'axiom': f'ManufacturedBy({vehicle}, {company}) && LocatedIn({company}, {country}) => UsedIn({vehicle}, {other_country})',
                                    'triples': [triple, other_triple],
                                    'types': {vehicle: 'vehicle', company: 'company', other_country: 'country'}
                                })
                            break
            
            if len(queries) >= 5* max_queries_perkind:
                break
        # if len(queries) >= max_queries_perkind:
        #     break
    selected_indices = random.sample(range(len(queries)), min(len(queries), max_queries_perkind))
    queries = [queries[i] for i in selected_indices]
    queries_nl = [queries_nl[i] for i in selected_indices]
    print("vehicles queries generated")
    return queries_nl, queries_nl_inexact, queries, queries_inexact


def generate_medical_queries(abox, inexact_matches, ents, max_queries_perkind = 100):
    queries = []; queries_inexact = []
    queries_nl = []; queries_nl_inexact = []
    query_counter = 0
    for triple in abox:
        predicate, arg1, arg2 = parse_triple(triple)
        if predicate == 'Treats':
            medication1 = arg1; disease = arg2
            for other_triple in abox:
                other_predicate, other_arg1, other_arg2 = parse_triple(other_triple)
                if other_predicate == 'Treats' and other_arg1 != medication1 and other_arg2 == disease:
                    medication2 = other_arg1
                    query = Predicate('CanBeUsedInterchangeably', Constant(medication1, 'chemical entity'), Constant(medication2, 'chemical entity'))
                    queries.append(query)
                    if len(queries) % 10 == 0 or len(queries) % 9 == 0:
                        print(len(queries))
                    queries_nl.append(
                        {
                            'query_id': f'{(len(queries_nl))}',
                            'query': f'CanBeUsedInterchangeably({medication1}, {medication2})',
                            'label': 'True',
                            'axiom': f'Treats({medication1}, {disease}) && Treats({medication2}, {disease}) => CanBeUsedInterchangeably({medication1}, {medication2})',
                            'triples': [triple, other_triple],
                            'types': {medication1: 'chemical entity', medication2: 'chemical entity', disease: 'disease'}
                        })
                    for ent in ents:
                        if get_wikidata_types(ent) == 'chemical entity' and f'Treats({ent}, disease)' not in abox:
                            neg_query = Predicate('CanBeUsedInterchangeably', Constant(medication1, 'chemical entity'), Constant(ent, 'chemical entity'))
                            queries.append(neg_query)
                            queries_nl.append(
                                {
                                    'query_id': f'{(len(queries_nl))}',
                                    'query': f'CanBeUsedInterchangeably({medication1}, {ent})',
                                    'label': 'False',
                                    'axiom': f'Treats({medication1}, {disease}) && Treats({ent}, {disease}) => CanBeUsedInterchangeably({medication1}, {ent})',
                                    'triples': [triple, other_triple],
                                    'types': {medication1: 'chemical entity', ent: 'chemical entity', disease: 'disease'}
                                })
                            break
        if len(queries) >= 5* max_queries_perkind:
                break               
    selected_indices = random.sample(range(len(queries)), min(len(queries), max_queries_perkind))
    queries = [queries[i] for i in selected_indices]
    queries_nl = [queries_nl[i] for i in selected_indices]
    print("medical queries generated")
    return queries_nl, queries_nl_inexact, queries, queries_inexact

def generate_sports_queries(abox, inexact_matches, ents, max_queries_perkind = 100):
    queries = []; queries_inexact = []
    queries_nl = []; queries_nl_inexact = []
    query_counter = 0
    countries_set = set()
    for triple in abox:
        pred, arg1, arg2 = parse_triple(triple)
        if pred == 'LocatedIn':
            countries_set.add(arg2)


    for triple in abox:
        predicate, arg1, arg2 = parse_triple(triple)
        if predicate == 'MemberOf':
            player = arg1; team = arg2
            player_countries = set()
            for other_triple in abox:
                predicate2, arg1_2, arg2_2 = parse_triple(other_triple)
                if predicate2 == 'LocatedIn' and arg1_2 == team:
                    player_countries.add(arg2_2)
                    country = arg2_2
                    query = Predicate('PaidTaxesIn', Constant(player, 'human'), Constant(country, 'country'))
                    queries.append(query)
                    queries_nl.append(
                        {
                            'query_id': f'{(len(queries_nl))}',
                            'query': f'PaidTaxesIn({player}, {country})',
                            'label': 'True',
                            'axiom': f'MemberOf({player}, {team}) && LocatedIn({team}, {country}) => PaidTaxesIn({player}, {country})',
                            'triples': [triple, other_triple],
                            'types': {player: 'human', team: 'sport club', country: 'country'}
                        })
            neg_countries = list(countries_set - player_countries)
            neg_country = random.sample(neg_countries, 1)[0]
            neg_query = Predicate('PaidTaxesIn', Constant(player, 'human'), Constant(neg_country, 'country'))
            queries.append(neg_query)
            queries_nl.append(
                {
                    'query_id': f'{(len(queries_nl))}',
                    'query': f'PaidTaxesIn({player}, {neg_country})',
                    'label': 'False',
                    'axiom': f'MemberOf({player}, {team}) && LocatedIn({team}, {country}) => PaidTaxesIn({player}, {neg_country})',
                    'triples': [triple],
                    'types': {player: 'human', team: 'sport club', neg_country: 'country'}
                })
            if len(queries) % 10 == 0:
                print(len(queries))
        if len(queries) >= 5* max_queries_perkind:
                break
    selected_indices = random.sample(range(len(queries)), min(len(queries), max_queries_perkind))
    queries = [queries[i] for i in selected_indices]
    queries_nl = [queries_nl[i] for i in selected_indices]
    print("sports queries generated")
    return queries_nl, queries_nl_inexact, queries, queries_inexact        




if __name__ == "__main__":

    # query_preds = []

    # with open('data/queries_onto_inexact.pkl', 'rb') as f:
    #     all_queries_inexact = pickle.load(f)
    # with open('data/queries_onto_nl.json', 'r') as f:
    #     all_queries_nl = json.load(f)
    # with open('data/onto_inexact_matches.json', 'r') as f:
    #         inexact_matches_dict = json.load(f)
    
    # exact_names = list(inexact_matches_dict.keys())

    # for i, query in enumerate(all_queries_inexact):

    #     name = query.name
    #     for exact_name in exact_names:
    #         if name in inexact_matches_dict[exact_name]:
    #             query_preds.append({'predicate': name, 'correct_match': exact_name})
    #             break

    # with open('data/onto_pred_matches.json', 'w') as f:
    #     json.dump(query_preds, f, indent=4, ensure_ascii=False)





    mode = 'extension'

    if mode == "generation":

        inexact_matches = {
            'CanDriveBetween': 'CanTravelByCarBetween',
            'NeedsBorderCrossing': 'RequiresCrossingCountryBorders',
            'LandConnected': 'ConnectedThroughLand',
            'PartOf': 'IsARegionOf',
            'LocatedIn': 'PositionedIn',
            'At': 'PositionedAt',
            'CanSkiIn': 'HasSkiingOpportunities',
            'CanVisitMuseum': 'HasMuseumVisitingOpportunities',
            'CanSailIn': 'HasSailingOpportunities',
            'CanCampIn': 'HasCampingOpportunities'
        }


        with open('data/a-box-onto.json', 'r') as f:
            abox = json.load(f)

        ents = set()
        problematic_triples = []
        for triple in abox:
            try:
                predicate, arg1, arg2 = parse_triple(triple)
                ents.add(arg1.strip())
                if arg2 != 'None':
                    continue
                else:
                    ents.add(arg2.strip())
                
            except:
                problematic_triples.append(triple)
        ents.remove('')
        if problematic_triples:
            for triple in problematic_triples:
                abox.remove(triple)


        animals_queries_nl, animals_queries_nl_inexact, animals_queries, animals_queries_inexact = generate_animals_queries(abox, inexact_matches, ents,  max_queries_perkind=200)
        foods_queries_nl, foods_queries_nl_inexact, foods_queries, foods_queries_inexact = generate_foods_queries(abox, inexact_matches, ents,  max_queries_perkind=200)
        vehicles_queries_nl, vehicles_queries_nl_inexact, vehicles_queries, vehicles_queries_inexact = generate_vehicle_queries(abox, inexact_matches, ents, max_queries_perkind=200)
        medical_queries_nl, medical_queries_nl_inexact, medical_queries, medical_queries_inexact = generate_medical_queries(abox, inexact_matches, ents, max_queries_perkind=200)
        sports_queries_nl, sports_queries_nl_inexact, sports_queries, sports_queries_inexact = generate_sports_queries(abox, inexact_matches, ents, max_queries_perkind=200)
        all_queries = animals_queries + foods_queries + vehicles_queries + medical_queries + sports_queries
        all_queries_nl = animals_queries_nl + foods_queries_nl + vehicles_queries_nl + medical_queries_nl + sports_queries_nl



        with open('data/queries_onto.pkl', 'wb') as f:
            pickle.dump(all_queries, f)
        
        with open('data/queries_onto_nl.json', 'w') as f:
            json.dump(all_queries_nl, f, indent=4, ensure_ascii=False)

        # with open('data/queries_onto_inexact.pkl', 'wb') as f:
        #     pickle.dump(all_queries_inexact, f)

        # with open('data/queries_onto_inexact_nl.json', 'w') as f:
        #     json.dump(all_queries_nl_inexact, f, indent=4, ensure_ascii=False)


        with open('data/queries_onto_nl.json', 'r') as f:
            nl_data = json.load(f)

        for i, query in enumerate(nl_data):
            query['query_id'] = f'{i}'

        with open('data/queries_onto_nl.json', 'w') as f:
            json.dump(nl_data, f, indent=4, ensure_ascii=False)

    elif mode == "extension":

        with open('data/queries_onto.pkl', 'rb') as f:
            all_queries = pickle.load(f)
        with open('data/onto_inexact_matches.json', 'r') as f:
            inexact_matches_dict = json.load(f)


        all_queries_inexact = []
        for query in all_queries:
            if query.name in inexact_matches_dict:
                if isinstance(inexact_matches_dict[query.name], list):
                    new_predicate = random.sample(inexact_matches_dict[query.name], 1)[0]
                    all_queries_inexact.append(Predicate(new_predicate, query.args[0], query.args[1]))
            else:
                all_queries_inexact.append(query)


        with open('data/queries_onto_inexact.pkl', 'wb') as f:
            pickle.dump(all_queries_inexact, f)