import requests 
import os
import pandas as pd
from sparql_config import URL, HEADERS, SPARQL_QUERIES

def get_birth(dataitem):
    if 'birth' in dataitem:
        if dataitem['birth']['type'] == 'uri':
            return '?'
        elif dataitem['birth']['type'] == 'literal':
            value = dataitem['birth']['value'].split('-')
            if len(value[0]) > 0:
                #birthyear AD
                return value[0]
            else:
                #birthyear BC
                return str(value[1]) + " BC"
    else:
        return '?'

def return_if_exists(dataitem, field):
    if field in dataitem:
        return dataitem[field]['value']
    else:
        return '?'

def get_nid(nid_data):
    if nid_data == '?':
        return '?'
    else:
        value = nid_data.split('/')
        return value[-1]

def main():

    # import the occupations and check which ones are new (not yet extracted)
    occupations_id = pd.read_csv("occupations.csv", header=None).to_numpy()[1:]
    occupations = [item[0] for item in occupations_id]
    ids = [item[1] for item in occupations_id]

    extracted_occupations = []
    for filename in os.listdir("extracted_occupations"):
        if filename.endswith(".csv"):
            extracted_occupations.append(filename.replace(".csv", ""))
    
    new_occupations = list(set(occupations) - set(extracted_occupations))
    known_countries = pd.read_csv("known_countries.csv")

    # query all new occupations from wikidata
    for occupation in new_occupations:
        print(f"Processing occupation: {occupation}")
        occ_id = ids[occupations.index(occupation)]
        # The number of footballers is very high, so we restrict queries to known countries only
        if occupation == 'footballer':
            occ_data = []
            for row in known_countries[['nID', 'nationality']].itertuples(index=False):
                country_id = row.nID
                country_name = row.nationality
                query = SPARQL_QUERIES['footballer_query'].format(occupationID=occ_id, countryID=country_id)
                print(f"\tQuerying for footballers from {country_name}")
                data = requests.get(URL, params={'query': query, 'format': 'json'}, headers=HEADERS).json()
                print(f"\tFound {len(data['results']['bindings'])} results in wikidata.")

                if len(data['results']['bindings']) > 0:
                    for item in data['results']['bindings']:
                        birthyear = get_birth(item)  
                        occ_data.append({
                            'name': return_if_exists(item, 'individual'),
                            'gender': return_if_exists(item, 'gender'),
                            'nationality': country_name,
                            'nationalityID': country_id,
                            'birth': birthyear
                            })
                print(f"\tTotal so far: {len(occ_data)} datapoints")
        
        else:
            query = SPARQL_QUERIES['occupation_query'].format(occupationID=occ_id)
            data = requests.get(URL, params={'query': query, 'format': 'json'}, headers=HEADERS).json()
            print(f"Queried for {occupation} and found {len(data['results']['bindings'])} results in wikidata.")
            
            if len(data['results']['bindings']) > 0:
                occ_data = []
                for item in data['results']['bindings']:
                    birthyear = get_birth(item)
                    nationality = return_if_exists(item, 'nationality')
                    nid = '?'
                    if nationality != '?' and nationality in known_countries['nationality'].values:
                        nid = known_countries[known_countries['nationality']==nationality]['nID'].item()
                        
                    occ_data.append({
                        'name': return_if_exists(item, 'individual'),
                        'gender': return_if_exists(item, 'gender'),
                        'nationality': nationality,
                        'nationalityID': nid,
                        'birth': birthyear
                        })
        if len(occ_data) > 0:
            #save cleaned dataframe 
            file = 'extracted_occupations/' + occupation + '.csv'
            df = pd.DataFrame(occ_data)
            df = df.drop_duplicates(subset='name', keep='first')
            print("\tAfter cleanup: " + str(len(df)) + " datapoints left")
            df.to_csv(file, index=False)

    extraxted_occupations_gender_ratios = []
    for filename in os.listdir("extracted_occupations"):
        if filename.endswith(".csv"):
            occupation = filename.replace(".csv", "")
            df = pd.read_csv("extracted_occupations/" + filename)
            num_female = (df['gender'] == 'female').sum()
            num_male = (df['gender'] == 'male').sum()
            extraxted_occupations_gender_ratios.append((occupation, num_female/(num_female + num_male)))
    extraxted_occupations_gender_ratios = pd.DataFrame(extraxted_occupations_gender_ratios, columns=['occupation', 'women_ratio']).sort_values(by='women_ratio', ascending=False)
    extraxted_occupations_gender_ratios.to_csv("occupation_ratios.csv", index=False)


if __name__ == '__main__':
    main()