import numpy as np
import torch

import random
import pickle
from datetime import datetime
import os
import sys
import argparse

import get_data
import util
from hyperparams import all_hyperparams, hp_ranges
from data_settings import all_settings 
from exp_settings import all_exp_settings
import process_results

from directories import results_dir

seed = 123456789

###################################################################################################
'''
run an experiment with a specific approach on one dataset
'''
def run_exp(dataset_name, data_package, approach, tune, val_gt, settings, split_seed=0, date='0', boot=True):
    #random seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
   
    if not os.path.isdir(results_dir + dataset_name + '/'):
        os.mkdir(results_dir + dataset_name + '/')
    
    data_params = all_settings[dataset_name]
    data_params['n_feats'] = data_package[0][0].shape[1]
    
    hyperparams = all_hyperparams[dataset_name][approach]
    hyperparam_ranges = hp_ranges[dataset_name][approach]
    
    if tune:
        mod, hyperparams, res = util.tune_hyperparams(data_package, approach, data_params, \
                                hyperparam_ranges, val_gt, results_dir, date, dataset_name, boot=boot)
        print(dataset_name, approach, hyperparams, hyperparam_ranges)
        print(res)
        return res
        
    else:
        mod, _, _ = util.get_model(dataset_name, data_package, approach, data_params, hyperparams, val_gt)
        print(util.eval_overall(mod, data_package[1], data_params, use_gt=True))
   

'''
run an full set of experiments from a list of approaches on one dataset
'''
def run_bulk_exp(dataset_name, approaches, val_gt, split_seed, date, setting, tune, boot=True):
    time_now = datetime.now()

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    data_params = all_settings[dataset_name]
    data_package = get_data.get_dataset(dataset_name, setting, params=data_params, split_seed=split_seed)

    for i in range(len(approaches)):
        run_exp(dataset_name, data_package, approaches[i], tune, val_gt, setting, split_seed=split_seed, date=date, boot=boot)

    print('################################################################')
    print('done!')
    print(dataset_name)

    end_time = datetime.now()
    print('time elapsed: ', end_time - time_now)


'''
run bulk experiments on different versions of a dataset (example - varying the noise rate on the synthetic data)
'''
def run_bulk_and_vary_setting(dataset_name, approaches, val_gt, split_seed, date, settings, tune, vary=False, boot=True):
    num_settings = len(settings)

    for i in range(num_settings):
        #multiplier = 1
        setting = settings[i]
        #setting['anchor_props'] = [setting['anchor_props'][0]*multiplier, setting['anchor_props'][1]*multiplier]
        run_bulk_exp(dataset_name, approaches, val_gt, split_seed, date + '-' + str(i), setting, tune, boot=boot)
    print('done with varying the setting')

    if 1==1:#(not vary): and boot:
        for i in range(num_settings):
            setting = settings[i]
            print('setting ', setting)
            process_results.postprocess_results(dataset_name, approaches, date + '-' + str(i), val_gt)
            #process_results.make_bar_graph_split(dataset_name, approaches, date + '-' + str(i), val_gt)
            print('************************************************************************')


###################################################################################################
'''
main block

what to check before running a mass experiment:
    dataset name
    the date of that gets prepended to the results file name (update so don't overwrite previous day's results)

datasets: synth_random, synth_feat1, synth_feat2, MIMIC-ARF-{random,feat1,feat2}, MIMIC-Shock-{same as arf}, 
          adult-{same as mimic}, compas-{same as mimic}
main approaches: 'baseline_plain_clean', 'baseline_plain', 'baseline_sln_filt', 'baseline_transition', 'baseline_transit_conf', 'baseline_fair_gpl', 'baseline_js_loss', 'proposed1'
other approaches: 'baseline_filt', 'baseline_sln', 'baseline_fair_reweight', 'anchor'

feat1 - unbiased feature noise
feat2 - biased feature noise

experiments: baseline, sa_size, sa_bias
'''
if __name__ == '__main__':
    #random seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    time_now = datetime.now()
    date = '0520' 

    parser = argparse.ArgumentParser(description='Datasets and experimental conditions for noisy labels experiments')
    parser.add_argument('--dataset')
    parser.add_argument('--experiment', default='baseline')
    args = parser.parse_args()

    dataset_name = args.dataset
    tune = True
    val_gt = True
    boot = False#True
    #approaches = ['baseline_plain', 'baseline_plain_clean'][-1:]
    #approaches = ['baseline_plain', 'baseline_transition', 'proposed1'][-1:]
    approaches = ['baseline_plain', 'baseline_sln_filt', 'baseline_transition', 'baseline_fair_gpl', \
                  'baseline_transit_conf', 'baseline_js_loss', 'proposed1', 'baseline_plain_clean']
    #approaches = ['baseline_js_loss']
    #approaches = ['baseline_plain', 'baseline_sln_filt', 'baseline_transition', 'baseline_fair_gpl', \
    #              'baseline_transit_conf', 'proposed1', 'baseline_plain_clean']
                 
    vary_setting = 'sa' in args.experiment
    exp_setting = all_exp_settings[dataset_name]
    seeds = [0,1,2,3,4]#,5,6,7,8,9]
    #seeds = [5,6,7,8,9]
    #seeds = [8,9]
    if not vary_setting:
        #7 for 10% and representative 
        #spot checks: 1/4/7 size, 10/15/19 bias, 40/45 min prop, 21/28 rate, 31/38 disp
        exp_setting = [exp_setting[10]]#, exp_setting[15], exp_setting[19]]#[1:]
        seeds = [0]#,1,2,3,4,5,6,7,8,9]
    else:
        if args.experiment == 'sa_size':
            exp_setting = exp_setting[0:10]
        elif args.experiment == 'sa_bias':
            exp_setting = exp_setting[10:20]
        elif args.experiment == 'sa_minprop':
            exp_setting = exp_setting[40:50] 
        elif args.experiment == 'sa_noise_rate':
            exp_setting = exp_setting[20:30] #20-29 inclusive
        elif args.experiment == 'sa_noise_disp':
            exp_setting = exp_setting[30:40] #30-39 inclusive
        elif args.experiment == 'sa_noise_rate_rand':
            exp_setting = exp_setting[20:30]
    
    for i in range(len(seeds)):
        split_seed = seeds[i]
        #if len(seeds) > 1:
        date, boot = date + 's' + str(split_seed) + args.experiment, False
        run_bulk_and_vary_setting(dataset_name, approaches, val_gt, split_seed, date, exp_setting, tune, boot=boot)
        date = date[:4]
        
    end_time = datetime.now()
    print('time elapsed: ', end_time - time_now)


########################################
'''
anaconda stuff

# To activate this environment, use
#
#     $ conda activate generic_env
#
# To deactivate an active environment, use
#
#     $ conda deactivate
'''
