import sys
sys.path.append('..')

import torch
import numpy as np

import dataset_utils
import fetch_tabular

if __name__ == "__main__":
    # 'Parkinsons_Telemonitoring',
    datasets_list = [
        'Phishing_Websites',
        # 'NPHA',

                # 'PhiUSIIL',
        # 'RT-IoT2022',

                 # 'AIDS_Clinical_Trials_Group_Study_175',
        # 'Secondary_Mushroom',
        # 'NHANES',
        #          'Cirrhosis_Patient_Survival_Prediction',
        # 'MetroPT-3',
        # 'Regensburg_Pediatric_Appendicitis',
        #          'Land_Mines',
        # 'Glioma_Grading_Clinical_and_Mutation_Features',
                 # 'Differentiated_Thyroid_Cancer_Recurrence',
                 # 'SUPPORT2',
##                 'Infrared_Thermography_Temperature',
        # 'DARWIN',  # 'Recipe_Reviews_and_User_Feedback',
##                 # 'Single_Elder_Home_Monitoring_Gas_and_Position', 'MOVER',
                 # 'Forty_Soybean_Cultivars_from_Subsequent_Harvests',
        # 'Auction_Verification',


        # 'Palmer_Penguins',
        #          'NATICUSdroid',
        #          'Toxicity',
        #          'Accelerometer_Gyro_Mobile_Phone',                 'Caesarian_Section_Classification_Dataset',
        #          'TUNADROMD', 'Drug_induced_Autoimmunity_Prediction', 'Cryotherapy', 'Period_Changer',
        #          'Sirtuin6_Small_Molecules', 'Z-Alizadeh_Sani', 'PIRvision_FoG_presence_detection',
        #          'Iris',
        #          'Wine',
        #          'Abalone', 'CDC_Diabetes_Health_Indicators'
                 ]

    dataset_list = ["Palmer_Penguins", "AIDS_Clinical_Trials_Group_Study_175",
                    "Forty_Soybean_Cultivars_from_Subsequent_Harvests", "Cirrhosis_Patient_Survival_Prediction", "NATICUSdroid",
                    "Glioma_Grading_Clinical_and_Mutation_Features", "PIRvision_FoG_presence_detection", "CDC_Diabetes_Health_Indicators",
                    "Auction_Verification", "Period_Changer", "SUPPORT2", "Wine", "Iris", "Cryotherapy", "Phishing_Websites", "Sirtuin6_Small_Molecules",
                    "Land_Mines", "Z-Alizadeh_Sani", "NHANES", "PhiUSIIL", "Differentiated_Thyroid_Cancer_Recurrence", "RT-IoT2022",
                    "Regensburg_Pediatric_Appendicitis", "DARWIN", "TUNADROMD", "Toxicity", "Drug_induced_Autoimmunity_Prediction", "MetroPT-3",
                    "Accelerometer_Gyro_Mobile_Phone", "Caesarian_Section_Classification_Dataset", "NPHA"]
    import pandas as pd
    df = pd.DataFrame(columns=['names', 'size', 'dims', 'n_class'])
    for d_name in dataset_list:
        x, y = fetch_tabular.get_dataset_by_name(d_name)
        x = x.to_numpy().astype(float)
        n, m = x.shape
        n_class = len(set(y))
        d = {'names': d_name, 'size': n, 'dims': m, 'n_class': n_class}
        df = df._append(pd.DataFrame(d, index=[0]))

    print(df.to_latex(index=False))
        # if not type(y) is np.ndarray:
        #     y = y.to_numpy()
        # y = y.reshape(-1)
        # n, _ = x.shape
        # n_class = len(set(y.tolist()))
        #
        # if n <= 4000:
        #     # x = x
        #     mu = np.mean(x, axis=0)
        #     std = np.std(x, axis=0)
        #     x = (x - mu) / (std + 1e-9)
        #
        #     x = torch.from_numpy(x)
        #
        #     cdist = torch.cdist(x, x)
        #     cdist = cdist.cpu().numpy().astype(np.float16)
        #     indices = np.arange(0, n)
        #     print(indices)
        #
        #     to_save = (cdist, y, indices, np.arange(n_class))
        #     save_verbo = f'{d_name}-{n_class}class-comb{0}'
        #     torch.save(to_save, f'/mnt/data01/public/aad_data/uci/{save_verbo}_size{n}_seed{0}_cdist.tar')
        #     print(f'/mnt/data01/public/aad_data/uci/{save_verbo}_size{n}_seed{0}_cdist.tar')
        # else:
        #     dataset_utils.generate_from_multi_class_dataset(dataset_name=d_name, x_data=x, y_data=y,
        #                                                     n_total_class=n_class, n_selected_class=n_class,
        #                                                     n_used=64,
        #                                                     min_size=500, max_size=4000, n_interval=10,
        #                                                     normalized=True, repeat=10,
        #                                                     save_path=f'/mnt/data01/public/aad_data/uci',
        #                                                     )

