from config import cfg
from .stats import make_stats


def process_control():
    cfg['data_name'] = cfg['control']['data_name']
    cfg['model_name'] = cfg['control']['model_name']
    cfg['num_groups'] = cfg['control'].get('num_groups', None)


    cfg['batch_size'] = 400
    cfg['step_period'] = 1
    cfg['num_steps'] = 10
    cfg['eval_period'] = 1
    cfg['num_epochs'] = 100
    cfg['collate_mode'] = 'dict'

    cfg['model'] = {}
    cfg['model']['model_name'] = cfg['model_name']
    data_shape = {'SimulateR': [1], 'SimulateC': [11], 'Adult': [89], 'SimulateCM':[11], 'AdultM': [89] }
    target_size = {'SimulateR': 1, 'SimulateC': 2, 'Adult': 2, 'SimulateCM': 2, 'AdultM': 2}
    cfg['model']['data_shape'] = data_shape[cfg['data_name']]
    cfg['model']['target_size'] = target_size[cfg['data_name']]
    cfg['model']['linear'] = {}
    cfg['model']['stats'] = make_stats(cfg['data_name'], cfg.get('init_seed', 0))
    cfg['model']['debias'] = cfg['control']['debias']
    cfg['model']['hparams'] = cfg['control']['hparams']


    tag = cfg['tag']
    cfg[tag] = {}
    cfg[tag]['optimizer'] = {}
    cfg[tag]['optimizer']['optimizer_name'] = 'SGD'
    cfg[tag]['optimizer']['lr'] = 0.01 
    cfg[tag]['optimizer']['momentum'] = 0.9
    cfg[tag]['optimizer']['betas'] = (0.9, 0.999)
    cfg[tag]['optimizer']['weight_decay'] = 0
    cfg[tag]['optimizer']['nesterov'] = True
    cfg[tag]['optimizer']['batch_size'] = {'train': cfg['batch_size'], 'test': cfg['batch_size']}
    cfg[tag]['optimizer']['step_period'] = cfg['step_period']
    cfg[tag]['optimizer']['num_steps'] = cfg['num_steps']
    cfg[tag]['optimizer']['scheduler_name'] = 'CosineAnnealingLR'
    return
