"""MIMIC experiment code"""
import argparse
import numpy as np 
import matplotlib.pyplot as plt
import torch

import mmd_estimators as mmde
import bounds as bd
import pandas as pd
import pickle


def epsilon_contaminate(x, y, epsilon):
    '''Function to create random epsilon contaminated data'''
    x_corr = np.copy(x)
    y_corr = np.copy(y)

    m = int(y_corr.shape[0] * epsilon)
    ids = np.random.choice(x_corr.shape[0], m, replace=False)
    print(f'Corrupted samples length {m}')

    y_corr = np.vstack([y_corr, x_corr[ids,:]])
    x_corr = np.delete(x_corr, ids, axis=0)

    return x_corr, y_corr, m
    
def nonrandom_epsilon_contaminate(x, y, epsilon):
    '''Function to create nonrandom epsilon contaminated data'''
    x_corr = np.copy(x)
    y_corr = np.copy(y)
    
    wd = mmde.witness_simple(x,y,x)

    m = int(y_corr.shape[0] * epsilon)
    
    try:
        worst_case_corr = np.argpartition(-wd, 2* m)[:2 * m]
        ids = np.random.choice(range(2*m), m, replace=False)
        y_corr = np.vstack([y_corr, x[worst_case_corr[ids],:]])
        x_corr = np.delete(x_corr, worst_case_corr[ids], axis=0)
    except:
        ids = np.random.choice(x_corr.shape[0], m, replace=False)
        y_corr = np.vstack([y_corr, x_corr[ids,:]])
        x_corr = np.delete(x_corr, ids, axis=0)
    print(f'Corrupted samples length {m}')

    return x_corr, y_corr, m



