##########################################################################################
# Machine Environment Config
DEBUG_MODE = False
USE_CUDA = not DEBUG_MODE
CUDA_DEVICE_NUM = 0

##########################################################################################
# Path Config
import os
import sys
import torch
import numpy as np

os.chdir(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, "..")  # for problem_def
sys.path.insert(0, "../..")  # for utils

##########################################################################################
# import
import logging
from utils.utils import create_logger, copy_all_src


from MOTSPTester_efficient import TSPTester
from MOTSPProblemDef import get_random_problems

##########################################################################################
import time
import hvwfg

##########################################################################################
# parameters
env_params = {
    'problem_size': 10,
    'pomo_size': 10,
}

architecture = "GMS-EB" # GMS-DH, GMS-EB or MBM
# GMS-DH: Change GREAT_params and DH_params for encoder, MP_params for decoder
# GMS-EB: Change GREAT_params for encoder, MP_params for decoder
# MBM: Change MatNet_params for encoder, MP_params for decoder

training_method = "Chb" # Either Linear och Chb

distribution = "TMAT" # Either EUC, TMAT, XASY

### Encoders ###

GREAT_params = {
    'embedding_dim': 128,
    'encoder_layer_num': 6,
    'qkv_dim': 16,
    'head_num': 8, 
    'ff_hidden_dim': 512,
    "great_nodeless": False,
    "great_asymmetric": True,
    "dropout": 0.1, 
}

MatNet_params = {
    'one_hot_seed_cnt': 10,
    'embedding_dim': 128, 
    'encoder_layer_num': 6,
    'qkv_dim': 16,
    'head_num': 8, 
    'ff_hidden_dim': 512,
    'ms_hidden_dim': 16,
    'ms_layer1_init': (1/2)**(1/2),
    'ms_layer2_init': (1/16)**(1/2),
}

dh_params = {
    'L1': 5, # GNN layer number, overrides GREAT_params['encoder_layer_num']
    'L2': 2 # Transformer layer number
}

### Encoders end ###

### Decoders ###

MP_params = {
    'embedding_dim': 128,
    'qkv_dim': 16, 
    'head_num': 8,
    'logit_clipping': 10,
    'eval_type': 'argmax',
}

### Decoders end ###

tester_params = {
    'use_cuda': USE_CUDA,
    'cuda_device_num': CUDA_DEVICE_NUM,
    'seed': 1000,
    'model_load': {
        'path': './result/Test',  # directory path of pre-trained model and log files saved.
        'epoch': 10, 
    },
    'reference': [15, 15, 15],
    'test_episodes': 200, 
    'test_batch_size': 200,
    'augmentation_enable': False,
    'aug_factor': 8,
    'aug_batch_size': 25
}
if tester_params['augmentation_enable']:
    tester_params['test_batch_size'] = tester_params['aug_batch_size']

logger_params = {
    'log_file': {
        'desc': 'test__tsp_n20',
        'filename': 'run_log'
    }
}

### Config end

if architecture == "GMS-DH":
    encoder = "hybrid"
    decoder = "MP"
elif architecture == "GMS-EB":
    encoder = "GREAT-E"
    decoder = "MP-E"
else:
    encoder = "MatNet"
    decoder = "MP"

if encoder == "GREAT" or encoder == "GREAT-E":
    encoder_params = GREAT_params
elif encoder == "MatNet":
    encoder_params = MatNet_params
elif encoder == "hybrid":
    encoder_params = dh_params
    encoder_params['edge_attention_type'] = "GREAT"
    encoder_params['edge_attention_params'] = GREAT_params

decoder_params = MP_params
decoder_params["training_method"] = training_method

##########################################################################################
def _set_debug_mode():
    global tester_params
    tester_params['test_episodes'] = 100

def _print_config():
    logger = logging.getLogger('root')
    logger.info('DEBUG_MODE: {}'.format(DEBUG_MODE))
    logger.info('USE_CUDA: {}, CUDA_DEVICE_NUM: {}'.format(USE_CUDA, CUDA_DEVICE_NUM))
    logger.info('Encoder: {}'.format(encoder))
    logger.info('Decoder: {}'.format(decoder))
    logger.info('Training Method: {}'.format(training_method))
    logger.info('Distribution: {}'.format(distribution))
    [logger.info(key + ": {}".format(encoder_params[key])) for key in encoder_params.keys()]
    [logger.info(key + ": {}".format(decoder_params[key])) for key in decoder_params.keys()]

def das_dennis_recursion(ref_dirs, ref_dir, n_partitions, beta, depth):
    if depth == len(ref_dir) - 1:
        ref_dir[depth] = beta / (1.0 * n_partitions)
        ref_dirs.append(ref_dir[None, :])
    else:
        for i in range(beta + 1):
            ref_dir[depth] = 1.0 * i / (1.0 * n_partitions)
            das_dennis_recursion(ref_dirs, np.copy(ref_dir), n_partitions, beta - i, depth + 1)

def das_dennis(n_partitions, n_dim):
    if n_partitions == 0:
        return np.full((1, n_dim), 1 / n_dim)
    else:
        ref_dirs = []
        ref_dir = np.full(n_dim, np.nan)
        das_dennis_recursion(ref_dirs, ref_dir, n_partitions, n_partitions, 0)
        return np.concatenate(ref_dirs, axis=0)

##########################################################################################
def main(n_sols = 10011):

    logger_start = time.time()
    
    if DEBUG_MODE:
        _set_debug_mode()
    
    create_logger(**logger_params)
    _print_config()

    tester = TSPTester(encoder=encoder,
                    decoder=decoder,
                    training_method=training_method,
                    distribution=distribution,
                    env_params=env_params,
                    encoder_params=encoder_params,
                    decoder_params=decoder_params,
                    tester_params=tester_params)
    
    #copy_all_src(tester.result_folder)
    
    shared_problem = get_random_problems(distribution, tester_params['test_episodes'], env_params['problem_size'], set_seed=True, seed=tester_params['seed'])

    if n_sols == 105:
        prefs = torch.Tensor(das_dennis(13,3))  # 105
    elif n_sols == 1035:
        prefs = torch.Tensor(das_dennis(44,3))   # 1035
    elif n_sols == 10011:
        prefs = torch.Tensor(das_dennis(140,3))   # 10011

    timer_start = time.time()

    sols = tester.run(shared_problem, prefs)
    
    timer_end = time.time()
    
    total_time = timer_end - timer_start
    
    ref = np.asarray(tester_params['reference'])
    hv = hvwfg.wfg(sols.astype(float), ref.astype(float))
    hv_ratio =  hv / (ref[0] * ref[1] * ref[2])
    
    print('Run Time(s): {:.4f}'.format(total_time))
    print('HV Ratio: {:.4f}'.format(hv_ratio))

##########################################################################################
if __name__ == "__main__":
    main()
