import numpy as np
import pandas as pd
import os
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, MinMaxScaler, QuantileTransformer, KBinsDiscretizer

ADLT_COLUMNS = ["age","workclass","fnlwgt","education","education-num","marital-status","occupation",
              "relationship","race","sex","capital-gain","capital-loss","hours-per-week","native-country","label"] 

COVR_COLUMNS = ["Elevation","Aspect","Slope","Horizontal_Distance_To_Hydrology","Vertical_Distance_To_Hydrology",
                "Horizontal_Distance_To_Roadways","Hillshade_9am","Hillshade_Noon","Hillshade_3pm","Horizontal_Distance_To_Fire_Points",
                "Wilderness_Area_01","Wilderness_Area_02","Wilderness_Area_03","Wilderness_Area_04",
                "Soil_Type_01","Soil_Type_02","Soil_Type_03","Soil_Type_04","Soil_Type_05",
                "Soil_Type_06","Soil_Type_07","Soil_Type_08","Soil_Type_09","Soil_Type_10",
                "Soil_Type_11","Soil_Type_12","Soil_Type_13","Soil_Type_14","Soil_Type_15",
                "Soil_Type_16","Soil_Type_17","Soil_Type_18","Soil_Type_19","Soil_Type_20",
                "Soil_Type_21","Soil_Type_22","Soil_Type_23","Soil_Type_24","Soil_Type_25",
                "Soil_Type_26","Soil_Type_27","Soil_Type_28","Soil_Type_29","Soil_Type_30",
                "Soil_Type_31","Soil_Type_32","Soil_Type_33","Soil_Type_34","Soil_Type_35",
                "Soil_Type_36","Soil_Type_37","Soil_Type_38","Soil_Type_39","Soil_Type_40","Type"]

YEAR_COLUMNS = ["Year",'TimbreAvg1', 'TimbreAvg2', 'TimbreAvg3', 'TimbreAvg4', 'TimbreAvg5', 'TimbreAvg6',
               'TimbreAvg7', 'TimbreAvg8', 'TimbreAvg9', 'TimbreAvg10', 'TimbreAvg11', 'TimbreAvg12',
               'TimbreCovariance1', 'TimbreCovariance2', 'TimbreCovariance3', 'TimbreCovariance4', 'TimbreCovariance5', 'TimbreCovariance6',
               'TimbreCovariance7', 'TimbreCovariance8', 'TimbreCovariance9', 'TimbreCovariance10', 'TimbreCovariance11', 'TimbreCovariance12',
               'TimbreCovariance13', 'TimbreCovariance14', 'TimbreCovariance15', 'TimbreCovariance16', 'TimbreCovariance17', 'TimbreCovariance18',
               'TimbreCovariance19', 'TimbreCovariance20', 'TimbreCovariance21', 'TimbreCovariance22', 'TimbreCovariance23', 'TimbreCovariance24',
               'TimbreCovariance25', 'TimbreCovariance26', 'TimbreCovariance27', 'TimbreCovariance28', 'TimbreCovariance29', 'TimbreCovariance30',
               'TimbreCovariance31', 'TimbreCovariance32', 'TimbreCovariance33', 'TimbreCovariance34', 'TimbreCovariance35', 'TimbreCovariance36',
               'TimbreCovariance37', 'TimbreCovariance38', 'TimbreCovariance39', 'TimbreCovariance40', 'TimbreCovariance41', 'TimbreCovariance42',
               'TimbreCovariance43', 'TimbreCovariance44', 'TimbreCovariance45', 'TimbreCovariance46', 'TimbreCovariance47', 'TimbreCovariance48',
               'TimbreCovariance49', 'TimbreCovariance50', 'TimbreCovariance51', 'TimbreCovariance52', 'TimbreCovariance53', 'TimbreCovariance54',
               'TimbreCovariance55', 'TimbreCovariance56', 'TimbreCovariance57', 'TimbreCovariance58', 'TimbreCovariance59', 'TimbreCovariance60',
               'TimbreCovariance61', 'TimbreCovariance62', 'TimbreCovariance63', 'TimbreCovariance64', 'TimbreCovariance65', 'TimbreCovariance66',
               'TimbreCovariance67', 'TimbreCovariance68', 'TimbreCovariance69', 'TimbreCovariance70', 'TimbreCovariance71', 'TimbreCovariance72',
               'TimbreCovariance73', 'TimbreCovariance74', 'TimbreCovariance75', 'TimbreCovariance76', 'TimbreCovariance77', 'TimbreCovariance78']

