import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split



def data_prep_transfer_mimic(ds_id, seed, task, stage='pretrain', pretrain_proportion=0, downstream_samples_per_class = 2, pretrain_subsample = False, sample_validation_set = False, input_seed_for_validation = 0):

    np.random.seed(seed)
    mimic_target_columns = ['diabetes_diagnosed', 'hypertensive_diagnosed', 'ischematic_diagnosed',
                            'heart_diagnosed', 'overweight_diagnosed', 'anemia_diagnosed', 'respiratory_diagnosed',
                            'hypotension_diagnosed', 'lipoid_diagnosed', 'atrial_diagnosed', 'purpura_diagnosed',
                            'alcohol_diagnosed']
    X_train = pd.read_csv('data/mimic_train_X.csv')
    X_val = pd.read_csv('data/mimic_val_X.csv')
    X_test = pd.read_csv('data/mimic_test_X.csv')
    y_train_full = pd.read_csv('data/mimic_train_y.csv')
    y_val_full = pd.read_csv('data/mimic_val_y.csv')
    y_test_full = pd.read_csv('data/mimic_test_y.csv')
    categorical_columns = ['gender']
    numerical_columns = list(X_train.columns[X_train.columns != 'gender'])
    X_train[categorical_columns] = X_train[categorical_columns].fillna("MissingValue")
    X_val[categorical_columns] = X_val[categorical_columns].fillna("MissingValue")
    X_test[categorical_columns] = X_test[categorical_columns].fillna("MissingValue")

    if task == 'binclass':
        if 'downstream' in stage:
            y_train_full = pd.concat([y_train_full, y_val_full], ignore_index=True)
            X_train = pd.concat([X_train, X_val], ignore_index=True)
            y_train = y_train_full[mimic_target_columns[pretrain_proportion]]
            y_val = y_val_full[mimic_target_columns[pretrain_proportion]]
            y_test = y_test_full[mimic_target_columns[pretrain_proportion]]
        elif 'pretrain' in stage:
            y_train = y_train_full.drop(columns=[mimic_target_columns[pretrain_proportion]])
            y_val = y_val_full.drop(columns=[mimic_target_columns[pretrain_proportion]])
            y_test = y_test_full.drop(columns=[mimic_target_columns[pretrain_proportion]])
            if pretrain_subsample:
                subsample_tuning_target = np.random.randint(10)
                possible_tuning_targets = np.array(y_train.columns)
                y_train = y_train_full[possible_tuning_targets[subsample_tuning_target]]
                y_val = y_val_full[possible_tuning_targets[subsample_tuning_target]]
                y_test = y_test_full[possible_tuning_targets[subsample_tuning_target]]
        else:
            raise ValueError('Stage is incorrect!')
    else:
        raise NotImplementedError('Mimic only accepts binclass tasks')

    X_train_full = X_train.copy()
    y_train_full = y_train.copy()
    if ('downstream' in stage) or pretrain_subsample:
        total_num_of_classes = len(set(y_train))
        X_train, _, y_train, _ = train_test_split(X_train, y_train,
                                       train_size=downstream_samples_per_class * len(set(y_train)),
                                       stratify=y_train, random_state = seed)
        sample_num_classes = len(set(y_train))
        if sample_num_classes < total_num_of_classes:
            X_train, y_train = stratified_sample_at_least_one_per_class(X_train_full, y_train_full, downstream_samples_per_class, seed)
            sample_num_classes = len(set(y_train))
        assert total_num_of_classes == sample_num_classes

    if sample_validation_set:
        X_train, X_val, y_train, y_val = sample_small_validation_set(X_train, y_train, downstream_samples_per_class, input_seed_for_validation)

    X_cat_train = X_train[categorical_columns].values
    X_num_train = X_train[numerical_columns].values
    y_train = y_train.values.astype('float')

    X_cat_val = X_val[categorical_columns].values
    X_num_val = X_val[numerical_columns].values
    y_val = y_val.values.astype('float')

    X_cat_test = X_test[categorical_columns].values
    X_num_test = X_test[numerical_columns].values
    y_test = y_test.values.astype('float')

    info = {}
    info['name'] = ds_id
    info['stage'] = stage
    info['split'] = seed

    info['task_type'] = task
    info['n_num_features'] = len(numerical_columns)
    info['n_cat_features'] = len(categorical_columns)
    info['train_size'] = X_train.shape[0]
    info['val_size'] = X_val.shape[0]
    info['test_size'] = X_test.shape[0]
    info['replacement_sampling'] = False
    if task == 'multiclass':
        info['n_classes'] = len(set(y))
    if task == 'binclass':
        if len(y_train.shape) > 1:
            info['n_classes'] = y_train.shape[1]
        else:
            info['n_classes'] = 1

    if len(numerical_columns) > 0:
        if ('downstream' in stage) and (not sample_validation_set):
            X_num_val = X_num_train
        N = {'train': X_num_train, 'val': X_num_val, 'test': X_num_test}
    else:
        N = None

    if len(categorical_columns) > 0:
        if ('downstream' in stage) and (not sample_validation_set):
            X_cat_val = X_cat_train
        C = {'train': X_cat_train, 'val': X_cat_val, 'test': X_cat_test}
    else:
        C = None

    if ('downstream' in stage) and (not sample_validation_set):
        y_val = y_train
    y = {'train': y_train, 'val': y_val, 'test': y_test}
    print('\n Train size:{} Val size:{} Test size:{}'.format(len(y_train), len(y_val), len(y_test)))
    if len(categorical_columns) > 0:
        full_cat_data_for_encoder = X_train_full[categorical_columns]
    else:
        full_cat_data_for_encoder = None
    return N, C, y, info, full_cat_data_for_encoder

def stratified_sample_at_least_one_per_class(X_train, y_train, downstream_samples_per_class, seed):
    
    X_train['y'] = y_train
    X_one_sample = X_train.groupby(by='y').sample(n=1)
    y_one_sample = X_one_sample['y']
    X_one_sample = X_one_sample.drop(columns=['y'])
    X_train = X_train[~X_train.index.isin(X_one_sample.index)]
    y_train = X_train['y']
    X_train = X_train.drop(columns=['y'])
    X_train, _, y_train, _ = train_test_split(X_train, y_train,
                                              train_size=downstream_samples_per_class * len(set(y_train)) - len(
                                                  X_one_sample),
                                              stratify=y_train, random_state=seed)
    X_train = pd.concat([X_train, X_one_sample], axis=0)
    y_train = pd.concat([y_train, y_one_sample], axis=0)
    return X_train, y_train

def sample_small_validation_set(X_train, y_train, downstream_samples_per_class, seed):
    total_num_of_classes = len(set(y_train))
    X_train, X_val, y_train, y_val = train_test_split(X_train, y_train,
                                   train_size=int(0.8*downstream_samples_per_class * len(set(y_train))),
                                   stratify=y_train, random_state = seed)
    sample_num_classes = len(set(y_train))

    assert total_num_of_classes == sample_num_classes
    return X_train, X_val, y_train, y_val
