"""BIO experiment code"""
import argparse
import numpy as np 
import matplotlib.pyplot as plt
import torch

import mmd_estimators as mmde
import bounds as bd
import pandas as pd
import pickle


def epsilon_contaminate(x, y, epsilon):
    '''Function to create random epsilon contaminated data'''
    x_corr = np.copy(x)
    y_corr = np.copy(y)

    m = int(y_corr.shape[0] * epsilon)
    ids = np.random.choice(x_corr.shape[0], m, replace=False)
    print(f'Corrupted samples length {m}')

    y_corr = np.vstack([y_corr, x_corr[ids,:]])
    x_corr = np.delete(x_corr, ids, axis=0)

    return x_corr, y_corr, m
    
def nonrandom_epsilon_contaminate(x, y, epsilon):
    '''Function to create nonrandom epsilon contaminated data'''
    x_corr = np.copy(x)
    y_corr = np.copy(y)
    
    wd = mmde.witness_simple(x,y,x)

    m = int(y_corr.shape[0] * epsilon)
    
    try:
        worst_case_corr = np.argpartition(-wd, 2* m)[:2 * m]
        ids = np.random.choice(range(2*m), m, replace=False)
        y_corr = np.vstack([y_corr, x[worst_case_corr[ids],:]])
        x_corr = np.delete(x_corr, worst_case_corr[ids], axis=0)
    except:
        ids = np.random.choice(x_corr.shape[0], m, replace=False)
        y_corr = np.vstack([y_corr, x_corr[ids,:]])
        x_corr = np.delete(x_corr, ids, axis=0)
    print(f'Corrupted samples length {m}')

    return x_corr, y_corr, m

def bootstrap_sample(df, bootstrap_num):
    '''Function to create bootstrap resamples of the AML ALL data, given a dataframe'''

    data_0 = df[df.cancer == "ALL"]
    data_0 = data_0.sample(n = bootstrap_num[0], replace = True) #bootstrap

    data_1 = df[df.cancer == "AML"]
    data_1 = data_1.sample(n = bootstrap_num[1], replace = True) #bootstrap
    
    df = pd.concat([data_0, data_1])
    
    #Change string identifiers to 0,1.
    df.cancer.replace(['AML', 'ALL'], [0, 1], inplace=True)

    return df