CNS_COLUMNS = [
    "age","class of worker","detailed industry recode","detailed occupation recode","education",
    "wage per hour","enroll in edu inst last wk","marital stat","major industry code","major occupation code",
    "race","hispanic origin","sex","member of a labor union","reason for unemployment",
    "full or part time employment stat","capital gains","capital losses","dividends from stocks","tax filer stat",
    "region of previous residence","state of previous residence","detailed household and family stat","detailed household summary in household","instance weight",
    "migration code-change in msa","migration code-change in reg","migration code-move within reg","live in this house 1 year ago","migration prev res in sunbelt",
    "num persons worked for employer","family members under 18","country of birth father","country of birth mother","country of birth self",
    "citizenship","own business or self employed","fill inc questionnaire for veteran's admin","veterans benefits","weeks worked in year",
    "year","Y"
]

CNS_CATE_COLUMNS = [
    "class of worker","detailed industry recode","detailed occupation recode","education","enroll in edu inst last wk",
    "marital stat","major industry code","major occupation code","race","hispanic origin",
    "sex","member of a labor union","reason for unemployment","full or part time employment stat","tax filer stat",
    "region of previous residence","state of previous residence","detailed household and family stat","detailed household summary in household","migration code-change in msa",
    "migration code-change in reg","migration code-move within reg","live in this house 1 year ago","migration prev res in sunbelt","family members under 18",
    "country of birth father","country of birth mother","country of birth self","citizenship","own business or self employed",
    "fill inc questionnaire for veteran's admin","veterans benefits",    "year"
]

KDDT_COLUMNS = ["Y"]+[f"Var{no}" for no in range(1,231)]

KDDT_CATE_COLUMNS = ['Var191', 'Var192', 'Var193', 'Var194', 'Var195', 'Var196', 'Var197','Var198', 'Var199', 'Var200', 'Var201', 'Var202',
                     'Var203', 'Var204','Var205', 'Var206', 'Var207', 'Var208', 'Var210', 'Var211', 'Var212','Var213', 'Var214', 'Var215',
                     'Var216', 'Var217', 'Var218', 'Var219','Var220', 'Var221', 'Var222', 'Var223', 'Var224', 'Var225', 'Var226','Var227',
                     'Var228', 'Var229']

BLST_CATE_COLUMNS = ['gender', 'SeniorCitizen', 'Partner', 'Dependents', 'PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity',
                     'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling',
                     'PaymentMethod']

SRCS_COLUMNS = ['positions1', 'positions2', 'positions3', 'positions4', 'positions5','positions6', 'positions7', 
                'velocities1', 'velocities2', 'velocities3','velocities4', 'velocities5', 'velocities6', 'velocities7',
                'accelerations1', 'accelerations2', 'accelerations3', 'accelerations4','accelerations5', 'accelerations6', 'accelerations7']

# 순서대로, label, cate_columns, to_drop, label_convert임
DATA_INFO_DICT = {
    "rssm":("Sales",[],["Year"],{}),
    "year":("Year",[],[],{}),
    "covr":("Type",[],[],{1:0,2:1,3:2,4:3,5:4,6:5,7:6}),
    "adlt":("label",["workclass","education","marital-status","occupation","relationship","race","sex","native-country"],[],{' <=50K':0, ' >50K':1, ' <=50K.':0, ' >50K.':1}),
    "adlt2":("label",["workclass","education","marital-status","occupation","relationship","race","sex","native-country"],[],{' <=50K':0, ' >50K':1, ' <=50K.':0, ' >50K.':1}),
    "cnss":("Y",CNS_CATE_COLUMNS,["instance weight"],{' - 50000.':0,' 50000+.':1}),
    "cnss2":("Y",CNS_CATE_COLUMNS,["instance weight"],{' - 50000.':0,' 50000+.':1}),
    "srcs":("torques1",[],[f"torques{no}" for no in range(2,8)],{}),
    "chrn":("Y",KDDT_CATE_COLUMNS,[],{-1:0,1:1}),
    "blst":("Churn",BLST_CATE_COLUMNS,["customerID"],{"No":0,"Yes":1}),
    "shrt":("Exited",['Geography', 'Gender', 'HasCrCard', 'IsActiveMember'],["RowNumber","CustomerId","Surname"],{}),
    "gpsp":("Y",[],[],{"'D'": 0, "'P'": 1, "'S'": 2, "'H'": 3, "'R'": 4}),
    "gddc":("Y",[],[],{4: 0, 3: 1, 6: 2, 2: 3, 1: 4, 5: 5}),
    "eyem":("Y",[],[],{}),
    "clhp":("median_house_value",["ocean_proximity"],[],{}),
    "clhp2":("median_house_value",["ocean_proximity"],[],{}),
    "hloc":("RiskPerformance",[],[],{"Bad":1,"Good":0}),
    "higg":("target",[],[],{})
    
}

