"""
@Description :   训练主函数
@Author      :   tqychy 
@Time        :   2025/01/02 16:04:41
"""
import argparse
import random
import warnings

import numpy as np
import torch

from config.default import cfg
from logger.logger import build_logger
from trainers import trainers


def set_seed(seed: int, cuda_deterministic=True):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if cuda_deterministic:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    else:
        torch.backends.cudnn.deterministic = False
        torch.backends.cudnn.benchmark = True


def main(cfg):
    logger, log_dir = build_logger(cfg)
    logger.info(cfg)
    set_seed(cfg.GLOBALS.SEED)
    type = cfg.TRAIN.TYPE
    trainer = trainers[type](log_dir, cfg, logger)
    trainer.train()


if __name__ == "__main__":
    warnings.filterwarnings("ignore")
    
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config_path",
        type=str,
        default="./config/classify/train.yaml"
    )
    args = parser.parse_args()
    cfg.merge_from_file(args.config_path)
    cfg.freeze()

    main(cfg)
