"""Example config file for large LR experiments on Cloud."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
from collections import namedtuple, OrderedDict
import uuid

JOB_DIR_PARAMS = OrderedDict([('job_id', ''), ('dataset', ''), ('model', ''), ('depth', 'd'), ('nonlinearity', ''), ('loss_type', ''), ('opt', ''),('n_train','n')])
JOB_ID_PARAMS = OrderedDict([('width', 'w'), ('learning_rate', 'lr'), ('batch_size', 'bs')])
ALIAS={'d':'depth','w':'width','bs':'batch_size','lr':'learning_rate','jobid':'job_id','n':'n_train','steps':'train_steps','total_steps':'train_steps','num_steps':'train_steps','jobid':'job_id'}
PARAMS0=['w_std', 'b_std', 'output_dim', 'nonlinearity',	'linearize', 'loss_type', 'seed', 'data_seed', 'opt', 'learning_rate',  'n_train', 'augment', 'upload_to_cloud', 'time_unit','logdir','savew_dir', 'std_meas', 'save_weights', 'load_weights', 'alias','verbose','similarjobsfolder',
    'job_id', 'job_dir_params', 'job_id_params','data_dir','meas_simple','std_meas_size','wrnsch','momcos','NTK_norm','distgrad','L2','randomlabels','randomdata','shuffledata','savew_name','meashesseval','n_test','physical','exit_acc']#'log_meas','earlymicromeas']
def get_default_params():
    
    dic={}
    # Standard deviation of weight initialization (constant O(1) piece). Set uniform throughout network.
    dic['w_std'] = math.sqrt(2.0)
    # Standard deviation of bias initialization (constant O(1) piece). Set uniform throughout network.
    dic['b_std'] = 0.0
    # Specify output dim (10 class or 1 for binary classifcation)
    dic['output_dim'] = 10
    # Nonlinearity. With NT Lib, options are:'Relu','Erf,'Identity'
    dic['nonlinearity'] = 'Relu'
    # Time, in steps, at which to linearize model. Set to None for no linearization (nonlinear model).
    
    dic['linearize'] = False
    # Optimization alg. Options:'sgd'
    dic['opt'] = 'sgd'
    # Loss function. Options:'mse','xent'
    dic['loss_type'] = 'mse'
    # Train on subset of full training set. Set to 0 for full dataset.
    dic['n_train'] = 0
    dic['n_test'] = 0
    # Use data augmentation (random flips & crops). True or False.
    dic['augment'] = False
    # Cloud bucket , if none doesn't upload to cloud
    dic['upload_to_cloud'] = None
    # unit in which measurements are made, meas_util uses it but we won't change it.
    dic['time_unit'] = 'step'
    # Learning rate, in NTK parameterization.
    dic['learning_rate'] = 0.1
    # Number of total training steps
    dic['train_steps'] = 2000.0
    dic['logdir'] = 'largelr_logs/'
    # Make standard measurements
    dic['std_meas'] = True
    # Random seed for initializing parameters.
    dic['seed'] = 0
    # Random seed for batching of data for training.
    dic['data_seed'] = 1
    # Stuff to save in job name file (CHANGE THIS MANUALLY). job_dir goes into common dir, job_id into folder within common dir.
    dic['job_dir_params'] = JOB_DIR_PARAMS
    dic['job_id_params'] = JOB_ID_PARAMS
    # This can be combined with the above and cleaned.
    dic['alias']=ALIAS
    # Folder to save weights in. Set to None for default.
    dic['savew_dir'] = 'weights/'
    dic['savew_name']=None
    # Save weights, provide a list of timesteps at which to steps, e,g [0, 20, 40, ...]. Otherwise set to None.
    dic['save_weights']= None
    # Specify file name (with extension) to load from, for instance 'weights/model100.pkl'. Set to None otherwise.
    dic['load_weights']= None
    #Loss and acc measured over all samples. 
    dic['std_meas_size']=0
    dic['verbose']=1
    # Group similar jobs in the same folder (jobs that only differ in job_id_params)
    dic['similarjobsfolder']=True
    dic['data_dir']='/tmp/datasets'
    dic['job_id']='largelrjob_'+str(uuid.uuid4().hex[:8])
    # only make standard meas
    dic['meas_simple']=False
    dic['meashesseval']=False
    ### LR schedule for wide resnet. 
    dic['wrnsch']=False
    ### Cosine schedule for WRN
    dic['momcos']=False
    ### Uses NTK normalization.
    dic['NTK_norm']=False
    ### Distribute gradients in mini batches of size distgrad
    dic['distgrad']=0
    ### L2 reg
    dic['L2']=0.0
    ### random labels
    dic['randomlabels']=0
    dic['randomdata']=0
    dic['shuffledata']=0
    ## Steps are given in physical time= step* learning_Rate
    dic['physical']=False
    ## Exit when accuracy hits this value. 1.1 it doesn't exit.
    dic['exit_acc']=1.1
    return dic

## THINGS TO EXPLORE: w_std,b_std,
def get_default_fc_config():
    PARAMS = PARAMS0 + ['dataset', 'model', 'depth', 'width', 'batch_size','train_steps',  'meas_freq', 'save_freq', 	'meas_bs', 'meas_samples', 'Hutch', 'Hutch_tol', 'NTK_SAMPLES',  'meas_overlaps',
    'measNTKspec', 'measHigherOrders','batch_norm','unit_norm','w_std_last']

    ConfigObj = namedtuple('config', PARAMS)
    dic = get_default_params()
    # Dataset. Current options:'cifar10','mnist','fashion_mnist'
    dic['dataset'] = 'mnist'
    # Architecture. Current options:'fc','cnn_real',wrn_original'
    dic['model'] = 'fc'
    # Hidden layers
    dic['depth'] = 2
    # Hidden units
    dic['width'] = 512
    # Option to include BatchNorm. Currently implemented for WideResnet only. Note kernel computation is invalid for such models.
    dic['batch_norm'] = False
    # Batch size for SGD
    dic['batch_size'] = 512
    # Measure overlap and g eigenval
    dic['meas_overlaps'] = False
    # freq of measuring and freq of saving (locally and cloud). logdir; save_freq=-1 only saves at the end
    dic['meas_freq'] = 20.0
    dic['save_freq'] = 100
    # measurement batch size and samples for this (and also tr NTK without Hutch)
    dic['meas_bs'] = 512
    dic['meas_samples'] = 4096
    # Measure NTK tr and eigenval, how many samples to use for NTK.V
    dic['measNTKspec'] = True
    dic['NTK_SAMPLES'] = 2048  
    # Use Hutch for tr eval and tolerance
    dic['Hutch'] = True
    dic['Hutch_tol'] = 0.005
    # Measure higher order operators (NOT IMPLEMENTED)
    dic['measHigherOrders'] = False
    dic['unit_norm']=False
    dic['w_std_last']= math.sqrt(2.0)


    try:
        ll=[dic[el] for el in PARAMS]
    except:
        raise Exception('Some parameter does not have a default value')
    return ConfigObj(*ll)


def get_default_cnn_real_config():
    PARAMS = PARAMS0 + ['dataset', 'model', 'depth', 'width', 'batch_size','train_steps',  'meas_freq', 'save_freq', 	'meas_bs', 'meas_samples', 'Hutch', 'Hutch_tol', 'NTK_SAMPLES',  'meas_overlaps',
    'measNTKspec', 'measHigherOrders','XL','pooling','batch_norm']

    ConfigObj = namedtuple('config', PARAMS)
    dic = get_default_params()
    # Dataset. Current options:'cifar10','mnist','fashion_mnist'
    dic['dataset'] = 'cifar10'
    # Architecture. Current options:'fc','cnn_flatten', 'cnn_real', 'wrn_original'
    dic['model'] = 'cnn_real'
    # UNUSED: Hidden layers
    dic['depth'] = 0
    # FOR THIS MODEL: factor by which to multiply widths (see model def)
    dic['width'] = 100
    # Batch size for SGD
    dic['batch_size'] = 256
    # Learning rate, in NTK parameterization.
    dic['learning_rate'] = 1.0
    # freq of measuring and freq of saving (locally and cloud). logdir; save_freq=-1 only saves at the end
    dic['meas_freq'] = 20.0
    dic['save_freq'] = 100
    # Measure overlap and g eigenval
    dic['meas_overlaps'] = False
    # measurement batch size and samples for this (and also tr NTK without Hutch)
    dic['meas_bs'] = 64 #256
    dic['meas_samples'] = 2048 #4096
    # Measure NTK tr and eigenval, how many samples to use for NTK.V
    dic['measNTKspec'] = True
    dic['NTK_SAMPLES'] = 512 #2048  
    # Use Hutch for tr eval and tolerance
    dic['Hutch'] = True
    dic['Hutch_tol'] = 0.005
    # Measure higher order operators (NOT IMPLEMENTED)
    dic['measHigherOrders'] = True 
    dic['XL']=False
    dic['pooling']='max'
    dic['batch_norm']=False
    try:
        ll=[dic[el] for el in PARAMS]
    except:
        raise Exception('Some parameter does not have a default value')
    return ConfigObj(*ll)


def get_default_wrn_config():
    PARAMS = PARAMS0 + ['dataset', 'model', 'depth', 'width', 'batch_size','train_steps',  'meas_freq', 'save_freq', 	'meas_bs', 'meas_samples', 'Hutch', 'Hutch_tol', 'NTK_SAMPLES',  'meas_overlaps',
    'measNTKspec', 'measHigherOrders','batch_norm','wrn_block_size','wrn_widening_f']
    ConfigObj = namedtuple('config', PARAMS)
    dic = get_default_params()
    # Dataset. Current options:'cifar10','mnist','fashion_mnist'
    dic['dataset'] = 'cifar10'
    # Architecture. Current options:'fc','cnn_flatten','cnn_gap', 'wrn_original'
    dic['model'] = 'wrn_original'
    # UNUSED: Hidden layers
    dic['depth'] = 0
    # UNUSED: Hidden units
    dic['width'] = 0
    # Option to include BatchNorm. Currently implemented for WideResnet only. Note kernel computation is invalid for such models.
    dic['batch_norm'] = False
    # Architectural feature for WRN: block size "N" in paper. Set to None otherwise.
    dic['wrn_block_size'] = 1
    # Architectural feature for WRN: widening factor "k" governing factor in channels (16*k -> 32*k -> 64*k across groups). Set to None otherwise.
    dic['wrn_widening_f'] = 1
    # Batch size for SGD
    dic['batch_size'] = 128
    # freq of measuring and freq of saving (locally and cloud). logdir; save_freq=-1 only saves at the end
    dic['meas_freq'] = 20.0
    dic['save_freq'] = 100
    # Measure overlap and g eigenval
    dic['meas_overlaps'] = False
    # measurement batch size and samples for this (and also tr NTK without Hutch)
    dic['meas_bs'] = 32 #256
    dic['meas_samples'] = 256 #4096
    # Measure NTK tr and eigenval, how many samples to use for NTK.V
    dic['measNTKspec'] = True
    dic['NTK_SAMPLES'] = 256 #2048   
    # Use Hutch for tr eval and tolerance
    dic['Hutch'] = True
    dic['Hutch_tol'] = 0.005
    dic['measHigherOrders'] = False
    # Stuff to save in job name file (CHANGE THIS MANUALLY). job_dir goes into common dir, job_id into folder within common dir.
    dic['job_dir_params'] = OrderedDict([('job_id', ''), ('dataset', ''), ('model', ''), ('wrn_block_size', 'bl'), ('nonlinearity', ''), ('loss_type', ''), ('opt', ''),('n_train','n')])
    dic['job_id_params'] = OrderedDict([('wrn_widening_f', 'wf'), ('learning_rate', 'lr'), ('batch_size', 'bs')])

    try:
        ll=[dic[el] for el in PARAMS]
    except:
        raise Exception('Some parameter does not have a default value')
    return ConfigObj(*ll)