def split_and_scale(X,y,test_size=0.2,seed=None,feature_range=(-1.0,1.0)):
    X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=test_size,random_state=seed)    

    scaler = MinMaxScaler(feature_range=feature_range).fit(X_train)
    X_train, X_test = scaler.transform(X_train), scaler.transform(X_test)
    return X_train,X_test,y_train,y_test, scaler
    
def do_basic_op(train,test,label,to_drop,label_convert,cate_columns,as_frame,normalize_feature=False,normalize_label=False,qtran_columns=[],y_scale=1.0):
    """
    train: pd.DataFrame for train
    test: pd.DataFrame for test
    label(str): A column string for label, and much exists in train and test
    to_drop(list[str]): A list of columns not to be used as features
    label_convert(dict): A Dict from original label to new label. Intended to convert string label(bankrupt) to integer code(1)
    cate_columns(list[str]): A list of categorical columns: Relabeled using LabelEncoder
    as_frame: return as dataFrame, else numpy array
    normalize_feature: Normalize all features from -1 to 1.
    normalize_label: normalize label, from 0 to 1. Intended for regression
    qtran_columns(list[str]): A list of columns  to transform 
    y_scale: Depreciated
    """

    y_train = train.pop(label)
    y_test = test.pop(label)
    X_train = train.drop(to_drop,axis=1,inplace=False).copy()
    X_test = test.drop(to_drop,axis=1,inplace=False).copy()
    
    
    y_train.replace(label_convert,inplace=True)
    y_test.replace(label_convert,inplace=True)

    info = {"columns":X_train.columns.tolist()[:]}
    
    for col in cate_columns:
        enc = LabelEncoder()
        X_train.loc[:,col] = enc.fit_transform(X_train[col])
        unseen_labels = set(X_test[col].unique()) - set(enc.classes_)
        unseen_indices = X_test[col].isin(unseen_labels)
        if np.any(unseen_indices):
            pass
#             print(f"{col} has {len(unseen_labels)} unseen labels : {unseen_labels}")
        X_test.loc[~unseen_indices,col] = enc.transform(X_test.loc[~unseen_indices,col])    
        X_test.loc[unseen_indices,col] = -2.0
    
    if len(qtran_columns)>0:
        enc = QuantileTransformer(n_quantiles=200,output_distribution="uniform",subsample=len(X_train))
        X_train.loc[:,qtran_columns] = enc.fit_transform(X_train[qtran_columns]) # -1~1까지로 수정했음. 근데 QuantileTransform이 나을듯(Since bijection)
        X_test.loc[:,qtran_columns] = enc.transform(X_test[qtran_columns])
        
    if normalize_feature:
        scaler = MinMaxScaler(feature_range=(-1.0, 1.0)).fit(X_train)
        X_train_arr, X_test_arr = scaler.transform(X_train), scaler.transform(X_test)
        X_train.loc[:,:] = X_train_arr
        X_test.loc[:,:] = X_test_arr

    if normalize_label:
        min_val, max_val = y_train.min(), y_train.max()
        y_train = y_scale*(y_train-min_val)/(max_val-min_val)
        y_test = y_scale*(y_test-min_val)/(max_val-min_val)
        info.update({"scale":max_val-min_val})
    
    if not as_frame:
        X_train = X_train.values
        X_test = X_test.values
        y_train = y_train.values
        y_test = y_test.values    
    info.update({"num_feat":X_train.shape[-1]})
    return X_train,X_test,y_train,y_test,info

