import _init_path
import argparse
import datetime
import glob
import os
import re
import time
from pathlib import Path

import numpy as np
import torch
from tensorboardX import SummaryWriter

from tools.eval_utils.inversion_4block_test_utils import eval_one_epoch
from pcdet.config import cfg, cfg_from_list, cfg_from_yaml_file, log_config_to_file
from pcdet.datasets import build_dataloader
from pcdet.models import build_network
from pcdet.utils import common_utils


def parse_config():
    parser = argparse.ArgumentParser(description="arg parser")
    parser.add_argument(
        "--cfg_file", type=str, default=None, help="specify the config for training"
    )

    parser.add_argument(
        "--batch_size",
        type=int,
        default=None,
        required=False,
        help="batch size for training",
    )
    parser.add_argument(
        "--workers", type=int, default=4, help="number of workers for dataloader"
    )
    parser.add_argument(
        "--extra_tag", type=str, default="default", help="extra tag for this experiment"
    )
    parser.add_argument(
        "--ckpt", type=str, default=None, help="checkpoint to start from"
    )
    parser.add_argument(
        "--launcher", choices=["none", "pytorch", "slurm"], default="none"
    )
    parser.add_argument(
        "--tcp_port", type=int, default=18888, help="tcp port for distrbuted training"
    )
    parser.add_argument(
        "--local_rank", type=int, default=0, help="local rank for distributed training"
    )
    parser.add_argument(
        "--set",
        dest="set_cfgs",
        default=None,
        nargs=argparse.REMAINDER,
        help="set extra config keys if needed",
    )

    parser.add_argument(
        "--max_waiting_mins", type=int, default=30, help="max waiting minutes"
    )
    parser.add_argument("--start_epoch", type=int, default=0, help="")
    parser.add_argument(
        "--eval_tag", type=str, default="default", help="eval tag for this experiment"
    )
    parser.add_argument(
        "--eval_all",
        action="store_true",
        default=False,
        help="whether to evaluate all checkpoints",
    )
    parser.add_argument(
        "--ckpt_dir",
        type=str,
        default=None,
        help="specify a ckpt directory to be evaluated if needed",
    )
    parser.add_argument("--save_to_file", action="store_true", default=False, help="")

    parser.add_argument("--start_layer", type=int, default=5, help="")

    args = parser.parse_args()

    cfg_from_yaml_file(args.cfg_file, cfg)
    cfg.TAG = Path(args.cfg_file).stem
    cfg.EXP_GROUP_PATH = "/".join(
        args.cfg_file.split("/")[1:-1]
    )  # remove 'cfgs' and 'xxxx.yaml'

    np.random.seed(1024)

    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs, cfg)

    return args, cfg


def eval_single_ckpt(
    detector_meanVFE, detector_backbone, inversion_model_out, inversion_model_4, inversion_model_3, inversion_model_2, test_loader, args, eval_output_dir, logger, epoch_id, dist_test=False, 
):
    # load checkpoint
    detector_backbone.load_params_from_file(
        filename=cfg.DETECTOR_BACKBONE.PREMODEL_ARGS["pretrained_path"], to_cpu=dist_test, logger=logger
    )
    inversion_model_out.load_params_from_file(
        filename=cfg.INVERSION_MODEL_out.PREMODEL_ARGS["pretrained_path"], to_cpu=dist_test, logger=logger
    )
    inversion_model_4.load_params_from_file(
        filename=cfg.INVERSION_MODEL_4.PREMODEL_ARGS["pretrained_path"], to_cpu=dist_test, logger=logger
    )
    inversion_model_3.load_params_from_file(
        filename=cfg.INVERSION_MODEL_3.PREMODEL_ARGS["pretrained_path"], to_cpu=dist_test, logger=logger
    )
    inversion_model_2.load_params_from_file(
        filename=cfg.INVERSION_MODEL_2.PREMODEL_ARGS["pretrained_path"], to_cpu=dist_test, logger=logger
    )

    
    detector_backbone.cuda()
    inversion_model_out.cuda()
    inversion_model_4.cuda()
    inversion_model_3.cuda()
    inversion_model_2.cuda()

    # start evaluation
    eval_one_epoch(
        cfg,
        detector_meanVFE,
        detector_backbone,
        inversion_model_out,
        inversion_model_4,
        inversion_model_3,
        inversion_model_2,
        test_loader,
        epoch_id,
        logger,
        dist_test=dist_test,
        result_dir=eval_output_dir,
        save_to_file=args.save_to_file,
        start_layer=args.start_layer,
    )



def get_no_evaluated_ckpt(ckpt_dir, ckpt_record_file, args):
    ckpt_list = glob.glob(os.path.join(ckpt_dir, "*checkpoint_epoch_*.pth"))
    ckpt_list.sort(key=os.path.getmtime)
    evaluated_ckpt_list = [
        float(x.strip()) for x in open(ckpt_record_file, "r").readlines()
    ]

    for cur_ckpt in ckpt_list:
        num_list = re.findall("checkpoint_epoch_(.*).pth", cur_ckpt)
        if num_list.__len__() == 0:
            continue

        epoch_id = num_list[-1]
        if "optim" in epoch_id:
            continue
        if (
            float(epoch_id) not in evaluated_ckpt_list
            and int(float(epoch_id)) >= args.start_epoch
        ):
            return epoch_id, cur_ckpt
    return -1, None


