import multiprocessing
import os
import sys
import argparse

root_path = os.path.abspath(__file__)
windows_path = root_path.split('\\')
windows_path = '/'.join(windows_path[:-4])
linux_path = root_path.split('/')
linux_path = '/'.join(linux_path[:-4])
sys.path.append(windows_path)
sys.path.append(linux_path)

from SourceCode.Factory import Factory


# CUDA_LAUNCH_BLOCKING=1

def parse_command_line(config):
    parser = argparse.ArgumentParser()
    parser.add_argument('--prod', action='store_true', default=False)
    args = parser.parse_args()
    prod = args.prod
    if prod:
        print('working in prod, set big batch size and eval gap:')
        config['train_config']['query_batch_size'] = 60000
        config['train_config']['support_batch_size'] = 60000
        config['logger_config']['eval_query_batch_size'] = 81000
        config['logger_config']['eval_support_batch_size'] = 81000
        config['train_config']['queue_size'] = 20
        config['logger_config']['flush_gap'] = 1
        config['logger_config']['eval_gap'] = 5000
        config['logger_config']['save_gap'] = 5000
        config['logger_config']['test_task_group_size'] = 5
        config['logger_config']['test_task_zip_param'] = [0.4, 0.6, 0.8]
        config['logger_config']['test_task_item_size'] = [5000, 10000, 40000, 80000]
    else:
        print('working in dev, set small hyperparameter for less memory consume')

    print('cuda num', config['train_config']['cuda_num'])
    print('queue_size', config['train_config']['queue_size'])
    print('eval_gap', config['logger_config']['eval_gap'])
    print('lr ', config['train_config']['lr'])
    print('support_set_item_upper:', config['data_config']['item_upper'])
    print('support_set_item_lower:', config['data_config']['item_lower'])

    if not prod:
        pass_cuda_tensor = False
        print('working in pc environment, passing cpu tensor')
    else:
        pass_cuda_tensor = True
        print('working in prod environment ,passing gpu tensor,attention: must in linux!!')
    return prod, pass_cuda_tensor


def init_config():
    data_config = {
        'dataset_path': None,
        "train_comment": "128KB_Baseline",
        "dataset_name": 'BasicSketch',
        'skew_lower': 1,
        'skew_upper': 10,
        'item_lower': 2,
        'item_upper': 60000,
        'zipf_param_upper': 0.8,
        'zipf_param_lower': 0.3,
    }
    pre_config = {
        'pre_ratio_list': [0.6, 0.2, 0.15],
        'pre_item_lower_list': [2, 2, 2],
        'pre_item_upper_list': [10000, 20000, 40000],
    }
    train_config = {
        "query_batch_size": 10000,
        "support_batch_size": 10000,
        "lr": 0.0005,
        "cuda_num": 3,
        "num_of_meta_task": 500000,
        'queue_size': 20,
    }
    logger_config = {
        "eval_query_batch_size": 10000,
        "eval_support_batch_size": 10000,
        "flush_gap": 1,
        "save_gap": 100,
        "eval_gap": 200,
        "test_task_group_size": 1,
        "test_task_item_size": [5000, 10000, 20000, 40000],
        "test_task_stream_length": None,
        "test_task_stream_length_ratio": None,
        "test_task_zip_param": [0.4, 0.6, 0.8,]
    }
    hidden_layer_config = {
        "source_embedding_hidden_layer_size": 32,
        "dest_embedding_hidden_layer_size": 32,
        "edge_embedding_hidden_layer_size": 16,
        "source_refined_hidden_layer_size": 32,
        "dest_refined_hidden_layer_size": 32,
        "weight_decode_hidden_layer_size": 64,
    }
    dim_config = {
        "dest_input_dim": 16,
        "source_input_dim": 16,
        "source_refined_dim": 16,
        "dest_refined_dim": 16,
        "source_embedding_dim": 8,
        "dest_embedding_dim": 8,
        "edge_embedding_dim": 16,
        "row_dim": 45,
        "col_dim": 45,
    }
    factory_config = {
        # Optional: LossFunc_for_Simple_AAE_and_ARE,LossFunc_for_Simple_MSE_and_ARE
        "loss_class": "LossFunc_for_Simple_MSE_and_ARE",
        # Optional: BasicMemoryMatrixAndCM12
        "memory_class": "BasicMemoryMatrixAndCM12",
        # Not None While "memory_calss" in {SparseMemoryMatrixAttentionSum,SparseMemoryMatrixSeperate,SparseMemoryMatrixSeperateAndCMRead}
        "sparse_degree": None,
        "decode_weight_class": "ResWeightDecodeNet",
        "exist_weight_class": None,
        # optional: BasicModel
        "model_class": "BasicModel",
        # optional: ScaleAttentionMatrix , AttentionMatrix
        "attention_class": "ScaleAttentionMatrix"
    }
    config = {
        "train_config": train_config,
        "factory_config": factory_config,
        "dim_config": dim_config,
        "hidden_layer_config": hidden_layer_config,
        "data_config": data_config,
        "logger_config": logger_config,
        "pre_config":pre_config,
    }
    return config


if __name__ == '__main__':
    config = init_config()
    prod, pass_cuda_tensor = parse_command_line(config)
    train_config = config['train_config']
    multiprocessing.set_start_method("spawn")
    MGS = Factory.init_basic_MGS(prod,config)
    print('train begin...')
    MGS.train(train_config["num_of_meta_task"]+100, train_config["support_batch_size"], train_config["query_batch_size"],
              pass_cuda_tensor=pass_cuda_tensor, queue_size=config['train_config']['queue_size'],pre_ratio_list=config['pre_config']['pre_ratio_list'],
              pre_item_lower_list=config['pre_config']['pre_item_lower_list'],pre_item_upper_list=config['pre_config']['pre_item_upper_list'])
    print('train end...')