def main(train_dir, test_dir, label_dir, eps_list, iters, n_steps, n_rep, biased_mmd=False): 
    '''Main experiment function '''
    # --- Set seed
    np.random.seed(0)
    
    # --- read in training, test, and label files
    train_data = pd.read_csv(train_dir, engine='python')
    test_data = pd.read_csv(test_dir, engine='python')
    labels = pd.read_csv(label_dir)
    
    # --- preprocess
    cols_train = [col for col in train_data.columns if "call" in col]
    cols_test = [col for col in test_data.columns if "call" in col]
    train_data.drop(cols_train, axis=1, inplace=True)
    test_data.drop(cols_test, axis=1, inplace=True)
    train_data = train_data.T
    test_data = test_data.T
    
    train_data.columns = test_data.iloc[1].values
    train_data.drop(["Gene Description", "Gene Accession Number"], axis=0, inplace=True)
    test_data.columns = test_data.iloc[1].values
    test_data.drop(["Gene Description", "Gene Accession Number"], axis=0, inplace=True)
    
    train_data["patient"] = train_data.index.values
    test_data["patient"] = test_data.index.values
    
    train_data = train_data.astype("int32")
    test_data = test_data.astype("int32")
    
    train_data = pd.merge(train_data, labels, on="patient")
    test_data = pd.merge(test_data, labels, on="patient")
    
    data = pd.concat([train_data, test_data])
    
    # --- store objects
    qno_FCR = np.zeros((len(eps_list), iters))
    qno_interval = np.zeros((len(eps_list), iters))

    sm_FCR = np.zeros((len(eps_list), iters))
    sm_interval = np.zeros((len(eps_list), iters))

    boot_FCR = np.zeros((len(eps_list), iters))
    boot_interval = np.zeros((len(eps_list), iters))

    SSD_FCR = np.zeros((len(eps_list), iters))
    SSD_interval = np.zeros((len(eps_list), iters))
    
    step1_FCR = np.zeros((len(eps_list), iters))
    step1_interval = np.zeros((len(eps_list), iters))
    
    sqno_FCR = np.zeros((len(eps_list), iters))
    sqno_interval = np.zeros((len(eps_list), iters))
    
    i_count = 0
    # --- loop over trials
    for i in range(iters):
        # --- load data.
        resample = bootstrap_sample(data, (47,25))
        
        x = resample[resample.cancer == 1].reset_index(drop=True)
        y = resample[resample.cancer == 0].reset_index(drop=True)
    
        # --- Drop patient ID and AML/ALL label
        x = x.drop(columns=['patient', 'cancer']).values
        y = y.drop(columns=['patient', 'cancer']).values
        
        # --- set the bandwidth based on median heuristic 
        gamma = mmde.kernelwidthPair(x,y)
    
        # --- get the true MMD
        if biased_mmd:
            mmd_true = mmde.mmd_b(x, y, gamma=gamma)
        else:
            mmd_true = mmde.mmd_u(x, y, gamma=gamma)
        print(f'True MMD: {mmd_true:.3f}')
    
        count = 0
        # --- loop over epsilon
        for epsilon in eps_list:
            print(f'| ------ Running epsilon = {epsilon} ------ |')
            # --- Simulated epsilon contamination
            x_c, y_c, m = nonrandom_epsilon_contaminate(x, y, epsilon)
            # x_c, y_c, m = epsilon_contaminate(x, y, epsilon)
    
            effective_epsilon = m / y_c.shape[0]
    
            # --- get the naive MMD: 
            if biased_mmd:
                mmd_naive = mmde.mmd_b(x_c, y_c, gamma=gamma)
            else:
                mmd_naive = mmde.mmd_u(x_c, y_c, gamma=gamma)
            print(f'Naive MMD: {mmd_naive:.3f}')
            
            # get QNO bounds 
            # initialize at sd picks. 
            mh_lower_pick, mh_upper_pick = bd.extreme_picks(x_c, y_c, m, gamma)
    
            qno_lower, qno_upper = bd.opt_bounds(x_c, y_c, m, 
                gamma, x0_lower = mh_lower_pick, x0_upper = mh_upper_pick, 
                biased_mmd=False, tol_fun=1e-10, disp=False, maxiter=5000)
    
            print(f'QNO bds: {qno_lower:.3f}, {qno_upper:.3f}')
    
            # -- get submodular bounds 
            sm_lower, sm_upper = bd.submodular_bounds(x_c, y_c, effective_epsilon, 
                gamma, return_samples=False, biased_mmd=biased_mmd,reg="logdet")
            print(f'SM bds: {sm_lower:.3f}, {sm_upper:.3f}')
    
            # --- get bootstrap bounds
            boot_lower, boot_upper = bd.bootstrap_bounds(x_c, y_c, 
                effective_epsilon, gamma, n_rep=n_rep, biased_mmd=biased_mmd)
            print(f'Boot bds: {boot_lower:.3f}, {boot_upper:.3f}')
    
            # --- get stepwise stochastic dominance bounds 
            step_lower, step_upper = bd.stepwise_seq_bounds(x_c, y_c, m, n_steps, gamma, biased_mmd=biased_mmd, return_samples=False)
            print(f'S-SD bds: {step_lower:.3f}, {step_upper:.3f}')
            
            # --- get stochastic dominance bounds 
            sd_lower, sd_upper = bd.stepwise_seq_bounds(x_c, y_c, m, 1, gamma, biased_mmd=biased_mmd, return_samples=False)
            print(f'SD bds: {sd_lower:.3f}, {sd_upper:.3f}')
            
            # --- get stepwise qno bounds
            sqno_lower, sqno_upper = bd.opt_bounds_sequential(x_c, y_c, m, n_steps, gamma, biased_mmd=False, 
                                                                    tol_fun=1e-10, disp=False, maxiter=5000)
                                                                    
            print(f'S-QNO bds: {sqno_lower:.3f}, {sqno_upper:.3f}')
            
            #Get stats
            if((qno_upper >= mmd_true) and (mmd_true >= qno_lower)): #If true contained
                qno_FCR[count][i_count] = qno_FCR[count][i_count] + 1
                
            if((sm_upper >= mmd_true) and (mmd_true >= sm_lower)): #If true contained
                sm_FCR[count][i_count] = sm_FCR[count][i_count] + 1
    
            if((boot_upper >= mmd_true) and (mmd_true >= boot_lower)): #If true contained
                boot_FCR[count][i_count] = boot_FCR[count][i_count] + 1
    
            if((step_upper >= mmd_true) and (mmd_true >= step_lower)): #If true contained
                SSD_FCR[count][i_count] = SSD_FCR[count][i_count] + 1
                
            if((sd_upper >= mmd_true) and (mmd_true >= sd_lower)): #If true contained
                step1_FCR[count][i_count] = step1_FCR[count][i_count] + 1
    
            if((sqno_upper >= mmd_true) and (mmd_true >= sqno_lower)): #If true contained
                sqno_FCR[count][i_count] = sqno_FCR[count][i_count] + 1
                
            qno_interval[count][i_count] = (qno_upper - qno_lower)
            sm_interval[count][i_count] = (sm_upper - sm_lower)
            boot_interval[count][i_count] = (boot_upper - boot_lower)
            SSD_interval[count][i_count] = (step_upper - step_lower)
            step1_interval[count][i_count] = (sd_upper - sd_lower)
            sqno_interval[count][i_count] = (sqno_upper - sqno_lower)
            
            count = count + 1
        i_count = i_count + 1

    print("\n\n\nEpsilons: " + str(eps_list) + "\n")
    np.set_printoptions(precision=3)
    
    print("S-SD:")
    print("MIW: \t" + str(np.mean(SSD_interval, axis = 1)))
    print("MIW SE: " + str(np.std(SSD_interval, axis = 1) / np.sqrt(iters)))
    print("FCR: \t" + str(1 - np.mean(SSD_FCR, axis = 1)))
    print("FCR SE: " + str(np.std(SSD_FCR, axis = 1) / np.sqrt(iters)))

    print("\nS-QNO:")
    print("MIW: \t" + str(np.mean(sqno_interval, axis = 1)))
    print("MIW SE: " + str(np.std(sqno_interval, axis = 1) / np.sqrt(iters)))
    print("FCR: \t" + str(1 - np.mean(sqno_FCR, axis = 1)))
    print("FCR SE: " + str(np.std(sqno_FCR, axis = 1) / np.sqrt(iters)))

    print("\nQNO:")
    print("MIW: \t" + str(np.mean(qno_interval, axis = 1)))
    print("MIW SE: " + str(np.std(qno_interval, axis = 1)  / np.sqrt(iters)))
    print("FCR: \t" + str(1 - np.mean(qno_FCR, axis = 1)))
    print("FCR SE:" + str(np.std(qno_FCR, axis = 1) / np.sqrt(iters)))
    
    print("SD:")
    print("MIW: \t" + str(np.mean(step1_interval, axis = 1)))
    print("MIW SE: " + str(np.std(step1_interval, axis = 1) / np.sqrt(iters)))
    print("FCR: \t" + str(1 - np.mean(step1_FCR, axis = 1)))
    print("FCR SE: " + str(np.std(step1_FCR, axis = 1) / np.sqrt(iters)))

    print("\nSM:")
    print("MIW: \t" + str(np.mean(sm_interval, axis = 1)))
    print("MIW SE: " + str(np.std(sm_interval, axis = 1) / np.sqrt(iters)))
    print("FCR: \t" + str(1 - np.mean(sm_FCR, axis = 1)))
    print("FCR SE: " + str(np.std(sm_FCR, axis = 1) / np.sqrt(iters)))

    print("\nBoot:")
    print("MIW: \t" + str(np.mean(boot_interval, axis = 1)))
    print("MIW SE: " + str(np.std(boot_interval, axis = 1)  / np.sqrt(iters)))
    print("FCR: \t" + str(1 - np.mean(boot_FCR, axis = 1)))
    print("FCR SE: " + str(np.std(boot_FCR, axis = 1)  / np.sqrt(iters)))



