import pandas as pd
import numpy as np
import pickle

import imodels
from torchvision.datasets import MNIST
from PIL import Image
from scipy.io import arff
import warnings
from pmlb import fetch_data
warnings.filterwarnings("ignore")


one_hot = True
datapath = "../data"

def column_to_numeric(series):
    values = series.unique()
    mapping = {}
    for ind,v in enumerate(values):
        mapping[v] = ind
    return mapping
    
# German Credit Dataset (Binary Classification Risk/No Risk)
# https://archive.ics.uci.edu/ml/datasets/statlog+(german+credit+data)
def load_credit():
    df = pd.read_csv(datapath+"/credit/german_credit_data.csv",sep=",",index_col=0)
    df.dropna(inplace=True)
    
    df["Risk"].replace( {"good":"Low Risk","bad":"High Risk"},inplace=True)
    df["Job"].replace({0 : "unskilled non-resident" , 1 : "unskilled resident", 2 : "skilled", 3 : "highly skilled"},inplace=True)
    
    df["Risk"].replace( {"Low Risk":0,"High Risk":1},inplace=True)
    df = pd.get_dummies(df,columns=["Job","Housing","Saving accounts","Checking account","Purpose"],dtype=int)
    
    df['Age'] = df['Age'].astype('int')
    df['Duration'] = df['Duration'].astype('int')
    df['Credit amount'] = df['Credit amount'].astype('int')
    
    output = {}
    output["df"] = df.copy(True)
    output["target"] = df["Risk"].to_numpy()
    df = df.drop("Risk",axis=1) 
    output["mapper"] = {}
    output["target_name"] = "Risk of Default"
    turn_to_numeric = list(filter(lambda x: not pd.api.types.is_numeric_dtype(df[x]),df.columns.values.tolist()))
    for col in turn_to_numeric:
        replacement = column_to_numeric(df[col])
        df[col].replace(replacement,inplace=True)   
        index = df.columns.values.tolist().index(col)
        output["mapper"][index] = {v: k for k, v in replacement.items()}
    
    output["data"] = df.to_numpy()
    output["feature_names"] = df.columns.values.tolist()
    return output

def load_default():
    df = pd.read_csv(datapath+"/default/default.csv",sep=",",index_col=0)
    df.dropna(inplace=True)
    for i in range(1,7):
        df.drop("PAY_{:d}".format(i),axis=1,inplace=True)
        df.drop("PAY_AMT{:d}".format(i),axis=1,inplace=True)
        df.drop("BILL_AMT{:d}".format(i),axis=1,inplace=True)
    output = {}
    df["CREDIT LIMIT"] = (df["CREDIT LIMIT"]/33).round()
    df["Default"].replace( {0:"No Default",1:"Default"},inplace=True)
    df["SEX"].replace({1:"male",2:"female"},inplace=True)
    df["EDUCATION"].replace({1:"graduate school",2:"university",3:"high school",4:"other"},inplace=True)
    df["MARRIAGE"].replace({1:"married",2:"single",3:"other"},inplace=True)
    output["df"] = df.copy(True)
    df["Default"].replace( {"No Default":0,"Default":1},inplace=True)
    output["target"] = df["Default"].to_numpy()
    df = df.drop("Default",axis=1) 
    output["mapper"] = {}
    output["target_name"] = "Risk of Default"
    turn_to_numeric = list(filter(lambda x: not pd.api.types.is_numeric_dtype(df[x]),df.columns.values.tolist()))
    for col in turn_to_numeric:
        replacement = column_to_numeric(df[col])
        df[col].replace(replacement,inplace=True)   
        index = df.columns.values.tolist().index(col)
        output["mapper"][index] = {v: k for k, v in replacement.items()}
    
    output["data"] = df.to_numpy()
    output["feature_names"] = df.columns.values.tolist()
    return output

def load_covid():
    df = pd.read_csv(datapath+"/covid/covid.csv")
    output = {}
    df.drop(["outcome","id","patient_id","weekday_change_of_status","hour_change_of_status","weekday_admit","hour_admit","days_change_of_status","date_admit","date_change_of_status","hospital"],axis=1,inplace=True)
    df.drop(df[df["group"]=="Patient"].index,inplace=True)
    df.dropna(inplace=True)
    output["df"] = df.copy(True)
    df["group"].replace(["Expired","Discharged"],[0,1],inplace=True)
    for col in ["sex","race"]:
        replacement = column_to_numeric(df[col])
        df[col].replace(replacement,inplace=True)
    df.dropna(inplace=True)
    output["target"] = df["group"].to_numpy()
    output["feature_names"] = df.columns.values.tolist()
    output["feature_names"].remove("group")
    df.drop("group",axis=1,inplace=True)
    output["data"] = df.to_numpy()
    output["target_name"] = "group"
    return output

