import numpy as np
import argparse


def set_cfgs():
    args = process_args()

    config = dict()
    config['seed'] = args.seed
    config['gpu_id'] = args.gpu_id
    config['device'] = args.device
    rng = np.random.RandomState(config['seed'])
    config['rng'] = rng

    # data cfgs
    config['batch_size'] = 1
    config['dt'] = 1
    config['image_shape'] = [args.image_shape]
    config['duration'] = args.duration
    config['tdim'] = int(config['duration'] / config['dt'])
    config['stimPeriod'] = np.array([0, int(500 / config['dt'])])
    config['fixationPeriod'] = np.array([0, int(1500 / config['dt'])])
    config['decisionPeriod'] = np.array([int(1500 / config['dt']), int(config['duration'] / config['dt'])])
    config['fixationInput'] = 1.0 / np.sqrt(np.prod(config['image_shape']))
    config['data_type'] = 2

    # model cfgs
    config['model_type'] = args.model_type
    config['neuron_thr'] = args.neuron_thr
    config['num_input'] = np.prod(config['image_shape']) + 1
    config['num_rnn'] = args.num_rnn
    config['num_rnn_out'] = np.prod(config['image_shape']) + 1
    config['num_branch'] = args.num_branch
    config['tau_minitializer'] = 'uniform'
    config['low_m'] = 0
    config['high_m'] = 4
    config['tau_ninitializer'] = 'uniform'
    config['low_n'] = 0
    config['high_n'] = 4

    # training cfgs
    config['learning_rate'] = args.learning_rate
    config['beta1'] = args.beta1
    config['beta2'] = args.beta2
    config['weight_decay'] = 0

    # loss cfgs
    config['l2_h'] = args.l2_h
    config['l2_wR'] = args.l2_wR
    config['l2_wI'] = args.l2_wI
    config['l2_wO'] = args.l2_wO

    # criterion cfgs
    config['thr'] = args.thr
    config['problems'] = 1000
    config['iters'] = args.max_iters
    config['max_iters'] = args.max_iters

    # other cfgs
    config['out_root'] = args.out_root
    config['damage_train_root'] = args.damage_train_root
    config['neuron_role_path'] = args.neuron_role_path
    return config


def process_args():
    parser = argparse.ArgumentParser(description='custom your configs')

    # other arguments
    parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'], help='device to use')
    parser.add_argument('--gpu_id', type=str, default='0', choices=['0', '1', '2', '3'], help='device id to use')
    parser.add_argument('--seed', type=int, default=42, help='random seed')

    # for datasets
    parser.add_argument('--image_shape', type=int, default=4, help='image size')
    parser.add_argument('--duration', type=int, default=2000, help='duration of input')

    # for model
    parser.add_argument('--model_type', type=str, default='custom-snn-heter',
                        choices=['custom-rnn', 'custom-snn-homo',
                                 'custom-snn-heter-states', 'custom-snn-heter-tau', 'custom-snn-heter-thr',
                                 'custom-snn-heter-states_tau', 'custom-snn-heter-states_thr', 'custom-snn-heter-tau_thr',
                                 'custom-snn-heter', 'custom-snn-optim'],
                        help='model to use')
    parser.add_argument('--neuron_thr', type=float, default=0.25, help='neuron threshold')
    parser.add_argument('--num_rnn', type=int, default=256, help='number of neurons')
    parser.add_argument('--num_branch', type=int, default=2, help='number of branch')

    # for snn loss
    parser.add_argument('--l2_h', type=float, default=0.0005, help='l2_h for loss')
    parser.add_argument('--l2_wR', type=float, default=0.001, help='l2_wR for loss')
    parser.add_argument('--l2_wI', type=float, default=0.0001, help='l2_wI for loss')
    parser.add_argument('--l2_wO', type=float, default=0.00001, help='l2_wO for loss')

    # # for rnn loss
    # parser.add_argument('--l2_h', type=float, default=0.0005, help='l2_h for loss')
    # parser.add_argument('--l2_wR', type=float, default=0.001, help='l2_wR for loss')
    # parser.add_argument('--l2_wI', type=float, default=0.0001, help='l2_wI for loss')
    # parser.add_argument('--l2_wO', type=float, default=0.1, help='l2_wO for loss')

    # for snn optim
    parser.add_argument('--learning_rate', type=float, default=0.01, help='learning rate for optimizer')
    parser.add_argument('--beta1', type=float, default=0.1, help='beta1 for optim')
    parser.add_argument('--beta2', type=float, default=0.3, help='beta2 for optim')

    # # for rnn optim
    # parser.add_argument('--learning_rate', type=float, default=0.0001, help='learning rate for optimizer')
    # parser.add_argument('--beta1', type=float, default=0.3, help='beta1 for optim')
    # parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for optim')

    # for criterion
    parser.add_argument('--thr', type=float, default=0.006, help='threshold for criterion')
    parser.add_argument('--max_iters', type=int, default=5000, help='maximum number of iterations')

    # for checkpoint
    parser.add_argument('--out_root', type=str, default='./results/results_snn_256_optim', help='output directory')
    parser.add_argument('--damage_train_root', type=str, default='../damage_schema')
    parser.add_argument('--neuron_role_path', type=str, default='../neuron_role/task-3_model-custom-snn-dr.json')

    args = parser.parse_args()
    return args
