
from imblearn.datasets import fetch_datasets
from sklearn.preprocessing import StandardScaler
import torch
from pathlib import Path
from dataset_utils import imbalance_preserving_Kfold, imbalance_preserving_downsample
import os, sys
import numpy as np
from collections import namedtuple

np.set_printoptions(precision=3, threshold=sys.maxsize)

NumericDataset = namedtuple("NumericDataset", ["name"])
MixedDataset = namedtuple("MixedDataset", ["name", "onehot_ranges"])

def print_X_y(X, y, output_txt_file):
    all = np.append(X, y[:, None], axis=1)
    file = open(output_txt_file, "w")
    file.write(str(all))
    file.close()

only_numerical = [NumericDataset('ecoli'),
                  NumericDataset('letter_img'),
                  NumericDataset('libras_move'),
                  NumericDataset('mammography'),
                  NumericDataset('ozone_level'),
                  NumericDataset('pen_digits'),
                  NumericDataset('satimage'),
                  NumericDataset('spectrometer'),
                  NumericDataset('us_crime'),
                  NumericDataset('webpage'),
                  NumericDataset('wine_quality'),
                  NumericDataset('yeast_me2'),
                  NumericDataset('yeast_ml8'),
                  NumericDataset('coil_2000'),
                  NumericDataset('oil'),
                  NumericDataset('optical_digits')]

mixed = [MixedDataset('abalone', [(0, 2)]),
         MixedDataset('abalone_19', [(0, 2)]),
         MixedDataset('sick_euthyroid', [(i, i+1) for i in range(1, 27, 2)]+[(28, 29), (31, 32), (34, 35), (37, 38), (40, 41)]),
         MixedDataset('thyroid_sick', [(i, i+1) for i in range(1, 32, 2)]+[(34, 35), (37, 38), (40, 41), (43, 44), (47, 51)]),
         MixedDataset('solar_flare_m0',
                                [(0, 5), (6, 11), (12, 15), (16, 17), (18, 20), (21, 23), (24, 25), (26, 27), (28, 29), (30, 31)]),
         ]

# All categoricals are binary (single bit) not onehot - ['arrhythmia', 'car_eval_4', 'car_eval_34']
optional = [MixedDataset('arrhythmia', 'all_binary')]

def onehot2int(X :np.ndarray, onehot_ranges):
    def assert_onehot(X):
        sum_is_one_or_zero = np.bitwise_or(np.sum(X, axis=1) == 1, np.sum(X, axis=1) == 0)  # zero is valid as it denotes a missing value
        is_positive = (X >= 0)
        assert np.all(sum_is_one_or_zero) and np.all(is_positive)
    N = X.shape[0]
    new_X = np.array([], dtype=np.int64).reshape(N, 0)
    last_non_processed_bit = 0
    feature_cnt = 0
    cat_idx_list = []
    for first, last in onehot_ranges:
        assert_onehot(X[:, first:last+1])
        if last_non_processed_bit == first: # encountered onset of onehot
            missing_values = np.sum(X[:, first:last+1], axis=1) == 0
            integer = np.argmax(X[:, first:last+1], axis=1).reshape(N, 1)
            integer[missing_values] = -1
            new_X = np.concatenate((new_X, integer), axis=1)
            cat_idx_list.append(feature_cnt)
            feature_cnt += 1
        else:
            new_X = np.concatenate((new_X, X[:, last_non_processed_bit:first]), axis=1)
            missing_values = np.sum(X[:, first:last + 1], axis=1) == 0
            integer = np.argmax(X[:, first:last+1], axis=1).reshape(N, 1)
            integer[missing_values] = -1
            new_X = np.concatenate((new_X, integer), axis=1)
            num_new_features = (first - last_non_processed_bit) + 1
            feature_cnt += num_new_features
            cat_idx_list.append(feature_cnt-1)
        last_non_processed_bit = last + 1
    if last_non_processed_bit < X.shape[1]:
        new_X = np.concatenate((new_X, X[:, last_non_processed_bit:]), axis=1)
    numeric_idx_list = [idx for idx in range(new_X.shape[1]) if idx not in cat_idx_list]
    return new_X, cat_idx_list, numeric_idx_list


def preprocess_dataset(dataset, downsample=False, onehot_ranges=None, suffix=None):
    output_file_base = f'imblearn/{dataset}/{dataset}'
    if suffix:
        output_file_base += f'_{suffix}'
    if not os.path.isdir(f'imblearn/{dataset}'):
        os.makedirs(f'imblearn/{dataset}')
    data = fetch_datasets(data_home='imblearn/imblearn_raw',
                          filter_data=[dataset],
                          download_if_missing=True,
                          random_state=None,
                          shuffle=False,
                          verbose=False)
    X_pre = data[dataset]['data']
    if onehot_ranges == 'all_binary':
        X = X_pre    # do not normalize
    elif onehot_ranges is None:
        X = StandardScaler(with_std=True).fit_transform(X_pre)
    else:
        X_pre, cat_idx_list, numeric_idx_list = onehot2int(X_pre, onehot_ranges)
        print(cat_idx_list)
        # StandardScale only numeric features
        if numeric_idx_list:
            X_pre[:, numeric_idx_list] = StandardScaler(with_std=True).fit_transform(X_pre[:, numeric_idx_list])
        X = X_pre

    y_pre = data[dataset]['target']
    y = (y_pre == -1) + y_pre

    X = torch.from_numpy(X)
    y = torch.from_numpy(y)

    if downsample:
        dataset_size = X.shape[0]
        _g = dataset_size // 1500
        if _g > 1:
            X, y = imbalance_preserving_downsample(X, y, _g)

    kf = imbalance_preserving_Kfold(n_splits=5)
    i = 1
    for x_train, y_train, x_test, y_test in kf.split(X, y):
        if i == 1:
            torch.save([x_train, y_train], Path(f'{output_file_base}_train.pt'))
            print_X_y(X, y, output_file_base+'_train.txt')
            torch.save([x_test, y_test], Path(f'{output_file_base}_test.pt'))
            print_X_y(X, y, output_file_base + '_test.txt')
        torch.save([x_train, y_train], Path(f'{output_file_base}_train_{i}.pt'))
        print_X_y(X, y, output_file_base + f'_train_{i}.txt')
        torch.save([x_test, y_test], Path(f'{output_file_base}_test_{i}.pt'))
        print_X_y(X, y, output_file_base + f'_test_{i}.txt')
        i += 1

if __name__ == "__main__":
    downsample = False
    for numeric_dataset in only_numerical:
        preprocess_dataset(numeric_dataset.name, downsample)
    for mixed_dataset in mixed:
        preprocess_dataset(mixed_dataset.name, downsample, suffix='onehot')
        preprocess_dataset(mixed_dataset.name, downsample, onehot_ranges=mixed_dataset.onehot_ranges)
    for mixed_dataset in optional:
        preprocess_dataset(mixed_dataset.name, downsample, suffix='onehot')
        preprocess_dataset(mixed_dataset.name, downsample, onehot_ranges=mixed_dataset.onehot_ranges)
