from tcto.task_mapping import task_dict
import pandas as pd
import numpy as np
import argparse
from sklearn.ensemble import RandomForestRegressor,RandomForestClassifier
from sklearn.model_selection import KFold, StratifiedKFold
from sklearn.metrics import f1_score


def relative_absolute_error(y_test, y_predict):
    y_test = np.array(y_test)
    y_predict = np.array(y_predict)
    error = np.sum(np.abs(y_test - y_predict)) / np.sum(np.abs(np.mean(y_test) - y_test))
    return error

def init_param():
    parser = argparse.ArgumentParser(description="PyTorch experiment, testing generated features.")
    parser.add_argument('--file_name', type=str, default='airfoil', help='data file name')
    args, _ = parser.parse_known_args()
    return args
    
def downstream_task_new(data:pd.DataFrame, task_type:str, state_num=0):
    X = data.iloc[:, :-1]
    y = data.iloc[:, -1]
    if task_type == 'cls':
        clf = RandomForestClassifier(random_state=state_num)
        f1_list = []
        skf = StratifiedKFold(n_splits=5, random_state=state_num, shuffle=True)
        for train, test in skf.split(X, y):
            X_train, y_train, X_test, y_test = X.iloc[train, :], y.iloc[train], X.iloc[test, :], y.iloc[test]
            clf.fit(X_train, y_train)
            y_predict = clf.predict(X_test)
            f1_list.append(f1_score(y_test, y_predict, average='weighted'))
        return np.mean(f1_list)
    elif task_type == 'reg':
        kf = KFold(n_splits=5, random_state=state_num, shuffle=True)
        reg = RandomForestRegressor(random_state=state_num)
        rae_list = []
        for train, test in kf.split(X):
            X_train, y_train, X_test, y_test = X.iloc[train, :], y.iloc[train], X.iloc[test, :], y.iloc[test]
            reg.fit(X_train, y_train)
            y_predict = reg.predict(X_test)
            rae_list.append(1 - relative_absolute_error(y_test, y_predict))
        return np.mean(rae_list)
    elif task_type == 'rank':
        pass
    else:
        return -1
    


if __name__ == '__main__':
    params = init_param()
    params = vars(params)
    file_name = params['file_name']
    file_name_list = ['higgs','amazon_employee','pima_indian','spectf','svmguide3','genrman_credit',\
        'credit_default','messidor_features','wine_red','wine_white','spam_base','ap_omentum_ovary',\
        'lymphography','ionosphere','housing_boston','airfoil','openml_618','openml_589','openml_616',\
        'openml_607','openml_620','openml_637','openml_586']
    
    #test original data
    data = pd.read_hdf(f'./data/{file_name}.hdf',key='test')
    original_per = downstream_task_new(data,task_dict[file_name])
    #test generated data
    data = pd.read_csv(f'./results/{file_name}.csv',index_col='Unnamed: 0')
    generated_per = downstream_task_new(data,task_dict[file_name])
    
    if(task_dict[file_name] == 'cls'):
        print(f'F1-score on {file_name} original data is : {original_per}.')
        print(f'F1-score on {file_name} generated data is : {generated_per}.')
    elif(task_dict[file_name] == 'reg'):
        print(f'1-RAE on {file_name} original data is : {original_per}.')
        print(f'F1-score on {file_name} generated data is : {generated_per}.')