import _init_path
import os
from pathlib import Path
import argparse
import datetime
import glob
import wandb

import torch
import torch.nn as nn
from tensorboardX import SummaryWriter
import torch.distributed as dist
from test import repeat_eval_ckpt

from pcdet.config import cfg, cfg_from_list, cfg_from_yaml_file, log_config_to_file
from pcdet.datasets import build_dataloader
from pcdet.models.model_utils.dsnorm import DSNorm
from pcdet.models import build_network, model_fn_decorator, model_fn_decorator_for_mt
from pcdet.utils import common_utils
from train_utils.optimization import build_optimizer, build_scheduler
from train_utils.train_utils import train_model
from tools.train_utils.train_st_utils import train_model_st
# from train_utils.train_mt_utils import train_model_mt
# from train_utils.train_mine_utils import train_model_st as train_model_mine
from tools.train_utils.train_st_utils_redb import train_model_st_redb, save_scratch_model

from tools.train_utils.train_st_utils_pere import train_model_pere
def parse_config():
    parser = argparse.ArgumentParser(description='arg parser')
    parser.add_argument('--cfg_file', type=str, default=None, help='specify the config for training')

    parser.add_argument('--batch_size', type=int, default=4, required=False, help='batch size for training')
    parser.add_argument('--epochs', type=int, default=None, required=False, help='number of epochs to train for')
    parser.add_argument('--workers', type=int, default=4, help='number of workers for dataloader')
    parser.add_argument('--extra_tag', type=str, default='default', help='extra tag for this experiment')
    parser.add_argument('--ckpt', type=str, default=None, help='checkpoint to start from')
    parser.add_argument('--pretrained_model', type=str, default=None, help='pretrained_model')
    parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none')
    parser.add_argument('--tcp_port', type=int, default=18888, help='tcp port for distrbuted training')
    parser.add_argument('--sync_bn', action='store_true', default=True, help='whether to use sync bn')
    parser.add_argument('--no_sync_bn', action='store_true', default=False, help='whether to do not use sync bn')
    parser.add_argument('--fix_random_seed', action='store_true', default=True, help='')
    parser.add_argument('--ckpt_save_interval', type=int, default=1, help='number of training epochs')
    parser.add_argument('--local_rank', type=int, default=0, help='local rank for distributed training')
    parser.add_argument('--max_ckpt_save_num', type=int, default=1000, help='max number of saved checkpoint')
    parser.add_argument('--merge_all_iters_to_one_epoch', action='store_true', default=False, help='')
    parser.add_argument('--set', dest='set_cfgs', default=None, nargs=argparse.REMAINDER,
                        help='set extra config keys if needed')

    parser.add_argument('--max_waiting_mins', type=int, default=0, help='max waiting minutes')
    parser.add_argument('--start_epoch', type=int, default=0, help='')
    parser.add_argument('--save_to_file', action='store_true', default=False, help='')
    parser.add_argument('--eval_fov_only', action='store_true', default=False, help='')
    parser.add_argument('--eval_src', action='store_true', default=False, help='')
    parser.add_argument('--num_epochs_to_eval', type=int, default=1000, help='number of checkpoints to be evaluated')

    args = parser.parse_args()

    cfg_from_yaml_file(args.cfg_file, cfg)
    cfg.TAG = Path(args.cfg_file).stem
    cfg.EXP_GROUP_PATH = '/'.join(args.cfg_file.split('/')[1:-1])  # remove 'cfgs' and 'xxxx.yaml'

    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs, cfg)

    return args, cfg