def preprocess_rssm(data_dir,save_dir):

    train = pd.read_csv(os.path.join(data_dir,"train.csv"),sep=",")
    store = pd.read_csv(os.path.join(data_dir,"store.csv"),sep=",")
    os.makedirs(save_dir,exist_ok=True)
    
    month_abbrs = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sept', 'Oct', 'Nov', 'Dec']
    
    # 1) make integer Year,Month,Day columns instead of Date
    # 2) join data from store table
    def preprocess(df, stores):
        date = np.array([list(map(int, s.split('-'))) for s in df['Date']])
        df = df.drop(['Date'], axis=1)
        df['Year'] = date[:, 0]
        df['Month'] = date[:, 1]
        df['Day'] = date[:, 2]
        df = df.join(stores, on='Store', rsuffix='_right')
        df = df.drop(['Store_right'], axis=1)
    
    
        promo2_start_months = [(s.split(',') if not pd.isnull(s) else []) for s in df['PromoInterval']]
    
        for month_abbr in month_abbrs:
            df['Promo2Start_' + month_abbr] = np.array([(1 if month_abbr in s else 0) for s in promo2_start_months])
        df = df.drop(['PromoInterval'], axis=1)
    
        df = df.fillna(0)
        return df
    
    
    train_prepared_fixed_date = preprocess(train, store)
    
    
    def get_str_column_names(df):
        str_names = []
        for col in df.columns:
            for x in df[col]:
                if isinstance(x, str):
                    str_names.append(col)
                    break
    
        return str_names
    
    train_inds = train_prepared_fixed_date[train_prepared_fixed_date['Year'] == 2014].index
    test_inds = train_prepared_fixed_date[train_prepared_fixed_date['Year'] == 2015].index
    
    train2 = train_prepared_fixed_date.iloc[train_inds]
    test2 = train_prepared_fixed_date.iloc[test_inds]
    
    
    str_cat_columns = get_str_column_names(train_prepared_fixed_date)
    
    
    # transform categorical columns with strings using LabelEncoder
    def fix_strs(df, cat_names, test_df=None):
        df[cat_names] = df[cat_names].fillna(0)
        if test_df is not None:
            test_df[cat_names] = test_df[cat_names].fillna(0)
        for col in cat_names:
            enc = LabelEncoder()
            df[col] = enc.fit_transform(df[col])
            if test_df is not None:
                test_df[col] = enc.transform(test_df[col])
        return df, test_df
        
    train2[str_cat_columns] = train2[str_cat_columns].astype(str)
    test2[str_cat_columns] = test2[str_cat_columns].astype(str)
    
    train2, test2 = fix_strs(train2, str_cat_columns, test2)
    
    
    all_cat_names = (['Store', 'DayOfWeek', 'Open', 'Promo', 'StateHoliday', 'SchoolHoliday',
                    'StoreType', 'Assortment', 'Promo2']
                     + ['Promo2Start_' + month_abbr for month_abbr in month_abbrs])
    
    train2.to_csv(os.path.join(save_dir, 'rssm_train.csv'), sep=',', header=True, index=False)
    test2.to_csv(os.path.join(save_dir, 'rssm_test.csv'), sep=',', header=True, index=False)
    
    with open(os.path.join(save_dir, 'rssm_code.csv'), 'w') as cd:
        for idx, name in enumerate(train2.columns):
            cd.write('{},{}\n'.format(
                idx,
                'Label' if name == 'Sales' else ('Categ,' + name if name in all_cat_names else 'Num,' + name))
            )    

    return 

def load_rssm(data_dir,seed=None,as_frame=True,normalize=False,qtran=False):
    if ('rssm_train.csv' not in os.listdir(data_dir)) or ('rssm_test.csv' not in os.listdir(data_dir)): 
        print(f"'rssm_train.csv' or 'rssm_test.csv' not in {data_dir}, do preprocess first")
        preprocess_rssm(data_dir,data_dir)
        
    train = pd.read_csv(os.path.join(data_dir,"rssm_train.csv"))
    test = pd.read_csv(os.path.join(data_dir,"rssm_test.csv"))
    columns = train.columns
    # print("Reordering Store No")
    # store_no_convert = {new_no:old_no for new_no,old_no in enumerate(train.groupby("Store")["Sales"].mean().sort_values().index)}
    # train.replace({"Store":store_no_convert},inplace=True)
    # test.replace({"Store":store_no_convert},inplace=True)

    dinfo = DATA_INFO_DICT["rssm"]
    X_train,X_test,y_train,y_test,basic_info =\
        do_basic_op(train,test,as_frame=as_frame,normalize_feature=normalize,normalize_label=True,
                    label=dinfo[0],to_drop=dinfo[2],label_convert=dinfo[3],cate_columns=dinfo[1],
                    qtran_columns= [col for col in columns if col!=dinfo[0]]  if qtran else []
                   )
    X_train,X_valid,y_train,y_valid = train_test_split(X_train,y_train,test_size=100000,random_state=seed)

    info = {"code":pd.read_csv(os.path.join(data_dir,"rssm_code.csv")),"columns":columns.drop(dinfo[0]).tolist()}    
    info.update(basic_info)
    return X_train,X_valid,X_test,y_train,y_valid,y_test, info

