
# %%

import os
import numpy as np
import scipy
import pandas as pd
import time

from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegressionCV, SGDClassifier, \
    RidgeClassifier, RidgeClassifierCV
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, roc_auc_score, balanced_accuracy_score
from sklearn.preprocessing import StandardScaler

import matplotlib.pyplot as plt
import argparse

import tensorflow as tf
from tensorflow import keras

from xgboost import XGBClassifier

from maml.code.sklearn_pipeline import SklearnPipeline

print(os.getcwd())

# %%
def clr(X):
    return np.log(X) - np.mean(np.log(X), axis=1, keepdims=True)


# %%

# X_scatter = pca_clr.transform(X_clr)[:, 0:2]

# plt.scatter(X_scatter[:, 0], X_scatter[:, 1], c=y)
# plt.show()

# %%

# y2 = pd.read_csv(
#     in_dir + 'hmp-task-sex.txt-y.csv',
#     index_col=0,
# )

# X_df2 = X_df.loc[y2.index]
# X_clr2 = clr(X_df2.to_numpy())

# X_scatter = pca_clr.transform(X_clr2)

# color = pd.get_dummies(y2['Var'], drop_first=True).to_numpy()
# plt.scatter(X_scatter[:, 0], X_scatter[:, 1], c=color)
# plt.show()

from augmenter import Augmenter
# from sklearn.base import BaseEstimator,TransformerMixin
# class Augmenter(BaseEstimator, TransformerMixin):
#     # See https://stackoverflow.com/questions/25539311/custom-transformer-for-sklearn-pipeline-that-alters-both-x-and-y
#     def __init__(self, nover2, factor=10):
#         self.nover2 = nover2
#         self.factor = factor # data augmentation factor

#     def fit( self, X, y=None):
#         return self 

#     def transform( self, X, y=None):
#         return X

#     def fit_transform(self, X, y):
#         print('\n\n\nMADE IT TO FIRST CKPT')
#         # print(X.shape)
#         # import traceback
#         # traceback.print_stack()

#         if y is None:
#             return X

#         X_bal = X.copy()
#         y_bal = y.copy()
#         for i in range(2, self.factor):
#             X_bal = np.concatenate([X_bal, X.copy()], axis=0)
#             y_bal = np.concatenate([y_bal, y.copy()])

#         # Check that we are operating on the train fold
#         if X.shape[0] > self.nover2:
#             for val in [0, 1]:
#                 print('\n\n\n\nI AM HEREEEEEEEE')
#                 idxs = y == val
#                 X_temp = X[idxs, :]
#                 n = X_temp.shape[0]
#                 n_aug = int(self.factor * n) - n
#                 lam = np.random.rand(n_aug).reshape([-1, 1])
#                 idx1 = np.random.choice(n, size=n_aug)
#                 idx2 = np.random.choice(n, size=n_aug)
#                 # Take convex combination
#                 X_aug = lam * X_temp[idx1, :] + (1 - lam) * X_temp[idx2, :]

#                 X = np.concatenate([X, X_aug], axis=0)
#                 y = np.concatenate([y, np.repeat(val, n_aug)])
#             X = np.concatenate([X, X_bal], axis=0)
#             y = np.concatenate([y, y_bal])
#             return X, y
#         return self.transform(X, y)


def augment_X(X_train, y_train, params):
    X = X_train.copy()
    y = y_train.copy()
    w = np.ones_like(y)

    if params == {}:
        return X, y, w
    
    if 'weight' in params:
        weight = params['weight']
    else:
        weight = params['factor'] / (1 + params['factor'])

    # if params['aug'] == 'pairseq_naive':
    #     for i in range(0, len(y_train) - 1):
    #         print(i)
    #         for j in range(i + 1, len(y_train)):
    #             if y_train[i] == y_train[j]:
    #                 X = np.concatenate(
    #                     [X, X_train[[i], :]/2 + X_train[[j], :]/2],
    #                     axis=0,
    #                 )
    #                 y = np.concatenate([y, y_train[[i]]])


    if params.get('mult') == True:
        for val in [0, 1]:
            idxs = y_train == val
            X_temp = X_train[idxs, :]
            n = X_temp.shape[0]
            n_large = 10 * n
            n_aug = int(params['factor'] * n) - n
            X_aug = []
            y_aug = []
            for i in range(n_large):
                idx = np.random.choice(n)
                counts = X_temp[idx, :] - 1
                X_aug.append(np.random.multinomial(counts.sum(), counts / counts.sum()))
                y_aug.append(y_train[idx])
            X_aug = X_aug[0:n_aug]
            y_aug = y_aug[0:n_aug]
            X = np.concatenate([X, np.array(X_aug) + 1], axis=0)
            y = np.concatenate([y, np.repeat(val, n_aug)])
            w = np.concatenate([w, np.repeat(weight / (1 - weight) * X_train.shape[0] / n_aug, n_aug)])

    if params['space'] == 'clr':
        X_train = clr(X_train)
        X = clr(X)
    elif params['space'] == 'prop':
        X_train_sums = X_train.sum(axis=1, keepdims=True)
        X_train = X_train / X_train_sums
        X_sums = X.sum(axis=1, keepdims=True)
        X = X / X_sums

    if params.get('conv') == 'half':
        raise ValueError("Haven't implemented proper seeding")
        for val in [0, 1]:
            idxs = y_train == val
            X_temp = X_train[idxs, :]
            mask = np.tri(X_temp.shape[0], dtype=bool)
            X_aug = (X_temp[:, None] + X_temp) / 2
            X_aug = X_aug[~mask]
            # Only augment in the same proportion as original class imbalance
            imbalance = np.mean(y_train == val) / np.mean(y_train == 1 - val)
            if imbalance > 1:
                n_other = np.sum(y_train == 1 - val)
                n_aug = int(n_other * (n_other - 1) / 2 * imbalance)
                X_aug = X_aug[np.random.choice(X_aug.shape[0], size=n_aug, replace=False)]
            X = np.concatenate([X, X_aug], axis=0)
            y = np.concatenate([y, np.repeat(val, X_aug.shape[0])])
    
    if params.get('conv') == 'rand':
        for val in [0, 1]:
            idxs = y_train == val
            X_temp = X_train[idxs, :]
            n = X_temp.shape[0]
            n_large = 10 * n
            n_aug = int(params['factor'] * n) - n

            lam = np.random.rand(n_large).reshape([-1, 1])
            idx1 = np.random.choice(n, size=n_large)
            idx2 = np.random.choice(n, size=n_large)

            # Take convex combination
            X_aug = lam * X_temp[idx1, :] + (1 - lam) * X_temp[idx2, :]

            X = np.concatenate([X, X_aug[0:n_aug]], axis=0)
            y = np.concatenate([y, np.repeat(val, n_aug)])
            w = np.concatenate([w, np.repeat(weight / (1 - weight) * X_train.shape[0] / n_aug, n_aug)])

    if params.get('comb') == 'rand':
        for val in [0, 1]:
            idxs = y_train == val
            X_temp = X_train[idxs, :]
            n = X_temp.shape[0]
            n_large = 10 * n
            n_aug = int(params['factor']) * n - n

            idx1 = np.random.choice(n, size=n_large)
            idx2 = np.random.choice(n, size=n_large)

            p = np.random.rand(n_large)
            mask = np.random.binomial(1, p, [X_temp.shape[1], n_large]).T
            X_aug = mask * X_temp[idx1, :] + (1 - mask) * X_temp[idx2, :]

            # If in clr space we must mean center each observation to 
            # ensure a valid composition
            if params['space'] == 'clr':
                X_aug = X_aug - X_aug.sum(axis=1, keepdims=True)

            X = np.concatenate([X, X_aug[0:n_aug]], axis=0)
            y = np.concatenate([y, np.repeat(val, n_aug)])
            w = np.concatenate([w, np.repeat(weight / (1 - weight) * X_train.shape[0] / n_aug, n_aug)])

    if params['space'] == 'clr':
        X = scipy.special.softmax(X, axis=1)
    # elif params['space'] == 'prop':
    #     X_train_sums = X_train.sum(axis=1, keepdims=True)
    #     X_train = X_train / X_train_sums
    #     X_sums = X.sum(axis=1, keepdims=True)
    #     X = X * X_sums
    #TODO: allow for counts to be returned here
    
    return X, y, w

