##########################################################################################
# Machine Environment Config
DEBUG_MODE = False
USE_CUDA = not DEBUG_MODE
CUDA_DEVICE_NUM = 0

##########################################################################################
# Path Config
import os
import sys

os.chdir(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, "..")  # for problem_def
sys.path.insert(0, "../..")  # for utils

##########################################################################################
# import
import logging
from utils.utils import create_logger, copy_all_src

from MOCVRPTrainer import CVRPTrainer

##########################################################################################
# parameters
env_params = {
    'problem_size': 5,
    'pomo_size': 5,
}

architecture = "MBM" # GMS-DH, GMS-EB or MBM
# GMS-DH: Change GREAT_params and DH_params for encoder, MP_params for decoder
# GMS-EB: Change GREAT_params for encoder, MP_params for decoder
# MBM: Change MatNet_params for encoder, MP_params for decoder

training_method = "Chb" # Either Linear och Chb
curriculum_learning = True

distribution = "TMAT" # Either EUC, TMAT, XASY

### Encoders ###

GREAT_params = {
    'embedding_dim': 128,
    'encoder_layer_num': 6,
    'qkv_dim': 16,
    'head_num': 8, 
    'ff_hidden_dim': 512,
    "great_nodeless": False,
    "great_asymmetric": True,
    "depot_embedder": False, 
    "dropout": 0.1, 
}

MatNet_params = {
    'one_hot_seed_cnt': 20, # Only relevant if demand_row_emb is True
    'embedding_dim': 128, 
    'encoder_layer_num': 6,
    'qkv_dim': 16,
    'head_num': 8, 
    'ff_hidden_dim': 512,
    'ms_hidden_dim': 16,
    'ms_layer1_init': (1/2)**(1/2),
    'ms_layer2_init': (1/16)**(1/2),
    "depot_embedder": False, 
    'demand_row_emb': False, # Whether to put demand embeddings in rows or cols
}

dh_params = {
    'L1': 5, # GNN layer number, overrides GREAT_params['encoder_layer_num']
    'L2': 2 # Transformer layer number
}

### Encoders end ###

### Decoders ###

MP_params = {
    'embedding_dim': 128,
    'qkv_dim': 16, 
    'head_num': 8,
    'logit_clipping': 10,
    'eval_type': 'argmax',
    'load_query': True
}

### Decoders end ###

optimizer_params = {
    'optimizer': {
        'lr': 1e-4, 
        'weight_decay': 1e-6
    },
    'scheduler': {
        'milestones': [180,],
        'gamma': 0.1
    }
}

trainer_params = {
    'use_cuda': USE_CUDA,
    'cuda_device_num': CUDA_DEVICE_NUM,
    'pref_dist': 'unif-2', 
    'epochs': 10,
    'train_episodes': 1000,
    'train_batch_size': 64,
    'logging': {
        'model_save_interval': 5,
        'img_save_interval': 10,
        'log_image_params_1': {
            'json_foldername': 'log_image_style',
            'filename': 'style_cvrp_20.json'
        },
        'log_image_params_2': {
            'json_foldername': 'log_image_style',
            'filename': 'style_loss_1.json'
        },
    },
    'model_load': {
        'enable': False,  # enable loading pre-trained model
        'path': './Final_result/edge_50',  # directory path of pre-trained model and log files saved.
        'epoch': 20,  # epoch version of pre-trained model to laod.
    }
}

logger_params = {
    'log_file': {
        'desc': 'train__cvrp',
        'filename': 'run_log'
    }
}

def curriculum_function(epoch):

    batch_size = trainer_params['train_batch_size']

    if epoch <= 5:
        problem_size = 5
        fwd_batch_size = 32
    else:
        problem_size = 10
        fwd_batch_size = batch_size

    problem_size += 1

    pomo_size = problem_size - 1

    logger = logging.getLogger('root')
    logger.info(f"Training on problem size: {problem_size}")

    return problem_size, pomo_size, batch_size, fwd_batch_size

### Config end
if architecture == "GMS-DH":
    encoder = "hybrid"
    decoder = "MP"
elif architecture == "GMS-EB":
    encoder = "GREAT-E"
    decoder = "MP-E"
else:
    encoder = "MatNet"
    decoder = "MP"

if encoder == "GREAT" or encoder == "GREAT-E":
    encoder_params = GREAT_params
elif encoder == "MatNet":
    encoder_params = MatNet_params
    encoder_params["one_hot_seed_cnt"] += 1
elif encoder == "hybrid":
    encoder_params = dh_params
    encoder_params['edge_attention_type'] = "GREAT"
    encoder_params['edge_attention_params'] = GREAT_params

decoder_params = MP_params
decoder_params["training_method"] = training_method
##########################################################################################
# main
def main():
    if DEBUG_MODE:
        _set_debug_mode()

    create_logger(**logger_params)

    _print_config()

    # We treat depot as any other node
    env_params['problem_size'] += 1

    trainer = CVRPTrainer(encoder=encoder,
                    decoder=decoder,
                    training_method=training_method,
                    curriculum_learning=curriculum_learning,
                    curriculum_function=curriculum_function,
                    distribution=distribution,
                    env_params=env_params,
                    encoder_params=encoder_params,
                    decoder_params=decoder_params,
                    optimizer_params=optimizer_params,
                    trainer_params=trainer_params)

    copy_all_src(trainer.result_folder)

    trainer.run()


def _set_debug_mode():
    global trainer_params
    trainer_params['epochs'] = 1
    trainer_params['train_episodes'] = 10
    trainer_params['train_batch_size'] = 10


def _print_config():
    logger = logging.getLogger('root')
    logger.info('DEBUG_MODE: {}'.format(DEBUG_MODE))
    logger.info('USE_CUDA: {}, CUDA_DEVICE_NUM: {}'.format(USE_CUDA, CUDA_DEVICE_NUM))
    logger.info('Encoder: {}'.format(encoder))
    logger.info('Decoder: {}'.format(decoder))
    logger.info('Distribution: {}'.format(distribution))
    logger.info('Training Method: {}'.format(training_method))
    [logger.info(key + ": {}".format(encoder_params[key])) for key in encoder_params.keys()]
    [logger.info(key + ": {}".format(decoder_params[key])) for key in decoder_params.keys()]

##########################################################################################

if __name__ == "__main__":
    main()