if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument('--train_dir', '-train_dir',
        help="Directory where the BIO training data is saved",
        type=str)
        
    parser.add_argument('--test_dir', '-test_dir',
        help="Directory where the BIO test data is saved",
        type=str)
        
    parser.add_argument('--label_dir', '-label_dir',
        help="Directory where the BIO label data is saved",
        type=str)
        
    parser.add_argument('--eps_list', '-eps_list',
        help="comma separated list of epsilons to use for corruption",
        type=str, 
        default='0.01,0.05,0.1,0.5')
        
    parser.add_argument('--iters', '-iters',
        help="Number of individual experiments of subsets of FOREST",
        type=int,
        default=100)
        
    parser.add_argument('--n_steps', '-n_steps',
        help="Number of steps for both S-QNO and S-SD.",
        type=int, 
        default=10)

    parser.add_argument('--n_rep', '-n_rep',
        help="Number of bootstrap replicates for the bootstrap bounds",
        type=int, 
        default=500)
        
    parser.add_argument('--biased_mmd', '-biased_mmd',
        help="Boolean flag to change calculations between the biased and unbiased MMD. Default is False.",
        type=bool, 
        default=False)
    
    args = vars(parser.parse_args())
    args['eps_list'] = [float(eps) for eps in args['eps_list'].split(',')]
    main(**args)