def load_year(data_dir,seed=None,as_frame=True,normalize=False,qtran=False):
    # 20%
    df = pd.read_csv(os.path.join(data_dir,"YearPredictionMSD.txt"),sep=",",header=None,names=YEAR_COLUMNS) 
    train, test = df.iloc[:463715].copy(), df.iloc[463715:].copy()
    columns = train.columns
    
    dinfo = DATA_INFO_DICT["year"]
    X_train,X_test,y_train,y_test,basic_info =\
        do_basic_op(train,test,as_frame=as_frame,normalize_feature=normalize,normalize_label=True,
                    label=dinfo[0],to_drop=dinfo[2],label_convert=dinfo[3],cate_columns=dinfo[1],
                    qtran_columns= [col for col in columns if col!=dinfo[0]]  if qtran else []
                   )

    X_train,X_valid,y_train,y_valid = train_test_split(X_train,y_train,test_size=0.2,random_state=seed)
    info = {}
    info.update(basic_info)
    return X_train,X_valid,X_test,y_train,y_valid,y_test, info

def load_covr(data_dir,test_size,seed=None,as_frame=True,normalize=False,qtran=False):
    data = pd.read_csv(os.path.join(data_dir,"covtype.csv"),sep=",",names=COVR_COLUMNS)
    train, test = train_test_split(data,test_size=test_size,random_state=seed)
    columns = train.columns
    
    dinfo = DATA_INFO_DICT["covr"]
    X_train,X_test,y_train,y_test,basic_info =\
        do_basic_op(train,test,as_frame=as_frame,normalize_feature=normalize,
                    label=dinfo[0],to_drop=dinfo[2],label_convert=dinfo[3],cate_columns=dinfo[1],
                    qtran_columns= [col for col in columns if col!=dinfo[0]]  if qtran else []
                   )
    X_train,X_valid,y_train,y_valid = train_test_split(X_train,y_train,test_size=0.2 / (1-test_size),random_state=seed)
    
    info = {"columns":columns.drop(dinfo[0]).tolist()}
    info.update(basic_info)
    return  X_train,X_valid,X_test,y_train,y_valid,y_test, info

def load_adlt(data_dir,seed=None,as_frame=True,normalize=False,qtran=False): 
    train = pd.read_csv(os.path.join(data_dir,"train.csv"),names=ADLT_COLUMNS)
    test = pd.read_csv(os.path.join(data_dir,"test.csv"),names=ADLT_COLUMNS)
    columns = train.columns
    
    dinfo = DATA_INFO_DICT["adlt"]
    X_train,X_test,y_train,y_test,basic_info =\
        do_basic_op(train,test,as_frame=as_frame,normalize_feature=normalize,
                    label=dinfo[0],to_drop=dinfo[2],label_convert=dinfo[3],cate_columns=dinfo[1],
                    qtran_columns= [col for col in columns if col!=dinfo[0]]  if qtran else []
                   )
    X_train,X_valid,y_train,y_valid = train_test_split(X_train,y_train,test_size=0.2,random_state=seed)
    info = {"columns":columns.drop(dinfo[0]).tolist()}
    info.update(basic_info)    
    return X_train,X_valid,X_test,y_train,y_valid,y_test, info

