'''
    this file contains many default parameters, 
    the parameters provided from command line will overwrite some of them
'''
# CUDA setting
cuda_deterministic = False # sets flags for determinism when using CUDA (potentially slow!)
no_cuda = False # disable CUDA training
import torch
cuda = not no_cuda and torch.cuda.is_available() # whether to use cuda
device = 3 # which device to use (default: 0)

# important model parameters
algo = 'acktr'  # algorithm to use: a2c | ppo | acktr
gamma = 1 # discount factor for rewards (default: 1)
num_processes = int(16) # how many training CPU processes to use (default: 16)
entropy_coef = 0.01 # entropy term coefficient (default: 0.01)
invalid_coef = 0.01 # invalid action possibility term coefficient (default: 0.1)
value_loss_coef = 0.5 # value loss coefficient (default: 0.5)
env_loss_coef = 0.001
hidden_size = 256 # hidden layer cell number (default: 256)
pruning_threshold = 0.5 # pruning_threshold (default: 0.5)
preview = 1 # known item number (default: 1), 1 means 'not lookahead'

# default environment padrameters
env_name = 'RealBpp-v0' # environment to train on
container_size = (12, 10, 10) # the size of bin(container)
box_range = (2,2,2, 5,5,5) # the item size range (x_min, y_min, z_min, x_max, y_max, z_max)
pallet_size = container_size[0]

box_size_set = [] 
for i in range(box_range[0],box_range[3]+1):
    for j in range(box_range[1],box_range[4]+1):
        for k in range(box_range[2],box_range[5]+1):
            box_size_set.append((i, j, k))
# other parameters of our environment
adjust = False # adjust agents' actions to touch corners or sides
adjust_ratio = 0  # adjust distance
input_format = 'cnn'  # the cnn vec
channel = 7 # 7 or 3
# channels of CNN: 4 for hmap+next box, 5 for hmap next box+truemask
data_type = 'depen'  # item sequence generators, depen|sample|md
give_up = False # whether agent can give up, now can only be False
enable_rotation = 1 # whether agent can rotate box

# saving and loading setting
cases = 49 # the number of sequences used for test (default 100)
save_model = True # whether to save training model
save_interval = 1000 # save interval, one save per n updates (default: 100)
log_interval = 100 # eval interval, one eval per n updates (default: None)
save_dir = './saved_models/' # directory to save agent logs (default: ./saved_models/)
load_dir = './pretrained_models/' # directory to load agent logs (default: ./pretrained_models/)
load_name = 'RealBpp-v02022.09.29-12-30.pt' # default trained model for testing or continuing training
load_name_sub = 'RealBpp-v02022.09.29-12-30_s.pt' # default trained model for testing or continuing training
data_dir = './dataset/' # the directory storing datasets
data_name = 'cut_2_reorder.pt' # the name of dataset, check 'data_dir' for details
tensorboard = True # whether use tensorboard to tracing trainning process
tbx_dir = './runs' # directory to save tensorboard logs (default: ./runs)
log_dir = './log' # directory to save agent logs (default: ./log)
image_folder = None # directory to save pictures (default: None)

pretrain = False # load whole model
# load_model = True # load model parameters only

# other parameters
recurrent_policy = False # use a recurrent policy
use_linear_lr_decay = True # use a linear schedule on the learning rate
use_proper_time_limits = False # compute returns taking into account time limits
lr = 1e-4   # learning rate (default: 7e-4)
eps = 1e-5 # RMSprop optimizer epsilon (default: 1e-5)
alpha = 0.99 # RMSprop optimizer apha (default: 0.99)
max_grad_norm = 0.5 # max norm of gradients (default: 0.5)
seed = int(7) # random seed (default: 1)

boxlist_len = (70, 70)
log_interval = int(4)  # log interval, one log per n updates (default: 10)'
num_steps = int(5) # number of forward steps in A2C (default: 5)
num_env_steps = 10e6 # number of environment steps to train (default: 10e6)

num_mini_batch = int(32) # number of batches for ppo (default: 32)
clip_param = float(0.2) # ppo clip parameter (default: 0.2)

env_params = {
    'node_cnt': 21,
    'problem_gen_params': {
        'int_min': 0,
        'int_max': 1000*1000,
        'scaler': 1000*1000
    },
    'pomo_size': 1  # same as node_cnt
}

model_params = {
    'input_box_dim': 3,
    'embedding_dim': 64,
    'sqrt_embedding_dim': 64**(1/2),
    'encoder_layer_num': 3,
    'decoder_layer_num': 5,
    'qkv_dim': 16,
    'sqrt_qkv_dim': 16**(1/2),
    'head_num': 8,
    'logit_clipping': 10,
    'ff_hidden_dim': 512,
    'ms_hidden_dim': 16,
    'ms_layer1_init': (1/2)**(1/2),
    'ms_layer2_init': (1/16)**(1/2),
    'eval_type': 'argmax',
    'one_hot_seed_cnt': 20,  # must be >= node_cnt
}

optimizer_params = {
    'optimizer': {
        'lr': 1e-5 ,
        'weight_decay': 1e-6
    },
    'scheduler': {
        'milestones': [2001, 2101],  # if further training is needed
        'gamma': 0.1
    }
}

trainer_params = {
    'use_cuda': True,
    'cuda_device_num': 0,
    'epochs': 2000,
    'train_episodes': 10*1000,
    'train_batch_size': 200,
    'logging': {
        'model_save_interval': 100,
        'img_save_interval': 200,
        'log_image_params_1': {
            'json_foldername': 'log_image_style',
            'filename': 'style.json'
        },
        'log_image_params_2': {
            'json_foldername': 'log_image_style',
            'filename': 'style_loss.json'
        },
    },
    'model_load': {
        'enable': False,  # enable loading pre-trained model
        'path': './saved_models/1421',  # directory path of pre-trained model and log files saved.
        'name': 'checkpoint-2022.01.14-17-23',  # epoch version of pre-trained model to laod.
    }
}

tester_params = {
    'use_cuda': True,
    'cuda_device_num': 0,
    'model_name': 'checkpoint-2022.01.14-12-53.pt',
    'model_path': './pretrained_models',
    'model_load': {
        'path': './pretrained_models',  # directory path of pre-trained model and log files saved.
        'epoch': 5000,  # epoch version of pre-trained model to load.
    },
    'file_count': 10*1000,
    'test_batch_size': 1000,
    'augmentation_enable': True,
    'aug_factor': 128,
    'aug_batch_size': 100,
}


logger_params = {
    'log_file': {
        'desc': 'matnet_train',
        'filename': 'log.txt'
    }
}
