from numpy import number
from config import cfg
from typing import List
import numpy as np
import collections

def process_command():
    cfg['save_interval'] = 30
    cfg['algo_mode'] = cfg['control']['algo_mode']
    cfg['data_name'] = cfg['control']['data_name']
    cfg['model_name'] = cfg['control']['model_name']
    cfg['algo_mode'] = cfg['control']['algo_mode']
    data_shape = {'CINIC10': [3, 32, 32], 'CIFAR10': [3, 32, 32], 'CIFAR100': [3, 32, 32], 'FashionMNIST': [28, 28], 'MNIST': [28, 28], 'SVHN': [3, 32, 32]}
    cfg['data_shape'] = data_shape[cfg['data_name']]
    cfg['conv'] = {'hidden_size': [32, 64]}
    cfg['resnet9'] = {'hidden_size': [64, 128, 256, 512]}
    # cfg['resnet9'] = {'hidden_size': [64, 128]}
    cfg['resnet18'] = {'hidden_size': [64, 128, 256, 512]}
    cfg['wresnet28x2'] = {'depth': 28, 'widen_factor': 2, 'drop_rate': 0.0}
    cfg['wresnet28x8'] = {'depth': 28, 'widen_factor': 8, 'drop_rate': 0.0}
    cfg['threshold'] = 0.95
    cfg['alpha'] = 0.75
    cfg['max_clip_norm'] = 10
    if 'lambda' in cfg['control']:
        cfg['lambda'] = float(cfg['control']['lambda'])
    else:
        cfg['lambda'] = 0.1

    if cfg['lambda'] == 999999:
        cfg['lambda_for_infinite_situation'] = 0.99

    # cfg['cyclic_replacement'] = {
    #     '0': '9'



    # }
    
    # TODO tune
    if 'objective_func_sigmoid_s' in cfg['control']:
        cfg['objective_func_sigmoid_s'] = float(cfg['control']['objective_func_sigmoid_s'])
    else:
        cfg['objective_func_sigmoid_s'] = 0.005

    if 'malicious_way' in cfg['control']:
        cfg['malicious_way'] = cfg['control']['malicious_way']
    else:
        cfg['malicious_way'] = 'no'

    if 'malicious_clients_ratio' in cfg['control']:
        cfg['malicious_clients_ratio'] = float(cfg['control']['malicious_clients_ratio'])
    else:
        cfg['malicious_clients_ratio'] = 0
    
    if 'malicious_data_in_clients_ratio' in cfg['control']:
        cfg['malicious_data_in_clients_ratio'] = float(cfg['control']['malicious_data_in_clients_ratio'])
    else:
        cfg['malicious_data_in_clients_ratio'] = 0
    # cfg['objective_func_lr'] = 0.001

    if 'num_clients' in cfg['control']:

        cfg['pricing_threshold_1'] = 0.3
        cfg['pricing_threshold_2'] = 0.7
        cfg['pricing_0'] = 0.9
        cfg['pricing_1'] = 0.1
        cfg['pricing_2'] = 0.1

        if 'pricing_interval' in cfg['control']:
            # include round 0
            cfg['pricing_interval'] = int(cfg['control']['pricing_interval']) + 1
        else:
            # include round 0
            cfg['pricing_interval'] = 3
        
        cfg['num_clients'] = int(cfg['control']['num_clients'])
        cfg['active_rate'] = float(cfg['control']['active_rate'])

        if 'data_split_mode' in cfg['control']:
            cfg['data_split_mode'] = cfg['control']['data_split_mode']
        else:
            cfg['data_split_mode'] = 'iid-equal'
        cfg['local_epoch'] = 5
        cfg['gm'] = 0
        cfg['server'] = {}
        cfg['server']['shuffle'] = {'train': True, 'test': False}
        cfg['server']['batch_size'] = {'train': 250, 'test': 500}
        cfg['client'] = {}
        cfg['client']['shuffle'] = {'train': True, 'test': False}
        cfg['client']['batch_size'] = {'train': 10, 'test': 500}

        # if cfg['num_clients'] > 10:
        cfg['client']['batch_size'] = {'train': 10, 'test': 500}
        
        cfg['theta1_window_size'] = 10
        # elif cfg['num_clients'] > 1:
        #     cfg['client']['batch_size'] = {'train': 100, 'test': 500}
        # else:
        # cfg['client']['batch_size'] = {'train': 250, 'test': 500}
        if 'train_sample_portion' in cfg['control']:
            cfg['train_sample_portion'] = float(cfg['control']['train_sample_portion'])
        else:
            cfg['train_sample_portion'] = 1

        cfg['local_training_type'] = 'epoch-5'
        if 'local_training_type' in cfg['control']:
            cfg['local_training_type'] = cfg['control']['local_training_type']
            if 'epoch' in cfg['local_training_type']:
                cfg['local_epoch'] = int(cfg['local_training_type'][-1])
            
        if 'update_server_model' in cfg['control']:
            cfg['update_server_model'] = cfg['control']['update_server_model']
        else:
            cfg['update_server_model'] = 'average'

        cfg['client']['optimizer_name'] = 'SGD'
        cfg['client']['lr'] = 3e-2
        if 'learning_rate' in cfg['control']:
            cfg['client']['lr'] = float(cfg['control']['learning_rate'])
        cfg['client']['momentum'] = 0.9
        cfg['client']['weight_decay'] = 5e-4
        cfg['client']['nesterov'] = True
        cfg['client']['maximize'] = False
        

        cfg['client']['num_epochs'] = cfg['local_epoch']

        if cfg['num_clients'] > 10:
            cfg['server']['num_epochs'] = 100
        else:
            cfg['server']['num_epochs'] = 100
        cfg['server']['optimizer_name'] = 'SGD'
        cfg['server']['lr'] = 1
        cfg['server']['momentum'] = cfg['gm']
        cfg['server']['weight_decay'] = 0
        cfg['server']['nesterov'] = False
        cfg['server']['scheduler_name'] = 'CosineAnnealingLR'
        cfg['server']['maximize'] = False

        if 'pricing_func' in cfg['control']:
            cfg['pricing_func'] = cfg['control']['pricing_func']
        else:
            cfg['pricing_func'] = 'lin'

        if 'theta_0' in cfg['control']:
            cfg['theta_0'] = float(cfg['control']['theta_0'])
        else:
            cfg['theta_0'] = -0.1

        if 'theta_1' in cfg['control']:
            cfg['theta_1'] = float(cfg['control']['theta_1'])
        else:
            cfg['theta_1'] = 0.3

        if 'pricing_func_s' in cfg['control']:
            cfg['pricing_func_s'] = float(cfg['control']['pricing_func_s'])
        else:
            cfg['pricing_func_s'] = 1

        if 'pricing_func_sigmoid_s' in cfg['control']:
            cfg['pricing_func_sigmoid_s'] = float(cfg['control']['pricing_func_sigmoid_s'])
        else:
            cfg['pricing_func_sigmoid_s'] = 0.005

        cfg['fedincen_objective'] = {}
        cfg['fedincen_objective']['num_epochs'] = cfg['server']['num_epochs']
        cfg['fedincen_objective']['optimizer_name'] = 'SGD'
        if 'objective_func_lr' in cfg['control']:
            cfg['fedincen_objective']['lr'] = float(cfg['control']['objective_func_lr'])
        else:
            cfg['fedincen_objective']['lr'] = 0.0001
        # cfg['fedincen_objective']['momentum'] = 0.9
        # cfg['fedincen_objective']['weight_decay'] = 5e-4
        # cfg['fedincen_objective']['nesterov'] = True

        if 'objective_func_momentum' in cfg['control']:
            cfg['fedincen_objective']['momentum'] = float(cfg['control']['objective_func_momentum'])
        else:
            cfg['fedincen_objective']['momentum'] = 0.5
        
        if 'objective_func_weight_decay' in cfg['control']:
            cfg['fedincen_objective']['weight_decay'] = float(cfg['control']['objective_func_weight_decay'])
        else:
            cfg['fedincen_objective']['weight_decay'] = 5e-4
        cfg['fedincen_objective']['nesterov'] = True

        cfg['fedincen_objective']['scheduler_name'] = 'CosineAnnealingLR'
        cfg['fedincen_objective']['maximize'] = False
        if 'regularizer' in cfg['control']:
            cfg['regularizer'] = True if int(cfg['control']['regularizer']) == 1 else False
        else:
            cfg['regularizer'] = False
        
        cfg['pricing_plan'] = '1'
        if 'pricing_plan' in cfg['control']:
            cfg['pricing_plan'] = cfg['control']['pricing_plan']
            if cfg['pricing_plan'] == '1':
                cfg['pricing_plan_coefficient'] = 11
            elif cfg['pricing_plan'] == '2':
                cfg['pricing_plan_coefficient'] = 101
            elif cfg['pricing_plan'] == '3':
                cfg['pricing_plan_coefficient'] = 2001
    
    else:
        raise ValueError('no num_clients')
    
    cfg['minimum_participation_clients'] = int(np.ceil(cfg['active_rate'] * cfg['num_clients']))

    print(f'cfg: {cfg}')
    return