def transform_X(X_tr, X_te, param):

    if param.get('space') == 'clr':
        X_tr = clr(X_tr)
        X_te = clr(X_te)
    elif param.get('space') == 'prop':
        X_tr = X_tr / X_tr.sum(axis=1, keepdims=True)
        X_te = X_te / X_te.sum(axis=1, keepdims=True)

    if param.get('std') == True:
        ss = StandardScaler()
        X_tr = ss.fit_transform(X_tr)
        X_te = ss.transform(X_te)

    return X_tr, X_te


def dim_reduction(X_train_aug, X_test, dr_params):
    if dr_params == {}:
        return X_train_aug.copy(), X_test.copy()
    
    if 'PCs' in dr_params:
        PCs = dr_params['PCs']
        pca = PCA()
        pca.fit(X_train_aug)
        X_train_dr = pca.transform(X_train_aug)
        X_test_dr = pca.transform(X_test)
        return X_train_dr[:, :PCs], X_test_dr[:, :PCs]


def evaluate_classifier(X_train, y_train, w, X_test, y_test, params):

    start_time = time.time()

    if params['model'] == 'svm':

        model = SGDClassifier(loss='hinge')
        model.fit(X_train, y_train, sample_weight=w)
        y_pred = 1 / (1 + np.exp(-model.decision_function(X_test)))
        
    if params['model'] == 'ridge':
        # Can use liblinear or saga solvers
        model = RidgeClassifier() #class_weight='balanced')#, alphas=alphas)
        model.fit(X_train, y_train, sample_weight=w)
        y_pred = 1 / (1 + np.exp(-model.decision_function(X_test)))

    if params['model'] == 'ridgecv':
        # Can use liblinear or saga solvers
        model = RidgeClassifierCV()#class_weight='balanced')#, alphas=alphas)
        model.fit(X_train, y_train, sample_weight=w)
        y_pred = 1 / (1 + np.exp(-model.decision_function(X_test)))

    if params['model'] == 'rf':
        # TODO: Pick mtry by CV
        # model = RandomForestClassifier(random_state=random_state)
        model = RandomForestClassifier(**rf_par) #, class_weight='balanced')
        model.fit(X_train, y_train, sample_weight=w)
        y_pred = model.predict_proba(X_test)[:, 1]

    if params['model'] == 'xgb':
        if params.get('early') == True:
            dtrain = xgb.DMatrix(X_train, label=y_train, weight=w)
            xgb_params = {'objective': 'binary:logistic'}
            model = xgb.cv(xgb_params, dtrain, num_boost_round=100, nfold=5)
            num_rounds = 1 + model['test-logloss-mean'].argmin()
        else:
            num_rounds = 100

        model = XGBClassifier(random_state=seed, n_estimators=num_rounds)
        model.fit(X_train, y_train, sample_weight=w)
        y_pred = model.predict_proba(X_test)[:, 1]

    if params['model'] == 'mlp':
        tf.random.set_seed(tf_seed)
        METRICS = [
            keras.metrics.TruePositives(name='tp'),
            keras.metrics.FalsePositives(name='fp'),
            keras.metrics.TrueNegatives(name='tn'),
            keras.metrics.FalseNegatives(name='fn'), 
            keras.metrics.BinaryAccuracy(name='accuracy'),
            keras.metrics.Precision(name='precision'),
            keras.metrics.Recall(name='recall'),
            keras.metrics.AUC(name='auc'),
            keras.metrics.AUC(name='prc', curve='PR'), # precision-recall curve
        ]

        def make_model(metrics=METRICS, output_bias=None):
            if output_bias is not None:
                output_bias = tf.keras.initializers.Constant(output_bias)
            
            if 'layers' in params:
                layers = params['layers']
            else:
                layers = 4
            
            if 'bn' in params:
                bn = params['bn']
            else:
                bn = True

            if 'dp' in params:
                dp = params['dp']
            else:
                dp = 0

            sequential = []
            for l in range(layers):
                sequential.append(keras.layers.Dense(128, activation='relu'))
                if dp > 0:
                    sequential.append(keras.layers.Dropout(0.5))
                if bn:
                    sequential.append(keras.layers.BatchNormalization())

            sequential.append(keras.layers.Dense(1, activation='sigmoid', bias_initializer=output_bias))

            model = keras.Sequential(sequential)

            # model = keras.Sequential([
            #     keras.layers.Dense(128, activation='relu'),
            #     # keras.layers.Dropout(0.5),
            #     keras.layers.BatchNormalization(),
            #     keras.layers.Dense(128, activation='relu'),
            #     # keras.layers.Dropout(0.5),
            #     keras.layers.BatchNormalization(),
            #     keras.layers.Dense(128, activation='relu'),
            #     # keras.layers.Dropout(0.5),
            #     keras.layers.BatchNormalization(),
            #     keras.layers.Dense(128, activation='relu'),
            #     # keras.layers.Dropout(0.5),
            #     keras.layers.BatchNormalization(),
            #     keras.layers.Dense(1, activation='sigmoid',
            #                         bias_initializer=output_bias),
            # ])

            model.compile(
                optimizer=keras.optimizers.Adam(learning_rate=1e-3),
                loss=keras.losses.BinaryCrossentropy(),
                metrics=metrics)

            return model

        if 'ep' in params:
            EPOCHS = params['ep']
        else:
            EPOCHS = 100

        if 'bs' in params:
            BATCH_SIZE = params['bs']
        else:
            BATCH_SIZE = 2048

        early_stopping = tf.keras.callbacks.EarlyStopping(
            monitor='val_prc', 
            verbose=1,
            patience=10,
            mode='max',
            restore_best_weights=True)

        # model = make_model()
        # model.summary()
        
        # TODO: Set good initialization of bias as per https://www.tensorflow.org/tutorials/structured_data/imbalanced_data

        if False:
            model = make_model()
            # model.load_weights(initial_weights)
            baseline_history = model.fit(
                X_train,
                y_train,
                batch_size=BATCH_SIZE,
                epochs=EPOCHS,
                callbacks=[early_stopping],
                validation_data=(X_test, y_test))
        
        weighted_model = make_model()
        # weighted_model.load_weights(initial_weights)
        
        pos = np.sum(y_train == 1)
        neg = np.sum(y_train == 0)
        total = neg + pos

        weight_for_0 = (1 / neg) * (total / 2.0)
        weight_for_1 = (1 / pos) * (total / 2.0)

        class_weight = {0: weight_for_0, 1: weight_for_1}

        if 'early' in params and params['early']:
            X1, X2, y1, y2 = train_test_split(
                X_train, y_train, test_size=0.2, stratify=y_train
            )
            weighted_history = weighted_model.fit(
                X1,
                y1,
                batch_size=BATCH_SIZE,
                epochs=EPOCHS,
                callbacks=[early_stopping],
                validation_data=(X2, y2),
                # The class weights go here
                # class_weight=class_weight
                sample_weight=w,
            )
        else:
            weighted_history = weighted_model.fit(
                X_train,
                y_train,
                batch_size=BATCH_SIZE,
                epochs=EPOCHS,
                # The class weights go here
                # class_weight=class_weight
                sample_weight=w,
            )

        # def plot_metrics(history):
        # metrics = ['loss', 'prc', 'precision', 'recall']
        # for n, metric in enumerate(metrics):
        #     name = metric.replace("_"," ").capitalize()
        #     plt.subplot(2,2,n+1)
        #     plt.plot(history.epoch, history.history[metric], color=colors[0], label='Train')
        #     plt.plot(history.epoch, history.history['val_'+metric],
        #             color=colors[0], linestyle="--", label='Val')
        #     plt.xlabel('Epoch')
        #     plt.ylabel(name)
        #     if metric == 'loss':
        #     plt.ylim([0, plt.ylim()[1]])
        #     elif metric == 'auc':
        #     plt.ylim([0.8,1])
        #     else:
        #     plt.ylim([0,1])

        #     plt.legend();

        # plot_metrics(baseline_history)

        y_pred = weighted_model.predict(X_test).flatten()

        # model = RandomForestClassifier(**rf_par, class_weight='balanced')
        # model.fit(X_train, y_train)
        # y_pred = model.predict_proba(X_test)[:, 1]

    if params['model'] == 'nn':
        tf.random.set_seed(tf_seed)

        METRICS = [
            keras.metrics.BinaryAccuracy(name='accuracy'),
        ]

        # Implement a VAE
        dp = 0.0
        class VAE(keras.Model):
            def __init__(self, input_dim, latent_dim=32):
                super(VAE, self).__init__()
                self.latent_dim = latent_dim
                self.input_dim = input_dim
                self.encoder = keras.Sequential([
                    keras.layers.Dense(128, activation='relu'),
                    keras.layers.Dropout(dp),
                    keras.layers.BatchNormalization(),
                    keras.layers.Dense(128, activation='relu'),
                    keras.layers.Dropout(dp),
                    keras.layers.BatchNormalization(),
                    keras.layers.Dense(64, activation='relu'),
                    keras.layers.Dropout(dp),
                    keras.layers.BatchNormalization(),
                    keras.layers.Dense(64, activation='relu'),
                    keras.layers.Dropout(dp),
                    keras.layers.BatchNormalization(),
                    keras.layers.Dense(latent_dim * 2),
                ])
                self.decoder = keras.Sequential([
                    keras.layers.Dense(32, activation='relu'),
                    # keras.layers.Dropout(dp),
                    keras.layers.BatchNormalization(),
                    keras.layers.Dense(64, activation='relu'),
                    # keras.layers.Dropout(dp),
                    keras.layers.BatchNormalization(),
                    keras.layers.Dense(64, activation='relu'),
                    # keras.layers.Dropout(dp),
                    keras.layers.BatchNormalization(),
                    keras.layers.Dense(128, activation='relu'),
                    # keras.layers.Dropout(dp),
                    keras.layers.BatchNormalization(),
                    keras.layers.Dense(128, activation='relu'),
                    # keras.layers.Dropout(dp),
                    keras.layers.BatchNormalization(),
                    keras.layers.Dense(self.input_dim),
                ])
                # self.head = keras.Sequential([
                #     keras.layers.Dense(1, activation='sigmoid'),
                # ])

            def encode(self, x):
                mean, logvar = tf.split(self.encoder(x), num_or_size_splits=2, axis=1)
                return mean, logvar

            def reparameterize(self, mean, logvar):
                eps = tf.random.normal(shape=mean.shape)
                return eps * tf.exp(logvar * .5) + mean

            def decode(self, z):
                return self.decoder(z)

            # def predict(self, x):
            #     mean, _ = self.encode(x)
            #     y_hat = self.head(mean)
            #     return y_hat


            # def call(self, x):
            #     mean, logvar = self.encode(x)
            #     z = self.reparameterize(mean, logvar)
            #     x_recon = self.decode(z)
            #     return x_recon, mean, logvar            

        def log_normal_pdf(sample, mean, logvar, raxis=1):
            log2pi = tf.math.log(2. * np.pi)
            return tf.reduce_sum(
                -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),
                axis=raxis)

        def vae_loss(model, x):
            mean, logvar = model.encode(x)
            z = model.reparameterize(mean, logvar)
            x_recon = model.decode(z)
            recon_loss = tf.reduce_sum(tf.square(x_recon - x))
            kl_loss = -0.5 * tf.reduce_sum(1 + logvar - tf.square(mean) - tf.exp(logvar))
            return recon_loss + kl_loss

            # recon_loss = tf.reduce_sum(tf.square(x_recon - x), axis=1)
            # kl_loss = -0.5 * tf.reduce_sum(1 + logvar - tf.square(mean) - tf.exp(logvar), axis=1)
            # return tf.reduce_mean(recon_loss + kl_loss)

            # logpx_z = -tf.reduce_sum(tf.square(x_recon - x), axis=1)
            # logpz = log_normal_pdf(z, 0., 0.)
            # logqz_x = log_normal_pdf(z, mean, logvar)
            # return -tf.reduce_mean(logpx_z + logpz - logqz_x)
        
        def vae_step(model, x, optimizer):
            with tf.GradientTape() as tape:
                loss = vae_loss(model, x)
            gradients = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(gradients, model.trainable_variables))
            return loss

        def train_vae(model, train_dataset, epochs, optimizer):
            for epoch in range(epochs):
                for x in train_dataset:
                    loss = vae_step(model, x, optimizer)
                if epoch % 10 == 0:
                    print("Epoch: {}, Loss: {}".format(epoch, loss))

        # def finetune_loss(model, x, y):
        #     y_hat = model.predict(x)
        #     return tf.reduce_mean(keras.losses.binary_crossentropy(y, y_hat))

        # def finetune_step(model, x, y, optimizer):
        #     with tf.GradientTape() as tape:
        #         loss = finetune_loss(model, x, y)
        #     gradients = tape.gradient(loss, model.trainable_variables)
        #     optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        #     return loss
        
        # def finetune(model, train_dataset, epochs, optimizer):
        #     for epoch in range(epochs):
        #         for x, y in train_dataset:
        #             loss = finetune_step(model, x, y, optimizer)
        #         if epoch % 10 == 0:
        #             print("Epoch: {}, Loss: {}".format(epoch, loss))

        # Testing

        # lreg = LogisticRegressionCV(penalty='l2', Cs=50)
        # lreg.fit(X_train, y_train, sample_weight=w)
        # l2 = 1 / (2 * lreg.C_[0] * X_train.shape[0])

        # inputs = keras.Input(shape=(X_train.shape[1],))
        # outputs = keras.layers.Dense(1, kernel_regularizer=keras.regularizers.l2(l2), activation='sigmoid')(inputs)
        # model = keras.Model(inputs, outputs)
        # model.layers[1].set_weights([lreg.coef_.reshape([-1, 1]), lreg.intercept_])
        # optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
        # model.compile(optimizer=optimizer, loss=keras.losses.binary_crossentropy, metrics=[keras.metrics.BinaryAccuracy()])
        # model.fit(X_train, y_train, epochs=100, batch_size=X_train.shape[0], sample_weight=w)

        EPOCHS = 100
        BATCH_SIZE = 128
        lr = 0.001
        vae = VAE(input_dim=X_train.shape[1])
        for layer in vae.encoder.layers:
            if 'dropout' in layer.name:
                layer.trainable = False
        optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
        # train_dataset = (tf.data.Dataset.from_tensor_slices((X_train.astype("float32"), y_train.astype('float32'))).shuffle(X_train.shape[0]).batch(BATCH_SIZE))
        train_dataset = (tf.data.Dataset.from_tensor_slices(X_train[0:n_train_orig, :].astype("float32")).shuffle(X_train.shape[0]).batch(BATCH_SIZE))
        train_vae(vae, train_dataset, EPOCHS, optimizer)


        # Linear probing with residual connection

        # Initialization & wd
        lreg = LogisticRegressionCV(penalty='l2', Cs=50)
        z_train, _ = vae.encode(X_train)
        lreg.fit(np.concatenate([z_train.numpy(), X_train], axis=1), y_train, sample_weight=w)
        l2 = 1 / (2 * lreg.C_[0] * X_train.shape[0])
        # l2 = l2 * 2
        
        head = keras.layers.Dense(1)
        linear = keras.layers.Dense(1)
        # vae.trainable = False
        inputs = keras.Input(shape=(X_train.shape[1],))
        means, _ = vae.encode(inputs)
        head = keras.layers.Dense(1, kernel_regularizer=keras.regularizers.l2(l2), activation='sigmoid')
        outputs = head(tf.concat([means, inputs], axis=1))
        model = keras.Model(inputs, outputs)

        head.set_weights([lreg.coef_.reshape([-1, 1]), lreg.intercept_])
        # model.layers[4].set_weights([lreg.coef_.reshape([-1, 1]), lreg.intercept_])
        
        # optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
        # model.compile(optimizer=optimizer, loss=keras.losses.binary_crossentropy, metrics=[keras.metrics.BinaryAccuracy()])
        # model.fit(X_train, y_train, epochs=EPOCHS, batch_size=BATCH_SIZE, sample_weight=w)

        # Finetune
        vae.trainable = True
        optimizer = tf.keras.optimizers.Adam(learning_rate=lr / 10)
        model.compile(optimizer=optimizer, loss=keras.losses.binary_crossentropy, metrics=[keras.metrics.BinaryAccuracy()])
        model.fit(X_train, y_train, epochs=EPOCHS, batch_size=BATCH_SIZE, sample_weight=w)
        y_pred = model.predict(X_test)


    if params['model'] == 'metann':
        tf.random.set_seed(tf_seed)
        
        METRICS = [
            keras.metrics.BinaryAccuracy(name='accuracy'),
        ]

        def make_model(metrics=METRICS, output_bias=None):
            if output_bias is not None:
                output_bias = tf.keras.initializers.Constant(output_bias)

            sequential = []
            sequential.append(keras.layers.Dense(512))
            sequential.append(keras.layers.Dropout(0.5))
            sequential.append(keras.layers.ReLU())
            sequential.append(keras.layers.Dense(256))
            sequential.append(keras.layers.ReLU())
            sequential.append(keras.layers.Dropout(0.5))
            sequential.append(keras.layers.Dense(128))
            sequential.append(keras.layers.ReLU())
            sequential.append(keras.layers.Dropout(0.5))
            sequential.append(keras.layers.Dense(1, activation='sigmoid'))
            model = keras.Sequential(sequential)

            # sequential.append(keras.layers.Dense(1, activation='sigmoid', bias_initializer=output_bias))

            model.compile(
                optimizer=keras.optimizers.Adam(learning_rate=0.01),
                loss=keras.losses.BinaryCrossentropy(),
                metrics=metrics)

            return model

        EPOCHS = 200

        BATCH_SIZE = 32

        weighted_model = make_model()
        # weighted_model.load_weights(initial_weights)
        
        pos = np.sum(y_train == 1)
        neg = np.sum(y_train == 0)
        total = neg + pos

        weight_for_0 = (1 / neg) * (total / 2.0)
        weight_for_1 = (1 / pos) * (total / 2.0)

        class_weight = {0: weight_for_0, 1: weight_for_1}

        weighted_history = weighted_model.fit(
            X_train,
            y_train,
            batch_size=BATCH_SIZE,
            epochs=EPOCHS,
            # The class weights go here
            # class_weight=class_weight
            sample_weight=w,
        )

        y_pred = weighted_model.predict(X_test).flatten()

    if params['model'] == 'maml':
        print(y_test)
        X_train_df = pd.DataFrame(X_train)
        X_test_df = pd.DataFrame(X_test)
        y_train_df = pd.DataFrame(y_train.astype('str'))
        tmp_dir = "tmp/d" + str(data_idx) + "s" + str(np_seed)
        skp = SklearnPipeline(tmp_dir, X_train_df, y_train_df)#, w=w)
        skp.filter_low_prevalence_features()
        # TODO: fix this
        skp.over_sampling()
        from maml.code.sklearn_pipeline_config import SCALERS, Tree_based_CLASSIFIERS, Other_CLASSIFIERS
        # SCALERS=SCALERS[0:1];
        # Tree_based_CLASSIFIERS=Tree_based_CLASSIFIERS[0:2];
        # Other_CLASSIFIERS=Other_CLASSIFIERS[0:1]

        if params.get('aug') == 'aitch':
            augmenter = Augmenter(nover2=X_train.shape[0]/2)
        else:
            augmenter = None

        All_CLASSIFIERS = Tree_based_CLASSIFIERS + Other_CLASSIFIERS
        skp.select_best_scl_clf(SCALERS, Tree_based_CLASSIFIERS, Other_CLASSIFIERS, augmenter, n_jobs=8)
        skp.hypertune_best_classifier(All_CLASSIFIERS, n_jobs=8)
        if hasattr(skp.best_estimator_, "decision_function"):
            y_pred = skp.best_estimator_.decision_function(X_test_df)
            y_pred = 1 / (1 + np.exp(-y_pred))
            # decision_function, finds the distance to the separating hyperplane.
            # y_proba = cross_val_predict(grid_search, self.X, self.Y, cv=outer_cv, method='decision_function')
        elif hasattr(skp.best_estimator_, "predict_proba"):
            # predict_proba is a method of a (soft) classifier outputting the probability of the instance being in each of the classes.
            # y_proba = cross_val_predict(grid_search, self.X, self.Y, cv=outer_cv, method='predict_proba')
            y_pred = skp.best_estimator_.predict_proba(X_test_df)[:, 1]
        print(time.time() - start_time)
        print(y_test)
        print(y_pred)

    if params['model'] == 'deepcoda':
        tf.random.set_seed(tf_seed)
        epochs = 200
        cascade_level = 5
        bottle_dim = 1
        hidden_dim = 16
        latent_dim = 1
        output_dim = 1
        batch_size = 32

        from keras import backend as K

        # regularize sum of weights at each cascade to be 0 and regularize weights to be sparse
        class SumZeroL1Reg(keras.regularizers.Regularizer):
            def __init__(self, sumzero_lambda=1e0, l1_lambda=1e-2):
                self.sumzero_lambda = K.cast_to_floatx(sumzero_lambda)
                self.l1_lambda = K.cast_to_floatx(l1_lambda)

            def __call__(self, w):
                sumzero_reg = 0
                sumzero_reg += self.sumzero_lambda * K.square(K.sum(w))

                l1_reg = 0
                l1_reg += self.l1_lambda * K.sum(K.abs(w))

                return sumzero_reg + l1_reg

            def get_config(self):
                return {'sumzero_lambda': float(self.sumzero_lambda),
                        'l1_lambda': float(self.l1_lambda)}


        x = keras.Input(shape=(X_train.shape[1],))
        # concat layer for all z
        concat_z = []
        for level_id in range(cascade_level):
            x_log = keras.layers.Lambda(lambda t: K.log(t))(x)
        # if use_weight_constraint == True:
            b = keras.layers.Dense(bottle_dim, activation='linear',
                    kernel_regularizer=SumZeroL1Reg())(x_log)
        # else:
        #     b = Dense(bottle_dim, activation='linear')(x_log)
            z = b
            concat_z.append(z)
        if cascade_level == 1:
            all_z = z
        else:
            all_z = keras.layers.Concatenate()(concat_z)
        h = keras.layers.Dense(hidden_dim, activation='relu')(all_z)
        beta = keras.layers.Dense(cascade_level, activation='linear')(h)
        # Decoder
        all_z_beta = keras.layers.Dot(axes=1)([all_z, beta])
        decoder = keras.Sequential([keras.layers.Dense(output_dim, input_dim=output_dim, activation='sigmoid')])
        y_pred = decoder(all_z_beta)
        # train network
        model = keras.Model(inputs=x, outputs=y_pred, name='bottleneck_model')
        opt = keras.optimizers.Adam()
        model.compile(optimizer=opt, loss='binary_crossentropy', metrics=['accuracy'])
        # hist = model.fit(X_train, y_train, shuffle=True, epochs=epochs, batch_size=batch_size, validation_split=0.1, verbose=0)
        hist = model.fit(X_train, y_train, shuffle=True, sample_weight=w, epochs=epochs, batch_size=batch_size, verbose=0)

        y_pred = model.predict(X_test)

    end_time = time.time()

    acc_bl = np.mean(np.concatenate([y_train, y_test]))
    acc_bl = max(acc_bl, 1 - acc_bl)
    y_pred_bin = np.round(y_pred)
    acc = accuracy_score(y_test, y_pred_bin)
    bacc = balanced_accuracy_score(y_test, y_pred_bin)
    auc = roc_auc_score(y_test, y_pred)

    res = {
        'acc_bl': [acc_bl],
        'acc': [acc],
        'bacc': [bacc],
        'auc': [auc],
        'runtime': [end_time - start_time],
    }

    return res

