
import pandas as pd
import torch

from ucimlrepo import fetch_ucirepo

def process_dataset_by_name(d_name):
    name_list = ['Phishing_Websites', 'NPHA', 'PhiUSIIL', 'RT-IoT2022', 'Parkinsons_Telemonitoring',
                 'AIDS_Clinical_Trials_Group_Study_175', 'Secondary_Mushroom', 'NHANES',
                 'Cirrhosis_Patient_Survival_Prediction', 'MetroPT-3', 'Regensburg_Pediatric_Appendicitis',
                 'Land_Mines', 'Glioma_Grading_Clinical_and_Mutation_Features', 'Differentiated_Thyroid_Cancer_Recurrence',
                 'SUPPORT2', 'Infrared_Thermography_Temperature', 'DARWIN',# 'Recipe_Reviews_and_User_Feedback',
                 'Single_Elder_Home_Monitoring_Gas_and_Position', 'MOVER',
                 'Forty_Soybean_Cultivars_from_Subsequent_Harvests', 'Auction_Verification', 'Palmer_Penguins',
                 'NATICUSdroid', 'Toxicity', 'Accelerometer_Gyro_Mobile_Phone', 'Caesarian_Section_Classification_Dataset',
                 'TUNADROMD', 'Drug_induced_Autoimmunity_Prediction', 'Cryotherapy', 'Period_Changer',
                 'Sirtuin6_Small_Molecules', 'Z-Alizadeh_Sani', 'PIRvision_FoG_presence_detection', 'Iris', 'Wine',
                 'Abalone', 'CDC_Diabetes_Health_Indicators']
    assert d_name in name_list
    if d_name == 'Phishing_Websites':
        # fetch dataset
        phishing_websites = fetch_ucirepo(id=327)

        # data (as pandas dataframes)
        x = phishing_websites.data.features
        y = phishing_websites.data.targets
    elif d_name == 'NPHA':
        # fetch dataset
        national_poll_on_healthy_aging_npha = fetch_ucirepo(id=936)
        # data (as pandas dataframes)
        x = national_poll_on_healthy_aging_npha.data.features
        y = national_poll_on_healthy_aging_npha.data.targets
    elif d_name == 'PhiUSIIL':
        phiusiil_phishing_url_website = fetch_ucirepo(id=967)

        # data (as pandas dataframes)
        x = phiusiil_phishing_url_website.data.features
        y = phiusiil_phishing_url_website.data.targets
        x = x.drop('URL', axis=1)
        x = x.drop('Domain', axis=1)
        x = x.drop('Title', axis=1)
        x = x.drop('TLD', axis=1)
    elif d_name == 'RT-IoT2022':
        rt_iot2022 = fetch_ucirepo(id=942)

        # data (as pandas dataframes)
        x = rt_iot2022.data.features
        y = rt_iot2022.data.targets
        y, _ = pd.factorize(y.iloc[:, 0])
    elif d_name == 'Parkinsons_Telemonitoring':
        parkinsons_telemonitoring = fetch_ucirepo(id=189)

        # data (as pandas dataframes)
        x = parkinsons_telemonitoring.data.features
        y = parkinsons_telemonitoring.data.targets
    elif d_name == 'AIDS_Clinical_Trials_Group_Study_175':
        aids_clinical_trials_group_study_175 = fetch_ucirepo(id=890)

        # data (as pandas dataframes)
        x = aids_clinical_trials_group_study_175.data.features
        y = aids_clinical_trials_group_study_175.data.targets
    elif d_name == 'Secondary_Mushroom':
        secondary_mushroom = fetch_ucirepo(id=848)

        # data (as pandas dataframes)
        x = secondary_mushroom.data.features
        y = secondary_mushroom.data.targets
    elif d_name == 'NHANES':
        national_health_and_nutrition_health_survey_2013_2014_nhanes_age_prediction_subset = fetch_ucirepo(id=887)

        # data (as pandas dataframes)
        x = national_health_and_nutrition_health_survey_2013_2014_nhanes_age_prediction_subset.data.features
        y = national_health_and_nutrition_health_survey_2013_2014_nhanes_age_prediction_subset.data.targets
        y, _ = pd.factorize(y.iloc[:, 0])
    elif d_name == 'Cirrhosis_Patient_Survival_Prediction':
        cirrhosis_patient_survival_prediction = fetch_ucirepo(id=878)

        # data (as pandas dataframes)
        x = cirrhosis_patient_survival_prediction.data.features
        y = cirrhosis_patient_survival_prediction.data.targets
        y, _ = pd.factorize(y.iloc[:, 0])
    # elif d_name == 'Assessing_Mathematics_Learning_in_Higher_Education':
    #     # this one is not good
    elif d_name == 'MetroPT-3':
        df = pd.read_csv('./data/MetroPT3.csv', index_col=0)
        x = df.drop('y', axis=1)
        y = df['y']
    elif d_name == 'Regensburg_Pediatric_Appendicitis':
        # fetch dataset
        regensburg_pediatric_appendicitis = fetch_ucirepo(id=938)

        # data (as pandas dataframes)
        x = regensburg_pediatric_appendicitis.data.features
        y = regensburg_pediatric_appendicitis.data.targets
        x = x.fillna('no')
        x_ = x
        x_.loc[:, 'y'] = y['Diagnosis']
        x_ = x_.dropna()
        x = x_.drop('y', axis=1)
        y, _ = pd.factorize(x_['y'])

        # y, _ = pd.factorize(y['Diagnosis'])
    # elif d_name == 'Product Classification and Clustering':
    #     # this one is not good
    elif d_name == 'Land_Mines':
        land_mines = fetch_ucirepo(id=763)

        # data (as pandas dataframes)
        x = land_mines.data.features
        y = land_mines.data.targets
    elif d_name == 'Glioma_Grading_Clinical_and_Mutation_Features':
        glioma_grading_clinical_and_mutation_features = fetch_ucirepo(id=759)

        # data (as pandas dataframes)
        x = glioma_grading_clinical_and_mutation_features.data.features
        y = glioma_grading_clinical_and_mutation_features.data.targets
    elif d_name == 'Differentiated_Thyroid_Cancer_Recurrence':
        differentiated_thyroid_cancer_recurrence = fetch_ucirepo(id=915)

        # data (as pandas dataframes)
        x = differentiated_thyroid_cancer_recurrence.data.features
        y = differentiated_thyroid_cancer_recurrence.data.targets
        y, _ = pd.factorize(y.iloc[:, 0])
    # TCGA Kidney Cancers: no longer avaliable
    elif d_name == 'SUPPORT2':
        support2 = fetch_ucirepo(id=880)

        # data (as pandas dataframes)
        x = support2.data.features
        y = support2.data.targets
        x_ = x
        x_.loc[:, 'y'] = y['death']
        x_ = x_.dropna()
        x = x_.drop('y', axis=1)
        y = x_['y']
    elif d_name == 'Infrared_Thermography_Temperature':
        infrared_thermography_temperature = fetch_ucirepo(id=925)

        # data (as pandas dataframes)
        x = infrared_thermography_temperature.data.features
        y = infrared_thermography_temperature.data.targets
    elif d_name == 'DARWIN':
        darwin = fetch_ucirepo(id=732)

        # data (as pandas dataframes)
        x = darwin.data.features
        y = darwin.data.targets
    # elif d_name == 'Recipe_Reviews_and_User_Feedback':
    # text data
    #     recipe_reviews_and_user_feedback = fetch_ucirepo(id=911)
    #
    #     # data (as pandas dataframes)
    #     x = recipe_reviews_and_user_feedback.data.features
    #     y = recipe_reviews_and_user_feedback.data.targets
    # elif d_name == 'Single_Elder_Home_Monitoring_Gas_and_Position':
    # no label
    #     single_elder_home_monitoring_gas_and_position = fetch_ucirepo(id=799)
    #
    #     # data (as pandas dataframes)
    #     x = single_elder_home_monitoring_gas_and_position.data.features
    #     y = single_elder_home_monitoring_gas_and_position.data.targets
    #     x = x.drop('timestamp', axis=1)
    elif d_name == 'MOVER':
        pass # TODO: wait for download link
    elif d_name == 'Forty_Soybean_Cultivars_from_Subsequent_Harvests':
        forty_soybean_cultivars_from_subsequent_harvests = fetch_ucirepo(id=913)

        # data (as pandas dataframes)
        x = forty_soybean_cultivars_from_subsequent_harvests.data.features
        y = forty_soybean_cultivars_from_subsequent_harvests.data.targets
        x_ = x.drop('Repetition', axis=1)
        x = x_.drop('Cultivar', axis=1)
        y, _ = pd.factorize(x_['Cultivar'])
    elif d_name == 'Auction_Verification':
        auction_verification = fetch_ucirepo(id=713)

        # data (as pandas dataframes)
        x = auction_verification.data.features
        y = auction_verification.data.targets
    elif d_name == 'Palmer_Penguins':
        # need R: https://allisonhorst.github.io/palmerpenguins/
        df = pd.read_csv('./data/penguins_processed.csv', index_col=0)
        x = df.drop('island', axis=1)
        y = df['island']
    elif d_name == 'NATICUSdroid':
        print(123)
        naticusdroid_android_permissions = fetch_ucirepo(id=722)

        # data (as pandas dataframes)
        x = naticusdroid_android_permissions.data.features
        y = naticusdroid_android_permissions.data.targets
    elif d_name == 'Toxicity':
        toxicity = fetch_ucirepo(id=728)

        # data (as pandas dataframes)
        x = toxicity.data.features
        y = toxicity.data.targets
        y, _ = pd.factorize(y.iloc[:, 0])

    # elif d_name == 'Printed_Circuit_Board_Processed_Image':
    #     pass # no label
    elif d_name == 'Accelerometer_Gyro_Mobile_Phone':
        accelerometer_gyro_mobile_phone = fetch_ucirepo(id=755)

        # data (as pandas dataframes)
        x = accelerometer_gyro_mobile_phone.data.features
        # y = accelerometer_gyro_mobile_phone.data.targets
        x_ = x.drop('timestamp', axis=1)
        x = x_.drop('Activity', axis=1)
        y = x_['Activity']
    # elif d_name == 'Rocket_League_Skillshots'
    #     #  too complicated
    elif d_name == 'Caesarian_Section_Classification_Dataset':
        df = pd.read_csv('./data/caesarian_processed.csv', index_col=0)
        x = df.drop('Caesarian', axis=1)
        y = df['Caesarian']
    elif d_name == 'TUNADROMD':
        df = pd.read_csv('./data/TUANDROMD.csv')
        x = df.drop('Label', axis=1)
        y = df['Label']
    # elif d_name == 'Somerville_Happiness_Survey':
    #     pass # no label
    elif d_name == 'Drug_induced_Autoimmunity_Prediction':
        df = pd.read_csv('./data/Drug_induced_Autoimmunity_Prediction_processed.csv', index_col=0)
        df = df.drop('SMILES', axis=1)
        x = df.drop('Label', axis=1)
        y = df['Label']
    elif d_name == 'Cryotherapy':
        df = pd.read_csv('./data/Cryotherapy.csv', index_col=0)
        x = df.drop('Result_of_Treatment', axis=1)
        y = df['Result_of_Treatment']
    elif d_name == 'Period_Changer':
        df = pd.read_csv('./data/period_changer-2_processed.csv', index_col=0)
        x = df.drop('Class', axis=1)
        y = df['Class']
    # elif d_name == 'Sundanese Twitter Dataset'
    #     pass # language dataset can be used in the future
    elif d_name == 'Sirtuin6_Small_Molecules':
        df = pd.read_csv('./data/SIRTUIN6_processed.csv', index_col=0)
        x = df.drop('Class', axis=1)
        y = df['Class']
    elif d_name == 'Z-Alizadeh_Sani':
        df = pd.read_csv('./data/Z-Alizadeh_sani_dataset_processed.csv', index_col=0)
        x = df.drop('Cath', axis=1)
        y = df['Cath']
    elif d_name == 'PIRvision_FoG_presence_detection':
        df = pd.read_csv('./data/pirvision_office_dataset_processed.csv', index_col=0)
        x = df.drop('Label', axis=1)
        y = df['Label']
    elif d_name == 'Iris':
        iris = fetch_ucirepo(id=53)

        # data (as pandas dataframes)
        x = iris.data.features
        y = iris.data.targets
    elif d_name == 'Wine':
        # fetch dataset
        wine = fetch_ucirepo(id=109)

        # data (as pandas dataframes)
        x = wine.data.features
        y = wine.data.targets
    elif d_name == 'Students_Dropout_and_Academic_Success':
        predict_students_dropout_and_academic_success = fetch_ucirepo(id=697)

        # data (as pandas dataframes)
        x = predict_students_dropout_and_academic_success.data.features
        y = predict_students_dropout_and_academic_success.data.targets
    elif d_name == 'Abalone':
        abalone = fetch_ucirepo(id=1)

        # data (as pandas dataframes)
        x = abalone.data.features
        y = abalone.data.targets
    elif d_name == 'CDC_Diabetes_Health_Indicators':
        cdc_diabetes_health_indicators = fetch_ucirepo(id=891)

        # data (as pandas dataframes)
        x = cdc_diabetes_health_indicators.data.features
        y = cdc_diabetes_health_indicators.data.targets
    else:
        print(d_name)
        raise NotImplementedError

    # transform catgorical value into one-hot
    x = pd.get_dummies(x)
    torch.save((x, y), f'./downloaded_data/{d_name}_x_y.tar')
    return x, y


