import __init__
import os, argparse, yaml, numpy as np
from torch import multiprocessing as mp
from examples.classification.train_nascl_modelnet40 import main as train
from examples.classification.pretrain import main as pretrain
from openpoints_student.utils import EasyConfig, dist_utils, find_free_port, generate_exp_directory, resume_exp_directory, Wandb


if __name__ == "__main__":
    parser = argparse.ArgumentParser('S3DIS scene segmentation training')
    parser.add_argument('--cfg', type=str, required=True, help='config file')
    parser.add_argument('--cfg_s', type=str, required=True, help='config file')
    parser.add_argument('--profile', action='store_true', default=False, help='set to True to profile speed')
    args, opts = parser.parse_known_args()
    cfg = EasyConfig()
    cfg_s = EasyConfig()
    cfg.load(args.cfg, recursive=True)
    cfg_s.load(args.cfg_s, recursive=True)
    cfg.update(opts)
    cfg_s.update(opts)

    if cfg.seed is None:
        cfg.seed = np.random.randint(1, 10000)
    if cfg_s.seed is None:
        cfg_s.seed = np.random.randint(1, 10000)

    # init distributed env first, since logger depends on the dist info.
    cfg.rank, cfg.world_size, cfg.distributed, cfg.mp = dist_utils.get_dist_info(cfg)
    cfg.sync_bn = cfg.world_size > 1
    cfg_s.rank, cfg_s.world_size, cfg_s.distributed, cfg_s.mp = dist_utils.get_dist_info(cfg_s)
    cfg_s.sync_bn = cfg_s.world_size > 1

    # init log dir
    cfg.task_name = args.cfg.split('.')[-2].split('/')[-2]
    cfg.exp_name = args.cfg.split('.')[-2].split('/')[-1]
    cfg_s.task_name = args.cfg_s.split('.')[-2].split('/')[-2]  # task/dataset name, \eg s3dis, modelnet40_cls
    cfg_s.cfg_basename = args.cfg_s.split('.')[-2].split('/')[-1]  # cfg_basename, \eg pointnext-xl
    tags = [
        cfg.task_name,  # task name (the folder of name under ./cfgs
        cfg.mode,
        cfg.exp_name,  # cfg file name
        f'ngpus{cfg.world_size}',
        f'seed{cfg.seed}',
    ]
    tags_s = [
        cfg_s.task_name,  # task name (the folder of name under ./cfgs
        cfg_s.mode,
        cfg_s.cfg_basename,  # cfg_s file name
        f'ngpus{cfg_s.world_size}',
    ]
    opt_list = [] # for checking experiment configs from logging file
    opt_list_s = []
    for i, opt in enumerate(opts):
        if 'rank' not in opt and 'dir' not in opt and 'root' not in opt and 'pretrain' not in opt and 'path' not in opt and 'wandb' not in opt and '/' not in opt:
            opt_list.append(opt)
    cfg.root_dir = os.path.join(cfg.root_dir, cfg.task_name)
    cfg.opts = '-'.join(opt_list)
    cfg_s.root_dir = os.path.join(cfg_s.root_dir, cfg_s.task_name)
    cfg_s.opts = '-'.join(opt_list_s)

    if cfg.mode in ['resume', 'val', 'test']:
        resume_exp_directory(cfg, pretrained_path=cfg.pretrained_path)
        cfg.wandb.tags = [cfg.mode]
    else:  # resume from the existing ckpt and reuse the folder.
        generate_exp_directory(cfg, tags, additional_id=os.environ.get('MASTER_PORT', None))
        cfg.wandb.tags = tags
    if cfg_s.mode in ['resume', 'val', 'test']:
        resume_exp_directory(cfg_s, pretrained_path=cfg_s.pretrained_path)
        cfg_s.wandb.tags = [cfg_s.mode]
    else:  # resume from the existing ckpt and reuse the folder.
        generate_exp_directory(cfg_s, tags, additional_id=os.environ.get('MASTER_PORT', None))
        cfg_s.wandb.tags = tags
    
    os.environ["JOB_LOG_DIR"] = cfg.log_dir
    os.environ["JOB_LOG_DIR_S"] = cfg_s.log_dir
    cfg_path = os.path.join(cfg.run_dir, "cfg.yaml")
    cfg_s_path = os.path.join(cfg_s.run_dir, "cfg.yaml")
    with open(cfg_path, 'w') as f:
        yaml.dump(cfg, f, indent=2)
        os.system('cp %s %s' % (args.cfg, cfg.run_dir))
    with open(cfg_s_path, 'w') as f:
        yaml.dump(cfg_s, f, indent=2)
        os.system('cp %s %s' % (args.cfg_s, cfg_s.run_dir))
    cfg.cfg_path = cfg_path
    cfg_s.cfg_path = cfg_s_path
    cfg.wandb.name = cfg.run_name
    cfg_s.wandb.name = cfg_s.run_name

    if cfg.mode == 'pretrain':
        main = pretrain
    else:
        main = train

    # multi processing.
    if cfg.mp:
        port = find_free_port()
        cfg.dist_url = f"tcp://localhost:{port}"
        print('using mp spawn for distributed training')
        mp.spawn(main, nprocs=cfg.world_size, args=(cfg, args.profile))
    else:
        main(0, cfg, cfg_s,profile=args.profile)