def load_cnss(data_dir,seed=None,as_frame=True,normalize=False,qtran=False):
    train = pd.read_csv(os.path.join(data_dir,"train.csv"),header=None,names=CNS_COLUMNS)
    test = pd.read_csv(os.path.join(data_dir,"test.csv"),header=None,names=CNS_COLUMNS)
    columns = train.columns
    dinfo = DATA_INFO_DICT["cnss"]
    X_train,X_test,y_train,y_test,basic_info =\
        do_basic_op(train,test,as_frame=as_frame,normalize_feature=normalize,
                    label=dinfo[0],to_drop=dinfo[2],label_convert=dinfo[3],cate_columns=dinfo[1],
                    qtran_columns= [col for col in columns if col!=dinfo[0]]  if qtran else []
                   )
    X_train,X_valid,y_train,y_valid = train_test_split(X_train,y_train,test_size=0.2,random_state=seed)
    info = {"columns":columns.drop(dinfo[0]).tolist()}
    info.update(basic_info)
    return X_train,X_valid,X_test,y_train,y_valid,y_test, info

def load_srcs(data_dir,seed=None,as_frame=True,normalize=False,qtran=False):
    train = pd.read_csv(os.path.join(data_dir,"sarcos_inv.csv"))
    test = pd.read_csv(os.path.join(data_dir,"sarcos_inv_test.csv"))
    columns = train.columns
    dinfo = DATA_INFO_DICT["srcs"]
    X_train,X_test,y_train,y_test,basic_info =\
        do_basic_op(train,test,as_frame=as_frame,normalize_feature=normalize,normalize_label=True,
                    label=dinfo[0],to_drop=dinfo[2],label_convert=dinfo[3],cate_columns=dinfo[1],
                    qtran_columns= [col for col in columns if col!=dinfo[0]]  if qtran else []
                   )
    X_train,X_valid,y_train,y_valid = train_test_split(X_train,y_train,test_size=0.2,random_state=seed)
    info = {"columns":columns.drop(dinfo[0]).tolist()}
    info.update(basic_info)
    return X_train,X_valid,X_test,y_train,y_valid,y_test, info

def load_higg(data_dir,test_size,seed=None,as_frame=True,normalize=False,qtran=False):
    data = pd.read_csv(os.path.join(data_dir,"higgs.csv"))
    train, test = train_test_split(data,test_size=test_size,random_state=seed)
    columns = train.columns
    
    dinfo = DATA_INFO_DICT["higg"]
    X_train,X_test,y_train,y_test,basic_info =\
        do_basic_op(train,test,as_frame=as_frame,normalize_feature=normalize,
                    label=dinfo[0],to_drop=dinfo[2],label_convert=dinfo[3],cate_columns=dinfo[1],
                    qtran_columns= [col for col in columns if col!=dinfo[0]]  if qtran else []
                   )
    X_train,X_valid,y_train,y_valid = train_test_split(X_train,y_train,test_size=0.2,random_state=seed)
    info = {"columns":columns.tolist()}
    info.update(basic_info)
    return X_train,X_valid,X_test,y_train,y_valid,y_test, info

def load_chrn(data_dir,seed=None,as_frame=True,normalize=False,qtran=False):
    train = pd.read_csv(os.path.join(data_dir,"train.csv"),header=None,names=KDDT_COLUMNS,sep="\t")
    test = pd.read_csv(os.path.join(data_dir,"test.csv"),header=None,names=KDDT_COLUMNS,sep="\t")
    columns = train.columns
    
    dinfo = DATA_INFO_DICT["chrn"]
    X_train,X_test,y_train,y_test,basic_info =\
        do_basic_op(train,test,as_frame=as_frame,normalize_feature=normalize,
                    label=dinfo[0],to_drop=dinfo[2],label_convert=dinfo[3],cate_columns=dinfo[1],
                    qtran_columns= [col for col in columns if col!=dinfo[0]]  if qtran else []
                   )
    X_train,X_valid,y_train,y_valid = train_test_split(X_train,y_train,test_size=0.2,random_state=seed)
    info = {}
    info.update(basic_info)
    return X_train,X_valid,X_test,y_train,y_valid,y_test, info

def load_blst(data_dir,seed=None,as_frame=True,normalize=False,qtran=False):
    data = pd.read_csv(os.path.join(data_dir,"blst.csv")).replace({"TotalCharges":{" ":-1}}).astype({"TotalCharges":"float"})
    train, test = train_test_split(data,test_size=0.2,random_state=seed)
    columns = train.columns
    
    dinfo = DATA_INFO_DICT["blst"]
    X_train,X_test,y_train,y_test,basic_info =\
        do_basic_op(train,test,as_frame=as_frame,normalize_feature=normalize,
                    label=dinfo[0],to_drop=dinfo[2],label_convert=dinfo[3],cate_columns=dinfo[1],
                    qtran_columns= [col for col in columns if col!=dinfo[0]]  if qtran else []
                   )
    X_train,X_valid,y_train,y_valid = train_test_split(X_train,y_train,test_size=0.1/(1-0.2),random_state=seed)
    info = {}
    info.update(basic_info)
    return X_train,X_valid,X_test,y_train,y_valid,y_test, info

