import sys
import traceback

from config import parse_args
from utils import *
import torch.distributed as dist

from train import Trainer
if __name__ == "__main__":
    args = parse_args()
    cfg = prepare_env(args, sys.argv)
    try:
        if cfg.train.ddp.istrue:
            # local_rank = torch.distributed.get_rank()
            local_rank = cfg.train.ddp.local_rank
            setup_seed(cfg.train.seed + local_rank)
        else:
            setup_seed(cfg.train.seed)
        trainer = Trainer(cfg)
        trainer.run()

    except (Exception, KeyboardInterrupt):
        print(traceback.format_exc())
        # if not cfg.val:
        key = input("是否删除当前训练数据：[任意键:删除；n:不删除]")
        if key!='n':
            if not cfg.train.ddp.istrue or dist.get_rank() == 0:
                clear_exp(cfg.common.exp_dir)