##########################################################################################
# Machine Environment Config
DEBUG_MODE = False
USE_CUDA = not DEBUG_MODE
CUDA_DEVICE_NUM = 0

##########################################################################################
# Path Config
import os
import sys
import time
import torch
import numpy as np
import hvwfg

os.chdir(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, "..")  # for problem_def
sys.path.insert(0, "../..")  # for utils

##########################################################################################
# import
import logging
from utils.utils import create_logger, copy_all_src, get_result_folder

from MOTSPTester_efficient import TSPTester
from MOTSPProblemDef import get_random_problems, sparsify

##########################################################################################
# parameters
env_params = {
    'problem_size': 5,
    'pomo_size': 5,
}

training_method = "Chb" # Either Linear och Chb

distribution = "MG_flex" # Either TMAT, XASY, MG_flex or MG_fix
emax = 2 # Only relevant for MG distributions

### Encoders ###

MatNet_params = {
    'one_hot_seed_cnt': 20,
    'embedding_dim': 128, 
    'encoder_layer_num': 6,
    'qkv_dim': 16,
    'head_num': 8, 
    'ff_hidden_dim': 512,
    'ms_hidden_dim': 16,
    'ms_layer1_init': (1/2)**(1/2),
    'ms_layer2_init': (1/16)**(1/2),
    'tw_row_emb': True
}

### Encoders end ###

### Decoders ###

MP_params = {
    'embedding_dim': 128,
    'qkv_dim': 16, 
    'head_num': 8,
    'logit_clipping': 10,
    'eval_type': 'argmax',
}

### Decoders end ###

tester_params = {
    'use_cuda': USE_CUDA,
    'cuda_device_num': CUDA_DEVICE_NUM,
    'seed': 1000,
    'model_load': {
        'path': './result/test',  # directory path of pre-trained model and log files saved.
        'epoch': 10, 
    },
    'reference': [15, 15],
    'test_episodes': 200, 
    'test_batch_size': 200,
    'augmentation_enable': False,
    'aug_factor': 8,
    'aug_batch_size': 25
}
if tester_params['augmentation_enable']:
    tester_params['test_batch_size'] = tester_params['aug_batch_size']

logger_params = {
    'log_file': {
        'desc': 'train_test_tsp',
        'filename': 'run_log'
    }
}

### Config end

encoder_params = MatNet_params
decoder_params = MP_params
decoder_params["training_method"] = training_method

##########################################################################################
# main
def train_test():

    create_logger(**logger_params)

    _print_config()

    # We treat depot as any other node
    env_params['problem_size'] += 1
    MatNet_params['one_hot_seed_cnt'] += 1

    tester = TSPTester(
                training_method=training_method,
                distribution=distribution,
                emax=emax,
                env_params=env_params,
                encoder_params=encoder_params,
                decoder_params=decoder_params,
                tester_params=tester_params)

    n_sols = 101

    if distribution != "MG_fix" and distribution != "MG_flex":
        problems, service_times, tw_start, tw_end = get_random_problems(distribution, tester_params['test_episodes'], env_params['problem_size'], set_seed=True, seed=tester_params['seed'])
    else: 
        edge_attr, edge_indices, service_times, tw_start, tw_end = get_random_problems(distribution, tester_params['test_episodes'], env_params['problem_size'], emax=emax, set_seed=True, seed=tester_params['seed'])

    prefs = torch.zeros((n_sols, 2))
    for i in range(n_sols):
        prefs[i, 0] = 1 - 0.01 * i
        prefs[i, 1] = 0.01 * i
        
    timer_start = time.time()
                    
    if distribution != "MG_flex" and distribution != "MG_fix":
        sols = tester.run(problems, service_times, tw_start, tw_end, prefs, print_results=False)
    else: 
        sols = tester.run(edge_attr, service_times, tw_start, tw_end, prefs, edge_indices, print_results=False)

    timer_end = time.time()
    
    total_time = timer_end - timer_start
    
    ref = np.asarray(tester_params['reference'])
    hv = hvwfg.wfg(sols.astype(float), ref.astype(float))
    hv_ratio =  hv / (ref[0] * ref[1])
    
    print('Run Time(s): {:.4f}'.format(total_time))
    print('HV Ratio: {:.4f}'.format(hv_ratio))

def _print_config():
    logger = logging.getLogger('root')
    logger.info('DEBUG_MODE: {}'.format(DEBUG_MODE))
    logger.info('USE_CUDA: {}, CUDA_DEVICE_NUM: {}'.format(USE_CUDA, CUDA_DEVICE_NUM))
    logger.info('Distribution: {}'.format(distribution))
    logger.info('Emax: {}'.format(emax))
    logger.info('Training Method: {}'.format(training_method))
    [logger.info(key + ": {}".format(encoder_params[key])) for key in encoder_params.keys()]
    [logger.info(key + ": {}".format(decoder_params[key])) for key in decoder_params.keys()]

##########################################################################################

if __name__ == "__main__":
    train_test()