def get_dataset_by_name(d_name):
    name_list = ['Phishing_Websites', 'NPHA', 'PhiUSIIL', 'RT-IoT2022', 'Parkinsons_Telemonitoring',
                 'AIDS_Clinical_Trials_Group_Study_175', 'Secondary_Mushroom', 'NHANES',
                 'Cirrhosis_Patient_Survival_Prediction', 'MetroPT-3', 'Regensburg_Pediatric_Appendicitis',
                 'Land_Mines', 'Glioma_Grading_Clinical_and_Mutation_Features',
                 'Differentiated_Thyroid_Cancer_Recurrence',
                 'SUPPORT2', 'Infrared_Thermography_Temperature', 'DARWIN',  # 'Recipe_Reviews_and_User_Feedback',
                 'MOVER',
                 'Forty_Soybean_Cultivars_from_Subsequent_Harvests', 'Auction_Verification', 'Palmer_Penguins',
                 'NATICUSdroid', 'Toxicity', 'Accelerometer_Gyro_Mobile_Phone',
                 'Caesarian_Section_Classification_Dataset',
                 'TUNADROMD', 'Drug_induced_Autoimmunity_Prediction', 'Cryotherapy', 'Period_Changer',
                 'Sirtuin6_Small_Molecules', 'Z-Alizadeh_Sani', 'PIRvision_FoG_presence_detection', 'Iris', 'Wine',
                 'Abalone', 'CDC_Diabetes_Health_Indicators']
    assert d_name in name_list
    x, y = torch.load(f'./downloaded_data/{d_name}_x_y.tar', weights_only=False)
    # x.loc[:, 'y'] = y
    # x = x.dropna()
    # x_ = x.drop('y', axis=1)
    # y_ = x['y']
    # y, _ = pd.factorize(y.iloc[:, 0])
    # torch.save((x, y), f'./downloaded_data/{d_name}_x_y.tar')
    return x, y