def main(data_dir, n_samples, eps_list, iters, n_steps, n_rep, biased_mmd=False): 
    '''Main experiment function '''
    # --- Set seed
    np.random.seed(0)
    # --- load data. Pickle file of numpy array of model logits of form [data from x, data from y]
    with open(data_dir, "rb") as f:
        [x_full, y_full] = pickle.load(f)
        
    print(x_full.shape)
    print(y_full.shape)
    
    # --- store objects
    qno_FCR = np.zeros((len(eps_list), iters))
    qno_interval = np.zeros((len(eps_list), iters))

    sm_FCR = np.zeros((len(eps_list), iters))
    sm_interval = np.zeros((len(eps_list), iters))

    boot_FCR = np.zeros((len(eps_list), iters))
    boot_interval = np.zeros((len(eps_list), iters))

    SSD_FCR = np.zeros((len(eps_list), iters))
    SSD_interval = np.zeros((len(eps_list), iters))
    
    step1_FCR = np.zeros((len(eps_list), iters))
    step1_interval = np.zeros((len(eps_list), iters))
    
    sqno_FCR = np.zeros((len(eps_list), iters))
    sqno_interval = np.zeros((len(eps_list), iters))
    
    i_count = 0
    # --- loop over trials
    for i in range(iters):
        #Sample from each, fix for each trial
        x = x_full[np.random.choice(x_full.shape[0], round(n_samples / 2), replace=True)]
        y = y_full[np.random.choice(y_full.shape[0], round(n_samples / 2), replace=True)]
        
        # --- set the bandwidth based on median heuristic 
        gamma = mmde.kernelwidthPair(x,y)
    
        # --- get the true MMD
        if biased_mmd:
            mmd_true = mmde.mmd_b(x, y, gamma=gamma)
        else:
            mmd_true = mmde.mmd_u(x, y, gamma=gamma)
        print(f'True MMD: {mmd_true:.3f}')
    
        count = 0
        # --- loop over epsilon
        for epsilon in eps_list:
            print(f'| ------ Running epsilon = {epsilon} ------ |')
            x_c, y_c, m = nonrandom_epsilon_contaminate(x, y, epsilon)
            # x_c, y_c, m = epsilon_contaminate(x, y, epsilon)
    
            effective_epsilon = m / y_c.shape[0]
    
            # --- get the naive MMD: 
            if biased_mmd:
                mmd_naive = mmde.mmd_b(x_c, y_c, gamma=gamma)
            else:
                mmd_naive = mmde.mmd_u(x_c, y_c, gamma=gamma)
            print(f'Naive MMD: {mmd_naive:.3f}')
            
            # get QNO bounds 
            # initialize at sd picks. 
            mh_lower_pick, mh_upper_pick = bd.extreme_picks(x_c, y_c, m, gamma)
    
            qno_lower, qno_upper = bd.opt_bounds(x_c, y_c, m, 
                gamma, x0_lower = mh_lower_pick, x0_upper = mh_upper_pick, 
                biased_mmd=False, tol_fun=1e-10, disp=False, maxiter=5000)
    
            print(f'QNO bds: {qno_lower:.3f}, {qno_upper:.3f}')
    
            # -- get submodular bounds 
            sm_lower, sm_upper = bd.submodular_bounds(x_c, y_c, effective_epsilon, 
                gamma, return_samples=False, biased_mmd=biased_mmd,reg="None")
            print(f'SM bds: {sm_lower:.3f}, {sm_upper:.3f}')
    
            # --- get bootstrap bounds
            boot_lower, boot_upper = bd.bootstrap_bounds(x_c, y_c, 
                effective_epsilon, gamma, n_rep=n_rep, biased_mmd=biased_mmd)
            print(f'Boot bds: {boot_lower:.3f}, {boot_upper:.3f}')
    
            # --- get stepwise stochastic dominance bounds 
            step_lower, step_upper = bd.stepwise_seq_bounds(x_c, y_c, m, n_steps, gamma, biased_mmd=biased_mmd, return_samples=False)
            print(f'S-SD bds: {step_lower:.3f}, {step_upper:.3f}')
            
            #get stochastic dominance bounds 
            sd_lower, sd_upper = bd.stepwise_seq_bounds(x_c, y_c, m, 1, gamma, biased_mmd=biased_mmd, return_samples=False)
            print(f'SD bds: {sd_lower:.3f}, {sd_upper:.3f}')
            
            #Get stepwise qno bounds
            sqno_lower, sqno_upper = bd.opt_bounds_sequential(x_c, y_c, m, n_steps, gamma, biased_mmd=False, 
                                                                    tol_fun=1e-10, disp=False, maxiter=5000)
                                                                    
            print(f'S-QNO bds: {sqno_lower:.3f}, {sqno_upper:.3f}')
            
            #Get stats
            if((qno_upper >= mmd_true) and (mmd_true >= qno_lower)): #If true contained
                qno_FCR[count][i_count] = qno_FCR[count][i_count] + 1
                
            if((sm_upper >= mmd_true) and (mmd_true >= sm_lower)): #If true contained
                sm_FCR[count][i_count] = sm_FCR[count][i_count] + 1
    
            if((boot_upper >= mmd_true) and (mmd_true >= boot_lower)): #If true contained
                boot_FCR[count][i_count] = boot_FCR[count][i_count] + 1
    
            if((step_upper >= mmd_true) and (mmd_true >= step_lower)): #If true contained
                SSD_FCR[count][i_count] = SSD_FCR[count][i_count] + 1
                
            if((sd_upper >= mmd_true) and (mmd_true >= sd_lower)): #If true contained
                step1_FCR[count][i_count] = step1_FCR[count][i_count] + 1
    
            if((sqno_upper >= mmd_true) and (mmd_true >= sqno_lower)): #If true contained
                sqno_FCR[count][i_count] = sqno_FCR[count][i_count] + 1
                
            qno_interval[count][i_count] = (qno_upper - qno_lower)
            sm_interval[count][i_count] = (sm_upper - sm_lower)
            boot_interval[count][i_count] = (boot_upper - boot_lower)
            SSD_interval[count][i_count] = (step_upper - step_lower)
            step1_interval[count][i_count] = (sd_upper - sd_lower)
            sqno_interval[count][i_count] = (sqno_upper - sqno_lower)
            
            count = count + 1
        i_count = i_count + 1

    print("\n\n\nEpsilons: " + str(eps_list) + "\n")
    np.set_printoptions(precision=3)
    
    print("S-SD:")
    print("MIW: \t" + str(np.mean(SSD_interval, axis = 1)))
    print("MIW SE: " + str(np.std(SSD_interval, axis = 1) / np.sqrt(iters)))
    print("FCR: \t" + str(1 - np.mean(SSD_FCR, axis = 1)))
    print("FCR SE: " + str(np.std(SSD_FCR, axis = 1) / np.sqrt(iters)))

    print("\nS-QNO:")
    print("MIW: \t" + str(np.mean(sqno_interval, axis = 1)))
    print("MIW SE: " + str(np.std(sqno_interval, axis = 1) / np.sqrt(iters)))
    print("FCR: \t" + str(1 - np.mean(sqno_FCR, axis = 1)))
    print("FCR SE: " + str(np.std(sqno_FCR, axis = 1) / np.sqrt(iters)))

    print("\nQNO:")
    print("MIW: \t" + str(np.mean(qno_interval, axis = 1)))
    print("MIW SE: " + str(np.std(qno_interval, axis = 1)  / np.sqrt(iters)))
    print("FCR: \t" + str(1 - np.mean(qno_FCR, axis = 1)))
    print("FCR SE:" + str(np.std(qno_FCR, axis = 1) / np.sqrt(iters)))
    
    print("SD:")
    print("MIW: \t" + str(np.mean(step1_interval, axis = 1)))
    print("MIW SE: " + str(np.std(step1_interval, axis = 1) / np.sqrt(iters)))
    print("FCR: \t" + str(1 - np.mean(step1_FCR, axis = 1)))
    print("FCR SE: " + str(np.std(step1_FCR, axis = 1) / np.sqrt(iters)))

    print("\nSM:")
    print("MIW: \t" + str(np.mean(sm_interval, axis = 1)))
    print("MIW SE: " + str(np.std(sm_interval, axis = 1) / np.sqrt(iters)))
    print("FCR: \t" + str(1 - np.mean(sm_FCR, axis = 1)))
    print("FCR SE: " + str(np.std(sm_FCR, axis = 1) / np.sqrt(iters)))

    print("\nBoot:")
    print("MIW: \t" + str(np.mean(boot_interval, axis = 1)))
    print("MIW SE: " + str(np.std(boot_interval, axis = 1)  / np.sqrt(iters)))
    print("FCR: \t" + str(1 - np.mean(boot_FCR, axis = 1)))
    print("FCR SE: " + str(np.std(boot_FCR, axis = 1)  / np.sqrt(iters)))



if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument('--data_dir', '-data_dir',
        help="Directory where the predictions and labels are saved",
        type=str)
        
    parser.add_argument('--n_samples', '-n_samples',
        help="Number of samples to be drawn from the dataset. Samples are drawn with even class balance.",
        type=int, 
        default=100)
        
    parser.add_argument('--eps_list', '-eps_list',
        help="comma separated list of epsilons to use for corruption",
        type=str, 
        default='0.01,0.05,0.1,0.5')
        
    parser.add_argument('--iters', '-iters',
        help="Number of individual experiments of subsets of FOREST",
        type=int,
        default=100)
        
    parser.add_argument('--n_steps', '-n_steps',
        help="Number of steps for both S-QNO and S-SD.",
        type=int, 
        default=10)

    parser.add_argument('--n_rep', '-n_rep',
        help="Number of bootstrap replicates for the bootstrap bounds",
        type=int, 
        default=500)
        
    parser.add_argument('--biased_mmd', '-biased_mmd',
        help="Boolean flag to change calculations between the biased and unbiased MMD. Default is False.",
        type=bool, 
        default=False)
    
    args = vars(parser.parse_args())
    args['eps_list'] = [float(eps) for eps in args['eps_list'].split(',')]
    main(**args)