# -*- coding: utf-8 -*-
import argparse

from common.log import access_log
from common.config import load_config

from hgtft.inference.inference_task import InferenceTaskManager
from hgtft.dataset.dataset_main import dataset_main
from hgtft.train.train_task import TrainTaskManager
from hgtft.finetuning.finetuning_task import FinetuningManager
from make_picture import make_html


import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'


def main(config_name):
    access_log.info(f'config: {config_name}')
    configs = load_config(config_name)
    if config_name == 'hgtft_data_process':
        dataset_main(config_name)
    elif config_name == 'hgtft_train':
        train = TrainTaskManager(configs)
        train.main()
    elif config_name == 'hgtft_finetuning':
        finetuning = FinetuningManager(configs)
        finetuning.main()
    else:
        evaluate = InferenceTaskManager(configs)
        evaluate.main()
        make_html(configs)




if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="hgtft")
    parser.add_argument('--config_name', type=str, default='hgtft_finetuning',
                        choices=['hgtft_train', 'hgtft_data_process', 'hgtft_finetuning', 'hgtft_inference'])
    args = parser.parse_args()
    main(args.config_name)