# %%
if __name__ == '__main__':

    seed = 0
    # np_seed = 0
    tf_seed = 0
    train_size = 0.8
    rf_par = {'n_estimators': 500, 'n_jobs': 8}

    # Set up params
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_idx', dest='data_idx', type=int, default=0)
    parser.add_argument('--seed', dest='seed', type=int, default=0)
    parser.add_argument('--method', dest='method', type=str, default='fast')

    args = parser.parse_args()

    data_idx = args.data_idx
    np_seed = args.seed
    method = args.method
    np.random.seed(np_seed)

    # all_models = [
    #     {'model': 'ridge'},
    #     # {'model': 'ridgecv'},
    #     {'model': 'svm'},
    #     {'model': 'rf'},
    # ]

    # all_drs = [
    #     {},
    #     {'PCs': 2},
    #     {'PCs': 10},
    #     {'PCs': 20},
    #     {'PCs': 100},
    # ]

    run_params = [
        {
            'aug_params': [
                {},
                # {'conv': 'half', 'space': 'raw'},
                # {'conv': 'half', 'space': 'clr'},
            ],
            'tran_params': [
                {'space': "cnt"},
                {'space': "prop"},
                {'space': "clr"},
            ],
            'dr_params': [
                {},
                # {'PCs': 2},
                # {'PCs': 10},
                # {'PCs': 20},
                # {'PCs': 100},
            ],
            'head_params': [
                {'model': 'ridge'},
                # {'model': 'glmnet'},
                {'model': 'svm'},
                # {'model': 'rf'},
            ],
        },

        {
            'aug_params': [
                {},
                # {'mult': True, 'space': 'raw', 'factor': 10},
                # {'mult': True, 'space': 'prop', 'factor': 10},
                {'mult': True, 'space': '', 'factor': 2},
                {'mult': True, 'space': '', 'factor': 4},
                {'mult': True, 'space': '', 'factor': 6},
                {'mult': True, 'space': '', 'factor': 8},
                {'mult': True, 'space': '', 'factor': 10},
                # {'mult': True, 'factor': 20},
                # {'mult': True, 'space': 'clr', 'factor': 20},
                # {'conv': 'half', 'space': 'raw'},
                # {'conv': 'half', 'space': 'clr'},
                # {'conv': 'rand', 'space': 'raw', 'factor': 2},
                # {'conv': 'rand', 'space': 'raw', 'factor': 5},
                # {'conv': 'rand', 'space': 'raw', 'factor': 10},
                # {'conv': 'rand', 'space': 'raw', 'factor': 20},
                {'conv': 'rand', 'space': 'clr', 'factor': 2},
                {'conv': 'rand', 'space': 'clr', 'factor': 4},
                {'conv': 'rand', 'space': 'clr', 'factor': 6},
                {'conv': 'rand', 'space': 'clr', 'factor': 8},
                {'conv': 'rand', 'space': 'clr', 'factor': 10},
                {'conv': 'rand', 'space': 'prop', 'factor': 2},
                {'conv': 'rand', 'space': 'prop', 'factor': 4},
                {'conv': 'rand', 'space': 'prop', 'factor': 6},
                {'conv': 'rand', 'space': 'prop', 'factor': 8},
                {'conv': 'rand', 'space': 'prop', 'factor': 10},
                # {'conv': 'rand', 'space': 'clr', 'factor': 5},
                # {'conv': 'rand', 'space': 'clr', 'factor': 10},
                # {'conv': 'rand', 'space': 'clr', 'factor': 20},
                # {'conv': 'rand', 'space': 'prop', 'factor': 2},
                # {'conv': 'rand', 'space': 'prop', 'factor': 5},
                # {'conv': 'rand', 'space': 'prop', 'factor': 10},
                # {'conv': 'rand', 'space': 'prop', 'factor': 20},
                # {'comb': 'rand', 'space': 'raw', 'factor': 2},
                # {'comb': 'rand', 'space': 'raw', 'factor': 5},
                # {'comb': 'rand', 'space': 'raw', 'factor': 10},
                # {'comb': 'rand', 'space': 'raw', 'factor': 20},
                # {'comb': 'rand', 'space': 'raw', 'factor': 50},

                {'comb': 'rand', 'space': 'prop', 'factor': 2},
                {'comb': 'rand', 'space': 'prop', 'factor': 4},
                {'comb': 'rand', 'space': 'prop', 'factor': 6},
                {'comb': 'rand', 'space': 'prop', 'factor': 8},
                {'comb': 'rand', 'space': 'prop', 'factor': 10},

                {'mult': True, 'conv': 'rand', 'comb': 'rand', 'space': 'prop', 'factor': 1.33},
                {'mult': True, 'conv': 'rand', 'comb': 'rand', 'space': 'prop', 'factor': 2},
                {'mult': True, 'conv': 'rand', 'comb': 'rand', 'space': 'prop', 'factor': 2.66},
                {'mult': True, 'conv': 'rand', 'comb': 'rand', 'space': 'prop', 'factor': 3.33},
                {'mult': True, 'conv': 'rand', 'comb': 'rand', 'space': 'prop', 'factor': 4},

                {'mult': True, 'conv': 'rand', 'comb': 'rand', 'space': 'clr', 'factor': 1.33},
                {'mult': True, 'conv': 'rand', 'comb': 'rand', 'space': 'clr', 'factor': 2},
                {'mult': True, 'conv': 'rand', 'comb': 'rand', 'space': 'clr', 'factor': 2.66},
                {'mult': True, 'conv': 'rand', 'comb': 'rand', 'space': 'clr', 'factor': 3.33},
                {'mult': True, 'conv': 'rand', 'comb': 'rand', 'space': 'clr', 'factor': 4},
                
                # {'comb': 'rand', 'space': 'clr', 'factor': 2},
                # {'comb': 'rand', 'space': 'clr', 'factor': 5},
                # {'comb': 'rand', 'space': 'clr', 'factor': 10},
                # {'comb': 'rand', 'space': 'clr', 'factor': 20},
                
                # {'comb': 'rand', 'space': 'clr', 'factor': 20},
                # {'conv': 'half', 'comb': 'rand', 'space': 'clr', 'factor': 2},
                # {'conv': 'half', 'comb': 'rand', 'space': 'clr', 'factor': 5},
                # {'conv': 'half', 'comb': 'rand', 'space': 'clr', 'factor': 10},
                # {'conv': 'half', 'comb': 'rand', 'space': 'clr', 'factor': 20},

                # {'conv': 'rand', 'comb': 'rand', 'space': 'clr', 'factor': 2},
                # {'conv': 'rand', 'comb': 'rand', 'space': 'clr', 'factor': 5},
                # {'conv': 'rand', 'comb': 'rand', 'space': 'clr', 'factor': 10},
                # {'conv': 'rand', 'comb': 'rand', 'space': 'clr', 'factor': 20},
                
                # {'mult': True, 'conv': 'rand', 'comb': 'rand', 'space': 'clr', 'factor': 2},
                # {'mult': True, 'conv': 'rand', 'comb': 'rand', 'space': 'clr', 'factor': 3},
                # {'mult': True, 'conv': 'rand', 'comb': 'rand', 'space': 'clr', 'factor': 4},
                # {'conv': 'rand', 'space': 'clr', 'factor': 1000},
            ],
            'tran_params': [
                # {'space': "cnt"},
                {'space': "prop"},
                {'space': "clr"},
            ],
            'dr_params': [
                {},
            ],
            'head_params': [
                {'model': 'ridge'},
                {'model': 'svm'},
                {'model': 'rf'},
                # {'model': 'xgb'},
                {'model': 'mlp'},
                # {'model': 'mlp', 'layers': 3},
                # {'model': 'mlp', 'dp': 0.5},
                # {'model': 'mlp', 'bs': 128},
                # {'model': 'mlp', 'early': True},
                # {'model': 'mlp', 'ep': 200},
                # {'model': 'mlp', 'ep': 200, 'early': True},
                {'model': 'metann'},
            ],
        },
    ]
    

    run_params = [
        {
            'aug_params': [
                {},
                
                # {'mult': True, 'space': '', 'weight': .2, 'factor': 10},
                # {'mult': True, 'space': '', 'weight': .4, 'factor': 10},
                # {'mult': True, 'space': '', 'weight': .6, 'factor': 10},
                # {'mult': True, 'space': '', 'weight': .8, 'factor': 10},
                
                {'conv': 'rand', 'space': 'clr', 'factor': 2},
                {'conv': 'rand', 'space': 'clr', 'factor': 4},
                {'conv': 'rand', 'space': 'clr', 'factor': 6},
                {'conv': 'rand', 'space': 'clr', 'factor': 8},
                {'conv': 'rand', 'space': 'clr', 'factor': 10},
                
                {'conv': 'rand', 'space': 'clr', 'weight': 0.2, 'factor': 10},
                {'conv': 'rand', 'space': 'clr', 'weight': 0.4, 'factor': 10},
                {'conv': 'rand', 'space': 'clr', 'weight': 0.6, 'factor': 10},
                {'conv': 'rand', 'space': 'clr', 'weight': 0.8, 'factor': 10},

                # {'comb': 'rand', 'space': 'clr', 'weight': 0.2, 'factor': 10},
                # {'comb': 'rand', 'space': 'clr', 'weight': 0.4, 'factor': 10},
                # {'comb': 'rand', 'space': 'clr', 'weight': 0.6, 'factor': 10},
                # {'comb': 'rand', 'space': 'clr', 'weight': 0.8, 'factor': 10},

                {'conv': 'rand', 'space': 'clr', 'weight': 0.5, 'factor': 10},
                {'conv': 'rand', 'space': 'prop', 'weight': 0.5, 'factor': 10},
                
                {'comb': 'rand', 'space': 'clr', 'weight': 0.5, 'factor': 10},
                {'mult': True, 'space': '', 'weight': 0.5, 'factor': 10},

                # {'mult': True, 'conv': 'rand', 'comb': 'rand', 'space': 'clr', 'weight': 0.2, 'factor': 10},

                # {'comb': 'rand', 'space': 'prop', 'factor': 2},
                # {'comb': 'rand', 'space': 'prop', 'factor': 4},
                # {'comb': 'rand', 'space': 'prop', 'factor': 6},
                # {'comb': 'rand', 'space': 'prop', 'factor': 8},
                # {'comb': 'rand', 'space': 'prop', 'factor': 10},

                # {'mult': True, 'conv': 'rand', 'comb': 'rand', 'space': 'prop', 'factor': 1.33},
                # {'mult': True, 'conv': 'rand', 'comb': 'rand', 'space': 'prop', 'factor': 2},
                # {'mult': True, 'conv': 'rand', 'comb': 'rand', 'space': 'prop', 'factor': 2.66},
                # {'mult': True, 'conv': 'rand', 'comb': 'rand', 'space': 'prop', 'factor': 3.33},
                # {'mult': True, 'conv': 'rand', 'comb': 'rand', 'space': 'prop', 'factor': 4},

                # {'mult': True, 'conv': 'rand', 'comb': 'rand', 'space': 'clr', 'factor': 1.33},
                # {'mult': True, 'conv': 'rand', 'comb': 'rand', 'space': 'clr', 'factor': 2},
                # {'mult': True, 'conv': 'rand', 'comb': 'rand', 'space': 'clr', 'factor': 2.66},
                # {'mult': True, 'conv': 'rand', 'comb': 'rand', 'space': 'clr', 'factor': 3.33},
                # {'mult': True, 'conv': 'rand', 'comb': 'rand', 'space': 'clr', 'factor': 4},                
            ],
            'tran_params': [
                # {'space': "cnt"},
                {'space': 'prop'},
                # {'space': 'prop', 'std': True},
                # {'space': 'clr'},
                # {'space': 'clr', 'std': True},
            ],
            'dr_params': [
                {},
            ],
            'head_params': [
                {'model': 'deepcoda'},
                # {'model': 'maml'},
                # {'model': 'nn'},
                {'model': 'ridge'},
                {'model': 'svm'},
                {'model': 'rf'},
                {'model': 'xgb'},
                {'model': 'mlp'},
                # {'model': 'mlp', 'layers': 3},
                # {'model': 'mlp', 'dp': 0.5},
                # {'model': 'mlp', 'bs': 128},
                # {'model': 'mlp', 'early': True},
                # {'model': 'mlp', 'ep': 200},
                # {'model': 'mlp', 'ep': 200, 'early': True},
                {'model': 'metann'},
            ],
        },
    ]

    # run_params = [
    #     {
    #         'aug_params': [
    #             {},
                             
    #         ],
    #         'tran_params': [
    #             # {'space': "cnt"},
    #             {'space': 'prop'},
    #             # {'space': 'prop', 'std': True},
    #             # {'space': 'clr'},
    #             # {'space': 'clr', 'std': True},
    #         ],
    #         'dr_params': [
    #             {},
    #         ],
    #         'head_params': [
    #             {'model': 'deepcoda'},
    #             {'model': 'rf'},
    #         ],
    #     },
    # ]

    # run_params = [
    #     {
    #         'aug_params': [
    #             {},
    #             {'conv': 'rand', 'comb': 'rand', 'space': 'raw', 'factor': 5},
    #         ],
    #         'dr_params': [
    #             {},
    #         ],
    #         'head_params': [
    #             {'model': 'rf'},
    #         ],
    #     },
    # ]

    if method == 'maml':
        run_params = [
            {
                'aug_params': [
                    {},
                ],
                'dr_params': [
                    {},
                ],
                'tran_params': [
                    {'space': 'prop'},
                ],
                'head_params': [
                    {'model': 'maml'},
                    {'model': 'maml', 'aug': 'aitch'},
                ],
            },
        ]

    # Load data
    data_list = pd.read_csv('./code/mlrepo12.csv', header=None)
    data_name = data_list.iloc[data_idx, 0]
    data_dir = './in/quinn2020/'
    X_df = pd.read_csv(
        data_dir + data_name + '-x.csv',
        index_col=0,
    )
    y_df = pd.read_csv(
        data_dir + data_name + '-y.csv',
        index_col=0,
    )

    # Remove redundant variables
    X_df = X_df.loc[:, X_df.std(axis=0) > 0]

    # Convert to numpy
    X = X_df.to_numpy()
    # If y has more than one column we just keep the response variable
    if y_df.shape[1] > 1:
        y = y_df['Var']
    else:
        y = y_df.iloc[:, 0]
    y = pd.get_dummies(y, drop_first=True).to_numpy().flatten()

    out = pd.DataFrame()

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=1-train_size, stratify=y
    )

    n_train_orig = X_train.shape[0]

    for params in run_params:
        print(params)

        for aug_params in params['aug_params']:
            aug_seed = 0
            np.random.seed(aug_seed)
            X_train_aug, y_train_aug, w = augment_X(X_train, y_train, aug_params)
            # print("\n\n\n\n\n\n\n\nXXXXXXXXXXXXXXXXX\n\n\n\---------------------")
            # print(X_train.shape)
            # print(X_train_aug.shape)
            # print(X_train_aug[X_train.shape[0]:(X_train.shape[0] + 10), 0:10])

            classes = np.unique(y_train_aug)
            tally = [np.sum(y_train_aug == i) for i in classes]
            weights = np.max(tally) / tally
            weights = weights[y_train_aug.astype(int)]
            w = w * weights

            for tran_param in params['tran_params']:
                X_train_tr, X_test_tr = transform_X(X_train_aug, X_test, tran_param)

                for dr_params in params['dr_params']:
                    X_train_dr, X_test_dr = dim_reduction(X_train_tr, X_test_tr, dr_params)

                    # scaler = StandardScaler()
                    # X_train_dr = scaler.fit_transform(X_train_dr)
                    # X_test_dr = scaler.transform(X_test_dr)
                    
                    for head_params in params['head_params']:
                        print(aug_params)
                        print(head_params)
                        print(X_test_dr.shape)
                        print(np.max(np.abs(X_test_dr)))
                        print('\n')
                        print(y_test[0:10]) # Check we are always seeing the same test set
                        print('\n\n\n')

                        res = evaluate_classifier(X_train_dr, y_train_aug, w, X_test_dr, y_test, head_params)

                        res = {
                            'seed': [np_seed],
                            'data_idx': [data_idx],
                            # 'quickstr': [quickstr],
                            # 'params': [params],
                            'aug_params': [aug_params],
                            'tran_params': [tran_param],
                            'dr_params': [dr_params],
                            'head_params': [head_params],
                            'head': [head_params['model']],
                            'n': X_train_dr.shape[0],
                            'p': X_train_dr.shape[1],
                            **res,
                        }

                        out = pd.concat([out, pd.DataFrame(res)])
                    print(out)

    print(out)

    if method == 'fast':
        folder_name = './out/mlrepo12/'
    elif method == 'maml':
        folder_name = './out/mlrepo12maml/'
    file_name = folder_name + 'd' + str(data_idx) + 's' + str(np_seed)
    out.to_csv(file_name)
    # print(out.sort_values(by='head'))