def load_shrt(data_dir,seed=None,as_frame=True,normalize=False,qtran=False):
    data = pd.read_csv(os.path.join(data_dir,"shrt.csv"))
    train, test = train_test_split(data,test_size=0.2,random_state=seed)
    columns = train.columns
    
    dinfo = DATA_INFO_DICT["shrt"]
    X_train,X_test,y_train,y_test,basic_info =\
        do_basic_op(train,test,as_frame=as_frame,normalize_feature=normalize,
                    label=dinfo[0],to_drop=dinfo[2],label_convert=dinfo[3],cate_columns=dinfo[1],
                    qtran_columns= [col for col in columns if col!=dinfo[0]]  if qtran else []
                   )
    X_train,X_valid,y_train,y_valid = train_test_split(X_train,y_train,test_size=0.1/(1-0.2),random_state=seed)
    info = {}
    info.update(basic_info)
    return X_train,X_valid,X_test,y_train,y_valid,y_test, info

def load_gpsp(data_dir,seed=None,as_frame=True,normalize=False,qtran=False):
    data = pd.read_csv(os.path.join(data_dir,"gpsp.csv"))
    train, test = train_test_split(data,test_size=0.2,random_state=seed)
    columns = train.columns
    
    dinfo = DATA_INFO_DICT["gpsp"]
    X_train,X_test,y_train,y_test,basic_info =\
        do_basic_op(train,test,as_frame=as_frame,normalize_feature=normalize,
                    label=dinfo[0],to_drop=dinfo[2],label_convert=dinfo[3],cate_columns=dinfo[1],
                    qtran_columns= [col for col in columns if col!=dinfo[0]]  if qtran else []
                   )
    X_train,X_valid,y_train,y_valid = train_test_split(X_train,y_train,test_size=0.1/(1-0.2),random_state=seed)
    info = {}
    info.update(basic_info)
    return X_train,X_valid,X_test,y_train,y_valid,y_test, info

def load_gddc(data_dir,seed=None,as_frame=True,normalize=False,qtran=False):
    data = pd.read_csv(os.path.join(data_dir,"gddc.csv"))
    train, test = train_test_split(data,test_size=0.2,random_state=seed)
    columns = train.columns
    
    dinfo = DATA_INFO_DICT["gddc"]
    X_train,X_test,y_train,y_test,basic_info =\
        do_basic_op(train,test,as_frame=as_frame,normalize_feature=normalize,
                    label=dinfo[0],to_drop=dinfo[2],label_convert=dinfo[3],cate_columns=dinfo[1],
                    qtran_columns= [col for col in columns if col!=dinfo[0]]  if qtran else []
                   )
    X_train,X_valid,y_train,y_valid = train_test_split(X_train,y_train,test_size=0.1/(1-0.2),random_state=seed)
    info = {}
    info.update(basic_info)
    return X_train,X_valid,X_test,y_train,y_valid,y_test, info

def load_eyem(data_dir,seed=None,as_frame=True,normalize=False,qtran=False):
    data = pd.read_csv(os.path.join(data_dir,"eyem.csv"))
    train, test = train_test_split(data,test_size=0.2,random_state=seed)
    columns = train.columns
    
    dinfo = DATA_INFO_DICT["eyem"]
    X_train,X_test,y_train,y_test,basic_info =\
        do_basic_op(train,test,as_frame=as_frame,normalize_feature=normalize,
                    label=dinfo[0],to_drop=dinfo[2],label_convert=dinfo[3],cate_columns=dinfo[1],
                    qtran_columns= [col for col in columns if col!=dinfo[0]]  if qtran else []
                   )
    X_train,X_valid,y_train,y_valid = train_test_split(X_train,y_train,test_size=0.1/(1-0.2),random_state=seed)
    info = {}
    info.update(basic_info)
    return X_train,X_valid,X_test,y_train,y_valid,y_test, info