def init_wandb(args, cfg):
    os.environ["WANDB__SERVICE_WAIT"] = "300"
    adaptation_task = '{}2{}'.format(
        cfg.DATA_CONFIG.DATASET.split('Data')[0][0],
        cfg.DATA_CONFIG_TAR.DATASET.split('Data')[0][0])

    proj_name = "".format(adaptation_task)
    entity_name = "" # Please Put your entity name in WANDB
    wandb.init(project=proj_name, entity=entity_name, config=cfg)
    wandb.config.update(args)
    train_mode = 'ST3D' \
        if cfg.get('SELF_TRAIN', None) else 'PT'

    # For self train, we want split ST3D and
    if cfg.get('SELF_TRAIN', None):
        train_mode = 'ST3D' \
            if cfg.SELF_TRAIN.MEMORY_ENSEMBLE.ENABLED is True \
            else 'ST'

    if 'random_object_scaling' not in cfg.DATA_CONFIG.DATA_AUGMENTOR.DISABLE_AUG_LIST:
        aug_type = 'ROS'
    elif 'normalize_object_size' not in cfg.DATA_CONFIG.DATA_AUGMENTOR.DISABLE_AUG_LIST:
        aug_type = 'SN'
    else:
        aug_type = 'SO'  # source only
    file_name = args.cfg_file.split('/')[-1].split('.')[0]
    train_mode_and_aug = train_mode + '-' + aug_type
    run_name = train_mode_and_aug + '-' + \
               cfg.MODEL.NAME + '-' + file_name + "-" + args.extra_tag

    wandb.run.name = run_name


