import pandas as pd
import os

import ember

def preprocess():
    '''
    [All positive samples]
        all_pos_key.csv, all_pos_X.csv
        100% of positive samples
    [Samples for Training ML model]
        X_train_key.csv, X_train.csv, y_train.csv
        10,000 positive samples and 10,000 negative samples
    [Samples for Validation]
        X_val_key.csv, X_val.csv, y_val.csv
        10,000 positive samples and 10,000 negative samples, which are not used for training ML model
    [Samples for Testing]
        X_test_key.csv, X_test.csv, y_test.csv
        100,000 positive samples and 100,000 negative samples
    '''
    data_path = 'data/ember/ember2018/'
    all_pos_key_path = 'data/ember/preprocessed/all_pos_key.csv'
    all_pos_X_path = 'data/ember/preprocessed/all_pos_X.csv'
    X_train_key_path = 'data/ember/preprocessed/X_train_key.csv'
    X_train_path = 'data/ember/preprocessed/X_train.csv'
    y_train_path = 'data/ember/preprocessed/y_train.csv'
    X_val_key_path = 'data/ember/preprocessed/X_val_key.csv'
    X_val_path = 'data/ember/preprocessed/X_val.csv'
    y_val_path = 'data/ember/preprocessed/y_val.csv'
    X_test_key_path = 'data/ember/preprocessed/X_test_key.csv'
    X_test_path = 'data/ember/preprocessed/X_test.csv'
    y_test_path = 'data/ember/preprocessed/y_test.csv'
    os.makedirs('data/ember/preprocessed', exist_ok=True)

    if not os.path.exists(os.path.join(data_path, 'y_train.dat')):
        ember.create_vectorized_features(data_path)
        ember.create_metadata(data_path)

    X_all_train, y_all_train, X_test_, y_test_ = ember.read_vectorized_features(data_path)
    metadata_dataframe = ember.read_metadata(data_path)
    all_train_metadata = metadata_dataframe[metadata_dataframe['subset'] == 'train']
    test_metadata = metadata_dataframe[metadata_dataframe['subset'] == 'test']
    train_df = pd.DataFrame(
        {
            'key': all_train_metadata['sha256'],
            'label': y_all_train.astype(int),
        }
    )
    valid_indexes = (train_df['label'] != -1)
    train_df = train_df[valid_indexes]
    X_all_train = X_all_train[valid_indexes]
    y_all_train = y_all_train[valid_indexes]

    test_df = pd.DataFrame(
        {
            'key': test_metadata['sha256'],
            'label': y_test_.astype(int),
        }
    )
    valid_indexes = (test_df['label'] != -1)
    test_df = test_df[valid_indexes]
    X_test_ = X_test_[valid_indexes]
    y_test_ = y_test_[valid_indexes]

    all_df = pd.concat([train_df, test_df], ignore_index=True)
    all_X = pd.concat([pd.DataFrame(X_all_train), pd.DataFrame(X_test_)], ignore_index=True)
    all_y = pd.concat([pd.Series(y_all_train), pd.Series(y_test_)], ignore_index=True)

    all_pos_indexes = (all_df['label'] == 1)
    all_pos_df = all_df[all_pos_indexes]
    all_pos_key = all_pos_df['key']
    all_pos_X = all_X[all_pos_indexes]
    all_pos_y = all_y[all_pos_indexes]
    all_neg_df = all_df[~all_pos_indexes]
    all_neg_key = all_neg_df['key']
    all_neg_X = all_X[~all_pos_indexes]
    all_neg_y = all_y[~all_pos_indexes]

    print("=== EMBER dataset ===")
    print("all_df shape: ", all_df.shape)
    print("all_X shape: ", all_X.shape)
    print("all_y shape: ", all_y.shape, ", pos: ", all_y[all_y == 1].shape[0], ", neg: ", all_y[all_y == 0].shape[0])
    print("all_pos_df shape: ", all_pos_df.shape)
    print("all_pos_X shape: ", all_pos_X.shape)
    print("all_pos_y shape: ", all_pos_y.shape)
    print("all_neg_df shape: ", all_neg_df.shape)
    print("all_neg_X shape: ", all_neg_X.shape)
    print("all_neg_y shape: ", all_neg_y.shape)

    # pos: train 10%, val 100%, test 10%
    # neg: train 10%, val 10%, test 10%
    X_train_pos = all_pos_X.sample(frac=0.1, random_state=42)
    X_train_neg = all_neg_X.sample(frac=0.1, random_state=42)
    X_train = pd.concat([X_train_pos, X_train_neg], ignore_index=True)
    y_train = pd.concat([pd.Series([1] * len(X_train_pos)), pd.Series([0] * len(X_train_neg))], ignore_index=True)
    X_train_key = pd.concat([all_pos_key.loc[X_train_pos.index], all_neg_key.loc[X_train_neg.index]], ignore_index=True)

    X_val_pos = all_pos_X
    X_val_neg = all_neg_X.drop(X_train_neg.index).sample(frac=0.1/0.9, random_state=42)
    X_val = pd.concat([X_val_pos, X_val_neg], ignore_index=True)
    y_val = pd.concat([pd.Series([1] * len(X_val_pos)), pd.Series([0] * len(X_val_neg))], ignore_index=True)
    X_val_key = pd.concat([all_pos_key.loc[X_val_pos.index], all_neg_key.loc[X_val_neg.index]], ignore_index=True)

    X_test_pos = all_pos_X.sample(frac=0.1, random_state=42)
    X_test_neg = all_neg_X.drop(X_train_neg.index).drop(X_val_neg.index).sample(frac=0.1/0.8, random_state=42)
    X_test = pd.concat([X_test_pos, X_test_neg], ignore_index=True)
    y_test = pd.concat([pd.Series([1] * len(X_test_pos)), pd.Series([0] * len(X_test_neg))], ignore_index=True)
    X_test_key = pd.concat([all_pos_key.loc[X_test_pos.index], all_neg_key.loc[X_test_neg.index]], ignore_index=True)

    def shuffle_X_y_key(X, y, key):
        data = pd.concat([key, pd.DataFrame(X), pd.Series(y)], axis=1)
        data.columns = ['key'] + [f'feature_{i}' for i in range(X.shape[1])] + ['label']
        data = data.sample(frac=1, random_state=42).reset_index(drop=True)
        key = data['key']
        X = data.iloc[:, 1:-1]
        y = data['label']
        return X, y, key

    X_train, y_train, X_train_key = shuffle_X_y_key(X_train, y_train, X_train_key)
    X_val, y_val, X_val_key = shuffle_X_y_key(X_val, y_val, X_val_key)
    X_test, y_test, X_test_key = shuffle_X_y_key(X_test, y_test, X_test_key)

    print("[ALL POSITIVE]")
    print("all_pos_key shape: ", all_pos_key.shape)
    print("all_pos_X shape: ", all_pos_X.shape)
    print("[TRAIN]")
    print("X_train_key shape: ", X_train_key.shape)
    print("X_train shape: ", X_train.shape)
    print("y_train shape: ", y_train.shape, ", pos: ", y_train[y_train == 1].shape[0], ", neg: ", y_train[y_train == 0].shape[0])
    print("[VAL]")
    print("X_val_key shape: ", X_val_key.shape)
    print("X_val shape: ", X_val.shape)
    print("y_val shape: ", y_val.shape, ", pos: ", y_val[y_val == 1].shape[0], ", neg: ", y_val[y_val == 0].shape[0])
    print("[TEST]")
    print("X_test_key shape: ", X_test_key.shape)
    print("X_test shape: ", X_test.shape)
    print("y_test shape: ", y_test.shape, ", pos: ", y_test[y_test == 1].shape[0], ", neg: ", y_test[y_test == 0].shape[0])

    all_pos_key.to_frame().to_csv(all_pos_key_path, index = False, header=False)
    pd.DataFrame(all_pos_X).to_csv(all_pos_X_path, index = False, header=False)
    X_train_key.to_frame().to_csv(X_train_key_path, index = False, header=False)
    pd.DataFrame(X_train).to_csv(X_train_path, index = False, header=False)
    pd.Series(y_train).to_csv(y_train_path, index = False, header=False)
    X_val_key.to_frame().to_csv(X_val_key_path, index = False, header=False)
    pd.DataFrame(X_val).to_csv(X_val_path, index = False, header=False)
    pd.Series(y_val).to_csv(y_val_path, index = False, header=False)
    X_test_key.to_frame().to_csv(X_test_key_path, index = False, header=False)
    pd.DataFrame(X_test).to_csv(X_test_path, index = False, header=False)
    pd.Series(y_test).to_csv(y_test_path, index = False, header=False)

if __name__ == "__main__":
    preprocess()