def load_clhp(data_dir,test_size,seed=None,as_frame=True,normalize=False,qtran=False):
    data = pd.read_csv(os.path.join(data_dir,"housing.csv"))
    data.fillna({"total_bedrooms":0},inplace=True)
    train, test = train_test_split(data,test_size=0.2,random_state=seed)
    columns = train.columns
    
    dinfo = DATA_INFO_DICT["clhp"]
    X_train,X_test,y_train,y_test,basic_info =\
        do_basic_op(train,test,as_frame=as_frame,normalize_feature=normalize,normalize_label=True,
                    label=dinfo[0],to_drop=dinfo[2],label_convert=dinfo[3],cate_columns=dinfo[1],
                    qtran_columns= [col for col in columns if col!=dinfo[0]]  if qtran else []
                   )
    X_train,X_valid,y_train,y_valid = train_test_split(X_train,y_train,test_size=0.2,random_state=seed)
    info = {}
    info.update(basic_info)
    return X_train,X_valid,X_test,y_train,y_valid,y_test, info  

def load_hloc(data_dir,test_size=0.2,seed=None,as_frame=True,normalize=False,qtran=False):
    data = pd.read_csv(os.path.join(data_dir,"heloc_dataset_v1 (1).csv"))
    train, test = train_test_split(data,test_size=test_size,random_state=seed)    
    columns = train.columns
    
    dinfo = DATA_INFO_DICT["hloc"]
    X_train,X_test,y_train,y_test,basic_info =\
        do_basic_op(train,test,as_frame=as_frame,normalize_feature=normalize,normalize_label=True,
                    label=dinfo[0],to_drop=dinfo[2],label_convert=dinfo[3],cate_columns=dinfo[1],
                    qtran_columns= [col for col in columns if col!=dinfo[0]]  if qtran else []
                   )
    X_train,X_valid,y_train,y_valid = train_test_split(X_train,y_train,test_size=0.2,random_state=seed)
    info = {}
    info.update(basic_info)
    return X_train,X_valid,X_test,y_train,y_valid,y_test, info  

################################## 시뮬레이션 데이터
def generate_category(num_x, num_repeat,return_matrix=True, seed =1234, ordering=False, feature_range=(-1.0,1.0)):

    rng = np.random.RandomState(seed)
    
    x_values = np.linspace(feature_range[0],feature_range[-1],num_x)
    prob_for_x = rng.uniform(0.1,0.9,num_x)
    if ordering :
        prob_for_x = np.sort(prob_for_x)
    y = [rng.choice([0,1],p=[1-p,p],size=num_repeat) for p in prob_for_x]
    y = np.concatenate(y)
    X = np.repeat(x_values,num_repeat)

    code_dict= {key:code for code,key in enumerate(x_values)}
    prob_dict = {key:value for key,value in zip(x_values,prob_for_x)}
    sample_p = pd.DataFrame(zip(X,y),columns=["x","y"]).groupby("x")["y"].mean().to_dict()

    info_dict= {"prob":prob_dict,"code":code_dict,"sample_prob":sample_p}
    if return_matrix:
        X = X[:,np.newaxis]
    return X,y,info_dict

def generate_random_data(data_size=1000,features=5,feature_range=(-1.0,1.0),seed=1234):
    rng = np.random.RandomState(seed)
    X,y = rng.uniform(feature_range[0],feature_range[-1],size=(data_size,features)), rng.randint(0,2,(data_size,))
    return X,y

def generate_random_data2(data_size=1000,features=5,feature_range=(-1.0,1.0),seed=1234):
    rng = np.random.RandomState(seed)
    X,y = rng.uniform(feature_range[0],feature_range[-1],size=(data_size,features)), rng.rand(data_size,)
    return X,y


from sklearn.datasets import make_classification
def generate_classification(data_size,test_size,n_features,n_informative,n_redundant=0,n_cluster=2,
                            flip_y=0.01, class_sep=1.0, hypercube=True, seed=None, feature_range=(-1.0,1.0)):
    X,y = make_classification(n_samples=data_size, n_features=n_features,  n_informative=n_informative, n_redundant=n_redundant
                              ,n_repeated=0, n_classes=2
                              ,n_clusters_per_class=n_cluster, weights=None, flip_y=flip_y, class_sep=class_sep, hypercube=hypercube
                              ,shift=0.0, scale=1.0, shuffle=True, random_state=seed)
    X_train,X_test,y_train,y_test, scaler = split_and_scale(X,y,test_size,seed)
    return X_train,X_test,y_train,y_test