def load_imodels_data(name):
    X,Y, feature_names = imodels.get_clean_dataset(name, data_source='imodels')
    return {"data":X, "target":Y, "feature_names":feature_names}

def load_pmlb_data(name):
    ds = fetch_data(name, local_cache_dir=datapath)
    Y = ds['target'].values
    Y = (Y==Y.max()).astype(int)
    X = ds.drop(columns=['target']).values
    feature_names = ds.drop(columns=['target']).columns
    return {"data":X, "target":Y, "feature_names":feature_names}


def load_ozone_level():
    arff_file = arff.loadarff(datapath+'/ozone-level/ozone-level.arff')

    df = pd.DataFrame(arff_file[0])
    labels = (df["Class"]==b'1').astype(int)
    df = df.drop("Class",axis=1)
    X = df.to_numpy()
    feature_names = df.columns.values.tolist()

    return {"data":X, "target":labels, "feature_names":feature_names}

def load_madelon():
    arff_file = arff.loadarff(datapath+'/madelon/madelon.arff')

    df = pd.DataFrame(arff_file[0])
    labels = (df["Class"]==b'1').astype(int)
    df = df.drop("Class",axis=1)
    X = df.to_numpy()
    feature_names = df.columns.values.tolist()

    return {"data":X, "target":labels, "feature_names":feature_names}

def load_pc1():
    arff_file = arff.loadarff(datapath+'/pc1/pc1.arff')

    df = pd.DataFrame(arff_file[0])
    labels = (df["defects"]==b'true').astype(int)
    df = df.drop("defects",axis=1)
    X = df.to_numpy()
    feature_names = df.columns.values.tolist()

    return {"data":X, "target":labels, "feature_names":feature_names}

def load_phoneme():
    arff_file = arff.loadarff(datapath+'/phoneme/phoneme.arff')

    df = pd.DataFrame(arff_file[0])
    labels = (df["Class"]==b"2").astype(int)
    df = df.drop("Class",axis=1)
    X = df.to_numpy()
    feature_names = df.columns.values.tolist()

    return {"data":X, "target":labels, "feature_names":feature_names}

def load_qsar_biodeg():
    arff_file = arff.loadarff(datapath+'/qsar-biodeg/qsar-biodeg.arff')

    df = pd.DataFrame(arff_file[0])
    labels = (df["Class"]==b"2").astype(int)
    df = df.drop("Class",axis=1)
    X = df.to_numpy()
    feature_names = df.columns.values.tolist()

    return {"data":X, "target":labels, "feature_names":feature_names}

def load_eeg_eye_state():
    arff_file = arff.loadarff(datapath+'/eeg-eye-state/eeg-eye-state.arff')

    df = pd.DataFrame(arff_file[0])
    labels = (df["Class"]==b"2").astype(int)
    df = df.drop("Class",axis=1)
    X = df.to_numpy()
    feature_names = df.columns.values.tolist()

    return {"data":X, "target":labels, "feature_names":feature_names}

def load_electricity():
    arff_file = arff.loadarff(datapath+'/electricity/electricity.arff')

    df = pd.DataFrame(arff_file[0])
    labels = (df["class"]==b"UP").astype(int)
    df = df.drop("class",axis=1)
    X = df.to_numpy()
    feature_names = df.columns.values.tolist()

    return {"data":X, "target":labels, "feature_names":feature_names}

def load_phishing():

    arff_file = arff.loadarff(datapath+'/phishing/phishing.arff')

    df = pd.DataFrame(arff_file[0])
    labels = (df["Result"]==b"1").astype(int)
    df = df.drop("Result",axis=1)
    X = df.to_numpy()
    feature_names = df.columns.values.tolist()

    return {"data":X, "target":labels, "feature_names":feature_names}

def load_android():

    df = pd.read_csv(datapath+'/android-permission/data.csv')
    
    labels = (df["Result"]).astype(int)
    df = df.drop("Result",axis=1)
    X = df.to_numpy()
    feature_names = df.columns.values.tolist()

    return {"data":X, "target":labels, "feature_names":feature_names}
 