def main():
    import torch.multiprocessing as mp

    # 设置共享密钥
    # mp.current_process().authkey = b'secret_shared_key'
    import warnings
    warnings.filterwarnings("ignore", message="torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument.")
    warnings.filterwarnings(
        "ignore", 
        message="The parameter 'weights' should be normalized, but got sum(weights) = 1.00000"
    )
    import warnings
    warnings.filterwarnings(
        "ignore", 
        message="Default grid_sample and affine_grid behavior has changed to align_corners=False"
    )

    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False

    #   if 'motorcycle' in info['gt_names']: # For DA task only
    warnings.filterwarnings(
        "ignore",
        message="elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison"
    )
    args, cfg = parse_config()
    if args.launcher == 'none':
        dist_train = False
        total_gpus = 1
    else:
        total_gpus, cfg.LOCAL_RANK = getattr(common_utils, 'init_dist_%s' % args.launcher)(
            args.tcp_port, args.local_rank, backend='nccl'
        )
        dist_train = True

    # if args.batch_size is None:
    #     args.batch_size = cfg.OPTIMIZATION.BATCH_SIZE_PER_GPU
    # else:
    assert args.batch_size % total_gpus == 0, 'Batch size should match the number of gpus'
    args.batch_size = args.batch_size // total_gpus

    args.epochs = cfg.OPTIMIZATION.NUM_EPOCHS if args.epochs is None else args.epochs

    if args.fix_random_seed:
        common_utils.set_random_seed(666)
    if args.no_sync_bn:
        args.sync_bn = False
    output_dir = cfg.ROOT_DIR / 'output' / cfg.EXP_GROUP_PATH / cfg.TAG / args.extra_tag
    ckpt_dir = output_dir / 'ckpt'
    ps_label_dir = output_dir / 'ps_label'
    ps_label_dir.mkdir(parents=True, exist_ok=True)
    output_dir.mkdir(parents=True, exist_ok=True)
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    log_file = output_dir / ('log_train_%s.txt' % datetime.datetime.now().strftime('%Y%m%d-%H%M%S'))
    logger = common_utils.create_logger(log_file, rank=cfg.LOCAL_RANK)

    # log to file
    logger.info('**********************Start logging**********************')
    gpu_list = os.environ['CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ.keys() else 'ALL'
    logger.info('CUDA_VISIBLE_DEVICES=%s' % gpu_list)

    if dist_train:
        logger.info('total_batch_size: %d' % (total_gpus * args.batch_size))
    for key, val in vars(args).items():
        logger.info('{:16} {}'.format(key, val))
    log_config_to_file(cfg, logger=logger)
    if cfg.LOCAL_RANK == 0:
        os.system('cp %s %s' % (args.cfg_file, output_dir))

    tb_log = SummaryWriter(log_dir=str(output_dir / 'tensorboard')) if cfg.LOCAL_RANK == 0 else None
    source_batch_size = 1 if cfg.get('SELF_TRAIN', None) and cfg.SELF_TRAIN.get('REDB', None) else args.batch_size
    # -----------------------create dataloader & network & optimizer---------------------------
    source_set, source_loader, source_sampler = build_dataloader(
        dataset_cfg=cfg.DATA_CONFIG,
        class_names=cfg.CLASS_NAMES,
        batch_size=source_batch_size,
        dist=dist_train, workers=args.workers,
        logger=logger,
        training=True,
        merge_all_iters_to_one_epoch=args.merge_all_iters_to_one_epoch,
        total_epochs=args.epochs
    )
    if cfg.get('SELF_TRAIN', None) and cfg.SELF_TRAIN.get('MINE', None):
        source_set_detect, source_loader_detect, source_sampler_detect = build_dataloader(
            dataset_cfg=cfg.DATA_CONFIG,
            class_names=cfg.CLASS_NAMES,
            batch_size=1,
            dist=dist_train, workers=args.workers,
            logger=logger,
            training=False
        )

    if cfg.get('SELF_TRAIN', None):
        target_set, target_loader, target_sampler = build_dataloader(
            cfg.DATA_CONFIG_TAR, cfg.DATA_CONFIG_TAR.CLASS_NAMES, args.batch_size,
            dist_train, workers=args.workers, logger=logger, training=True
        )
    else:
        target_set = target_loader = target_sampler = None

    # logger.info('Creating model...')
    # logger.info(source_set.point_feature_encoder.num_point_features)
    # logger.info(source_set.point_feature_encoder.num_point_features)

    model = build_network(model_cfg=cfg.MODEL, num_class=len(cfg.CLASS_NAMES),
                        dataset=source_set)

    # logger.info('After creating model...')
    # logger.info(source_set.point_feature_encoder.num_point_features)
    # logger.info(source_set.point_feature_encoder.num_point_features)

    if cfg.get('SELF_TRAIN', None) and cfg.SELF_TRAIN.get('MINE', None):
        source_model = build_network(model_cfg=cfg.MODEL, num_class=len(cfg.CLASS_NAMES),
                            dataset=source_set_detect)
        source_model.cuda()
        # if dist_train:
        #     source_model = nn.parallel.DistributedDataParallel(model, device_ids=[
        #         cfg.LOCAL_RANK % torch.cuda.device_count()])
        source_model.eval()

    ema_model = None
    if cfg.MODEL.NAME.endswith('MT') or cfg.MODEL.get('EMA_MODEL_ALPHA', None):
    #     ema_model = build_network(model_cfg=cfg.MODEL, num_class=len(cfg.CLASS_NAMES), dataset=source_set)
    #     for param in ema_model.parameters():
    #         param.detach_()
        # 复制一个模型作为 EMA 模型
        import copy
        ema_model = copy.deepcopy(model)
        # 设置 EMA 模型参数不计算梯度
        for param in ema_model.parameters():
            param.requires_grad_(False)
    # logger.info('After creating model...')
    # logger.info(source_set.point_feature_encoder.num_point_features)
    # logger.info(source_set.point_feature_encoder.num_point_features)
    if args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        if ema_model:
            ema_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(ema_model)
    if cfg.get('SELF_TRAIN', None) and cfg.SELF_TRAIN.get('DSNORM', None):
        model = DSNorm.convert_dsnorm(model)
        if cfg.get('SELF_TRAIN', None) and cfg.SELF_TRAIN.get('MINE', None):
            source_model = DSNorm.convert_dsnorm(source_model)
        if ema_model:
            ema_model = DSNorm.convert_dsnorm(ema_model)

    model.cuda()
    if ema_model:
        ema_model.cuda()
    if cfg.get('SELF_TRAIN', None) and cfg.SELF_TRAIN.get('REDB', None):
        if cfg.SELF_TRAIN.LOAD_SCRATCH_AFTER_PSEUDO_LABELING or \
                cfg.SELF_TRAIN.LOAD_OPTIMIZER_AFTER_PSEUDO_LABELING:
            save_scratch_model(ckpt_dir, model)

    optimizer = build_optimizer(model, cfg.OPTIMIZATION)
    # load checkpoint if it is possible
    start_epoch = it = 0
    last_epoch = -1
    if args.pretrained_model is not None:
        model.load_params_from_file(filename=args.pretrained_model, to_cpu=dist, logger=logger)
        if cfg.get('SELF_TRAIN', None) and cfg.SELF_TRAIN.get('MINE', None):
            source_model.load_params_from_file(filename=args.pretrained_model, to_cpu=dist,
                                                logger=logger)
        if ema_model:
            ema_model.load_params_from_file(filename=args.pretrained_model, to_cpu=dist, logger=logger)

    if args.ckpt is not None:
        if ema_model:
            _, _ = ema_model.load_params_with_optimizer(args.ckpt, to_cpu=dist, optimizer=optimizer, logger=logger)
        it, start_epoch = model.load_params_with_optimizer(args.ckpt, to_cpu=dist, optimizer=optimizer, logger=logger)
        last_epoch = start_epoch + 1
    else:
        ckpt_list = glob.glob(str(ckpt_dir / '*checkpoint_epoch_*.pth'))
        if len(ckpt_list) > 0:
            ckpt_list.sort(key=os.path.getmtime)
            if ema_model:
                _, _ = ema_model.load_params_with_optimizer(
                    ckpt_list[-1], to_cpu=dist, optimizer=optimizer, logger=logger
                )
            it, start_epoch = model.load_params_with_optimizer(
                ckpt_list[-1], to_cpu=dist, optimizer=optimizer, logger=logger
            )
            last_epoch = start_epoch + 1

    model.train()  # before wrap to DistributedDataParallel to support fixed some parameters
    if ema_model:
        ema_model.train()
    if dist_train:
        if cfg.get('SELF_TRAIN', None) and cfg.SELF_TRAIN.get('Triplet', None):
            model = nn.parallel.DistributedDataParallel(model, device_ids=[cfg.LOCAL_RANK % torch.cuda.device_count()], broadcast_buffers=False, find_unused_parameters=True)
        else:
            model = nn.parallel.DistributedDataParallel(model, device_ids=[cfg.LOCAL_RANK % torch.cuda.device_count()],find_unused_parameters=True)
        if cfg.get('SELF_TRAIN', None) and cfg.SELF_TRAIN.get('MINE', None):
            source_model = nn.parallel.DistributedDataParallel(source_model, device_ids=[cfg.LOCAL_RANK % torch.cuda.device_count()])
    logger.info(model)
    if ema_model:
        logger.info(ema_model)

    if cfg.get('SELF_TRAIN', None):
        total_iters_each_epoch = len(target_loader) if not args.merge_all_iters_to_one_epoch \
                                            else len(target_loader) // args.epochs
    else:
        total_iters_each_epoch = len(source_loader) if not args.merge_all_iters_to_one_epoch \
            else len(source_loader) // args.epochs

    lr_scheduler, lr_warmup_scheduler = build_scheduler(
        optimizer, total_iters_each_epoch=total_iters_each_epoch, total_epochs=args.epochs,
        last_epoch=last_epoch, optim_cfg=cfg.OPTIMIZATION
    )

    # select proper trainer
    if cfg.get('SELF_TRAIN', None):
    #     train_func = train_model_mt if ema_model else train_model_st 
    # else:
        train_func = train_model_st
        # train_func = train_model_mt
    else:
        train_func = train_model

    if cfg.get('SELF_TRAIN', None) and cfg.SELF_TRAIN.get('REDB', None):
        train_func = train_model_st_redb
    
    if cfg.get('SELF_TRAIN', None) and cfg.SELF_TRAIN.get('Triplet', None):
        train_func = train_model_pere
    
    # if cfg.get('SELF_TRAIN', None) and cfg.SELF_TRAIN.get('MINE', None):
    #     train_func = train_model_mine

    if tb_log is not None:
        init_wandb(args, cfg)
    # -----------------------start training---------------------------
    logger.info('**********************Start training %s/%s(%s)**********************'
                % (cfg.EXP_GROUP_PATH, cfg.TAG, args.extra_tag))
    train_func(
        model,
        optimizer,
        source_loader,
        target_loader,
        # model_func=model_fn_decorator_for_mt() if ema_model else model_fn_decorator(),
        model_func=model_fn_decorator(),
        lr_scheduler=lr_scheduler,
        optim_cfg=cfg.OPTIMIZATION,
        start_epoch=start_epoch,
        total_epochs=args.epochs,
        start_iter=it,
        rank=cfg.LOCAL_RANK,
        tb_log=tb_log,
        ckpt_save_dir=ckpt_dir,
        ps_label_dir=ps_label_dir,
        source_sampler=source_sampler,
        target_sampler=target_sampler,
        lr_warmup_scheduler=lr_warmup_scheduler,
        ckpt_save_interval=args.ckpt_save_interval,
        max_ckpt_save_num=args.max_ckpt_save_num,
        merge_all_iters_to_one_epoch=args.merge_all_iters_to_one_epoch,
        logger=logger,
        ema_model=ema_model,
        pretrained=args.pretrained_model,
        # source_loader_detect=source_loader_detect if cfg.get('SELF_TRAIN', None) and cfg.SELF_TRAIN.get('MINE', None) else None,
        # source_sampler_detect=source_sampler_detect if cfg.get('SELF_TRAIN', None) and cfg.SELF_TRAIN.get('MINE', None) else None,
        # source_model=source_model if cfg.get('SELF_TRAIN', None) and cfg.SELF_TRAIN.get('MINE', None) else None,
        # dist=dist_train
    )

    logger.info('**********************End training %s/%s(%s)**********************\n\n\n'
                % (cfg.EXP_GROUP_PATH, cfg.TAG, args.extra_tag))
    wandb.finish()

    # debug eval 时候dataset正常返回
    if cfg.get('SELF_TRAIN', None) and cfg.SELF_TRAIN.get('EMA_TEACHER', None):
        cfg.SELF_TRAIN.EMA_TEACHER = False

    logger.info('**********************Start evaluation %s/%s(%s)**********************' %
                (cfg.EXP_GROUP_PATH, cfg.TAG, args.extra_tag))

    if args.eval_fov_only:
        cfg.DATA_CONFIG_TAR.FOV_POINTS_ONLY = True

    if cfg.get('DATA_CONFIG_TAR', None) and not args.eval_src:
        test_set, test_loader, sampler = build_dataloader(
            dataset_cfg=cfg.DATA_CONFIG_TAR,
            class_names=cfg.DATA_CONFIG_TAR.CLASS_NAMES,
            batch_size=args.batch_size,
            dist=dist_train, workers=args.workers, logger=logger, training=False
        )
    else:
        test_set, test_loader, sampler = build_dataloader(
            dataset_cfg=cfg.DATA_CONFIG,
            class_names=cfg.CLASS_NAMES,
            batch_size=args.batch_size,
            dist=dist_train, workers=args.workers, logger=logger, training=False
        )

    eval_output_dir = output_dir / 'eval' / 'eval_with_train'
    eval_output_dir.mkdir(parents=True, exist_ok=True)
    # Only evaluate the last args.num_epochs_to_eval epochs
    args.start_epoch = max(args.epochs - args.num_epochs_to_eval, 0)

    repeat_eval_ckpt(
        model.module if dist_train else model,
        test_loader, args, eval_output_dir, logger, ckpt_dir,
        dist_test=dist_train
    )
    logger.info('**********************End evaluation %s/%s(%s)**********************' %
                (cfg.EXP_GROUP_PATH, cfg.TAG, args.extra_tag))


if __name__ == '__main__':
    main()