def main():
    args, cfg = parse_config()
    if args.launcher == "none":
        dist_test = False
        total_gpus = 1
    else:
        total_gpus, cfg.LOCAL_RANK = getattr(
            common_utils, "init_dist_%s" % args.launcher
        )(args.tcp_port, args.local_rank, backend="nccl")
        dist_test = True

    if args.batch_size is None:
        args.batch_size = cfg.OPTIMIZATION.BATCH_SIZE_PER_GPU
    else:
        assert (
            args.batch_size % total_gpus == 0
        ), "Batch size should match the number of gpus"
        args.batch_size = args.batch_size // total_gpus

    output_dir = cfg.ROOT_DIR / "output" / cfg.EXP_GROUP_PATH / cfg.TAG / args.extra_tag
    output_dir.mkdir(parents=True, exist_ok=True)

    eval_output_dir = output_dir / "eval"

    if not args.eval_all:
        num_list = re.findall(r"\d+", args.ckpt) if args.ckpt is not None else []
        epoch_id = num_list[-1] if num_list.__len__() > 0 else "no_number"
        eval_output_dir = (
            eval_output_dir
            / ("epoch_%s" % epoch_id)
            / cfg.DATA_CONFIG.DATA_SPLIT["test"]
        )
    else:
        eval_output_dir = eval_output_dir / "eval_all_default"

    if args.eval_tag is not None:
        eval_output_dir = eval_output_dir / args.eval_tag

    eval_output_dir.mkdir(parents=True, exist_ok=True)
    log_file = eval_output_dir / (
        "log_eval_%s.txt" % datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    )
    logger = common_utils.create_logger(log_file, rank=cfg.LOCAL_RANK)

    # log to file
    logger.info("**********************Start logging**********************")
    gpu_list = (
        os.environ["CUDA_VISIBLE_DEVICES"]
        if "CUDA_VISIBLE_DEVICES" in os.environ.keys()
        else "ALL"
    )
    logger.info("CUDA_VISIBLE_DEVICES=%s" % gpu_list)

    if dist_test:
        logger.info("total_batch_size: %d" % (total_gpus * args.batch_size))
    for key, val in vars(args).items():
        logger.info("{:16} {}".format(key, val))
    log_config_to_file(cfg, logger=logger)

    ckpt_dir = args.ckpt_dir if args.ckpt_dir is not None else output_dir / "ckpt"

    test_set, test_loader, sampler = build_dataloader(
        dataset_cfg=cfg.DATA_CONFIG,
        class_names=cfg.CLASS_NAMES,
        batch_size=args.batch_size,
        dist=dist_test,
        workers=args.workers,
        logger=logger,
        training=False,
    )

    detector_meanVFE = build_network(
        model_cfg=cfg.DETECTOR_MEANVFE, num_class=len(cfg.CLASS_NAMES), dataset=test_set
    )
    detector_backbone = build_network(
        model_cfg=cfg.DETECTOR_BACKBONE, num_class=len(cfg.CLASS_NAMES), dataset=test_set
    )

    from pcdet.models.inversion_model import Inversion_Model_Conv2_CLS as Inversion_Model_Conv2_to_Conv1
    from pcdet.models.inversion_model import Inversion_Model_Conv3_to_Conv2_CLS_REG as Inversion_Model_Conv3_to_Conv2
    if cfg.INVERSION_MODEL_4.LAYER == 'xconv4_to_xconv3_cls_reg':
        from pcdet.models.inversion_model import Inversion_Model_Conv4_to_Conv3_CLS_REG as Inversion_Model_Conv4_to_Conv3   
    elif cfg.INVERSION_MODEL_4.LAYER == 'xconv4_to_xconv3_cls_reg_voxelresbackbone':
        from pcdet.models.inversion_model import Inversion_Model_Conv4_to_Conv3_CLS_REG_Voxelresbackbone as Inversion_Model_Conv4_to_Conv3    
    if cfg.INVERSION_MODEL_out.LAYER == 'xconvout_to_xconv4_cls_reg':
        from pcdet.models.inversion_model import Inversion_Model_Convout_to_Conv4_CLS_REG as Inversion_Model_ConvOut_to_Conv4  
    elif cfg.INVERSION_MODEL_out.LAYER == 'xconvout_to_xconv4_cls_reg_voxelresbackbone':
        from pcdet.models.inversion_model import Inversion_Model_Convout_to_Conv4_CLS_REG_Voxelresbackbone as Inversion_Model_ConvOut_to_Conv4   

    inversion_model_2 = Inversion_Model_Conv2_to_Conv1()
    inversion_model_3 = Inversion_Model_Conv3_to_Conv2()
    inversion_model_4 = Inversion_Model_Conv4_to_Conv3()
    inversion_model_out = Inversion_Model_ConvOut_to_Conv4()

    with torch.no_grad():
        eval_single_ckpt(
            detector_meanVFE,
            detector_backbone,
            inversion_model_out,
            inversion_model_4,
            inversion_model_3,
            inversion_model_2,
            test_loader,
            args,
            eval_output_dir,
            logger,
            epoch_id,
            dist_test=dist_test,
        )


if __name__ == "__main__":
    main()