if __name__ == '__main__':
    name_list = [
        'Phishing_Websites',
        # 'NPHA',
        # 'PhiUSIIL',
                 # 'RT-IoT2022',
        # 'Parkinsons_Telemonitoring',
        #          'AIDS_Clinical_Trials_Group_Study_175',
        #          'Secondary_Mushroom',
        #          'NHANES',
        #          'Cirrhosis_Patient_Survival_Prediction',
        #          'MetroPT-3',
        #          'Regensburg_Pediatric_Appendicitis',
        #          'Land_Mines',
        #          'Glioma_Grading_Clinical_and_Mutation_Features',
        #          'Differentiated_Thyroid_Cancer_Recurrence',
        #          'SUPPORT2',
        # 'Infrared_Thermography_Temperature',
        # 'DARWIN',  # 'Recipe_Reviews_and_User_Feedback',
        #          # 'Single_Elder_Home_Monitoring_Gas_and_Position', 'MOVER',
        #          'Forty_Soybean_Cultivars_from_Subsequent_Harvests',
        #          'Auction_Verification',
        #          'Palmer_Penguins',
        #          'NATICUSdroid',
        #          'Toxicity',
        #          'Accelerometer_Gyro_Mobile_Phone',
        #          'Caesarian_Section_Classification_Dataset',
        #          'TUNADROMD', 'Drug_induced_Autoimmunity_Prediction', 'Cryotherapy', 'Period_Changer',
        #          'Sirtuin6_Small_Molecules', 'Z-Alizadeh_Sani', 'PIRvision_FoG_presence_detection',
        #          'Iris',
        #          'Wine',
        #          'Abalone', 'CDC_Diabetes_Health_Indicators'
    ]
    for d_name in name_list:
        x, y = get_dataset_by_name(d_name)
        print(y, len(set(y)))
        print(d_name, x.shape, y.shape)
