import torch
from stability import exp_stability
import os
from itertools import product
import pickle as pkl
import yaml
from datetime import datetime
from pytz import timezone
import numpy as np


import argparse

def save_yaml_config(config, path):
    """Load the config file in yaml format.
    Args:
        config (dict object): Config.
        path (str): Path to save the config.
    """
    with open(path, 'w') as outfile:
        yaml.dump(config, outfile, default_flow_style=False)

def get_datetime_str(add_random_str=False):
    """Get string based on current datetime."""
    datetime_str = datetime.now(timezone('EST')).strftime('%Y-%m-%d_%H-%M-%S-%f')[:-3]
    if add_random_str:
        # Add a random integer after the datetime string
        # This would largely decrease the probability of having the same directory name
        # when running many experiments concurrently
        return '{}_{}'.format(datetime_str, np.random.randint(low=1, high=10000))
    else:
        return datetime_str

parser = argparse.ArgumentParser(description='Training setting')
parser.add_argument('--model', default='vgan', choices={'vgan', 'dcgan'}, help='GAN structure (default: vgan)')
# nsgan is non-saturating gan
parser.add_argument('--loss', default='gan', choices={'bce', 'gan', 'wgan'}, help='GAN loss (default: gan)')
parser.add_argument('--data', default='mnist', choices={'mnist', 'cifar10'}, help='Dataset (default: mnist)')
parser.add_argument('--metric', default='fro', choices={'fro', 'ned'}, help='Distance metric')
parser.add_argument('--node', default=5, type=int, help='number of nodes')
parser.add_argument('--lr', default=0.0002, type=float, help='learning rate')
parser.add_argument('--sample_size', default=0, type=int, help='sample size')
parser.add_argument('-p', '--path', help='path of dataset, default=\'../data\'',
                        dest='path', type=str, default='../data')
parser.add_argument('--mode', default='ring', help='mode of graph')
parser.add_argument('--exp_type', default='different_lr', help='type of experiment')
args = parser.parse_args()

data_path = os.path.join(os.getcwd(), 'data')
if not os.path.exists(data_path):
    os.mkdir(data_path)

args.work_dir = os.path.join('experiments',
                                args.data,
                                args.exp_type,
                                get_datetime_str(add_random_str=True)
                                )
if not os.path.exists(args.work_dir):
      os.makedirs(args.work_dir)
   
if args is not None:
    save_yaml_config(vars(args), path='{}/args_info.yaml'.format(args.work_dir))

# learning parameters
options = dict()
options['model'] = args.model
options['loss'] = args.loss
options['data'] = args.data
options['metric'] = args.metric
options['node'] = args.node
options['learning_rate'] = args.lr
options['sample_size'] = args.sample_size
options['path'] = args.path
options['batch_size'] = 100 # large batch size for GPU can speed up
options['num_epochs'] = 10
options['iterations'] = 100
options['nz'] = 128 # latent vector size
options['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
options['mode'] = args.mode
print('device: %s' % options['device'])

# data parameters
manual_seed_list = [1,2,3,4]
remove_index_list = [3, 7, 15, 20]
# generator_seed_list = [3, 864, 77777, 9345634, 2145483647]
# manual_seed_list = manual_seed_list[:2]
# remove_index_list = remove_index_list[:2]
# generator_seed_list = generator_seed_list[:2]

# initialize saving
num_cv = len(manual_seed_list) * len(remove_index_list)
# gen_diff_cv = torch.zeros(num_cv, 4, options['num_epochs']) # 4 is the number of layers
# dis_diff_cv = torch.zeros(num_cv, 4, options['num_epochs'])
gen_diff_cv = torch.zeros(num_cv, 4, options['iterations']) # 4 is the number of layers
dis_diff_cv = torch.zeros(num_cv, 4, options['iterations'])

# cross validation
for i, (manual_seed, remove_index) in enumerate(product(manual_seed_list, remove_index_list)):
    print('Current cross validation number: %d / Total: %d' %(i+1, num_cv))
    gen_diff, dis_diff = exp_stability(remove_index, manual_seed, options)
    gen_diff_cv[i] = gen_diff
    dis_diff_cv[i] = dis_diff

results_gen = dict()
results_dis = dict()
results_gen['mean'] = torch.mean(gen_diff_cv,0)
results_gen['std'] = torch.std(gen_diff_cv,0)
results_dis['mean'] = torch.mean(dis_diff_cv,0)
results_dis['std'] = torch.std(dis_diff_cv,0)
results = dict()
results['gen'] = results_gen
results['dis'] = results_dis
results['options'] = options


# res_path = os.path.join(os.getcwd(), 'res')
# if not os.path.exists(res_path):
#     os.mkdir(res_path)
f = open(os.path.join(args.work_dir, 'result.pkl'), 'wb')
pkl.dump(results, f)
f.close()

