# Copyright (c) OpenMMLab. All rights reserved.
import copy as cp
import os.path as osp

import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (
    DistSamplerSeedHook,
    EpochBasedRunner,
    OptimizerHook,
    build_optimizer,
    get_dist_info,
)
from mmcv.runner.hooks import Fp16OptimizerHook

from DetTRC.utill.PrepareTrainClipHook import PrepareTrainClipHook
from ..core import (
    DistEvalHook,
    EvalHook,
    OmniSourceDistSamplerSeedHook,
    OmniSourceRunner,
)
from ..datasets import build_dataloader, build_dataset
from ..utils import PreciseBNHook, get_root_logger
from .test import multi_gpu_test


def train_model(
    model,
    dataset,
    cfg,
    distributed=False,
    validate=False,
    test=dict(test_best=False, test_last=False),
    timestamp=None,
    meta=None,
):
    """Train model entry function.

    Args:
        model (nn.Module): The model to be trained.
        dataset (:obj:`Dataset`): Train dataset.
        cfg (dict): The config dict for training.
        distributed (bool): Whether to use distributed training.
            Default: False.
        validate (bool): Whether to do evaluation. Default: False.
        test (dict): The testing option, with two keys: test_last & test_best.
            The value is True or False, indicating whether to test the
            corresponding checkpoint.
            Default: dict(test_best=False, test_last=False).
        timestamp (str | None): Local time for runner. Default: None.
        meta (dict | None): Meta dict to record some important information.
            Default: None
    """
    logger = get_root_logger(log_level=cfg.log_level)

    # prepare data loaders
    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]

    dataloader_setting = dict(
        videos_per_gpu=cfg.data.get("videos_per_gpu", 1),
        workers_per_gpu=cfg.data.get("workers_per_gpu", 1),
        persistent_workers=cfg.data.get("persistent_workers", False),
        num_gpus=len(cfg.gpu_ids),
        dist=distributed,
        seed=cfg.seed,
    )
    dataloader_setting = dict(
        dataloader_setting, **cfg.data.get("train_dataloader", {})
    )
    # modified a bug for mmaction2.
    val_dataloader_setting = dict(
        dataloader_setting, **cfg.data.get("val_dataloader", {})
    )

    if cfg.omnisource:
        # The option can override videos_per_gpu
        train_ratio = cfg.data.get("train_ratio", [1] * len(dataset))
        omni_videos_per_gpu = cfg.data.get("omni_videos_per_gpu", None)
        if omni_videos_per_gpu is None:
            dataloader_settings = [dataloader_setting] * len(dataset)
        else:
            dataloader_settings = []
            for videos_per_gpu in omni_videos_per_gpu:
                this_setting = cp.deepcopy(dataloader_setting)
                this_setting["videos_per_gpu"] = videos_per_gpu
                dataloader_settings.append(this_setting)
        data_loaders = [
            build_dataloader(ds, **setting)
            for ds, setting in zip(dataset, dataloader_settings)
        ]

    else:
        if len(dataset) == 1:
            data_loaders = [
                build_dataloader(ds, **dataloader_setting) for ds in dataset
            ]
        else:
            data_loaders = [
                build_dataloader(dataset[0], **dataloader_setting),
                build_dataloader(dataset[1], **val_dataloader_setting),
            ]

    # put model on gpus
    if distributed:
        find_unused_parameters = cfg.get("find_unused_parameters", False)
        # Sets the `find_unused_parameters` parameter in
        # torch.nn.parallel.DistributedDataParallel
        model = MMDistributedDataParallel(
            model.cuda(),
            device_ids=[torch.cuda.current_device()],
            broadcast_buffers=False,
            find_unused_parameters=find_unused_parameters,
        )
    else:
        model = MMDataParallel(model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)

    # build runner
    # optimizer = build_optimizer(model, cfg.optimizer)

    # modified the learning rate
    all_params = model.parameters()
    weight_params = []
    for pname, p in model.named_parameters():
        if pname.endswith("sampling_offsets.weight") or pname.endswith(
            "atten_weight.weight"
        ):
            weight_params += [p]

    params_id = []
    params_id = list(map(id, weight_params))
    other_params = list(filter(lambda p: id(p) not in params_id, all_params))
    learning_rate = cfg.optimizer["lr"]
    optimizer = torch.optim.AdamW(
        [
            {"params": weight_params, "lr": 0.1 * learning_rate},
            {"params": other_params},
        ],
        lr=learning_rate,
    )

    Runner = OmniSourceRunner if cfg.omnisource else EpochBasedRunner

    # modified optimizer for different learning rate

    # runner = Runner(
    #     model,
    #     optimizer=optimizer,
    #     work_dir=cfg.work_dir,
    #     logger=logger,
    #     meta=meta)
    runner = Runner(
        model, optimizer=None, work_dir=cfg.work_dir, logger=logger, meta=meta
    )
    runner.optimizer = optimizer
    # an ugly workaround to make .log and .log.json filenames the same
    runner.timestamp = timestamp

    # fp16 setting
    fp16_cfg = cfg.get("fp16", None)
    if fp16_cfg is not None:
        optimizer_config = Fp16OptimizerHook(
            **cfg.optimizer_config, **fp16_cfg, distributed=distributed
        )
    elif distributed and "type" not in cfg.optimizer_config:
        optimizer_config = OptimizerHook(**cfg.optimizer_config)
    else:
        optimizer_config = cfg.optimizer_config

    # register hooks
    runner.register_training_hooks(
        cfg.lr_config,
        optimizer_config,
        cfg.checkpoint_config,
        cfg.log_config,
        cfg.get("momentum_config", None),
    )
    if cfg.get("temporal_aug_config", None) == "part_exchange_and_accelerate":
        runner.register_hook()
    if distributed:
        if cfg.omnisource:
            runner.register_hook(OmniSourceDistSamplerSeedHook())
        else:
            runner.register_hook(DistSamplerSeedHook())

    # precise bn setting
    if cfg.get("precise_bn", False):
        precise_bn_dataset = build_dataset(cfg.data.train)
        dataloader_setting = dict(
            videos_per_gpu=cfg.data.get("videos_per_gpu", 1),
            workers_per_gpu=1,  # save memory and time
            persistent_workers=cfg.data.get("persistent_workers", False),
            num_gpus=len(cfg.gpu_ids),
            dist=distributed,
            seed=cfg.seed,
        )
        data_loader_precise_bn = build_dataloader(
            precise_bn_dataset, **dataloader_setting
        )
        precise_bn_hook = PreciseBNHook(data_loader_precise_bn, **cfg.get("precise_bn"))
        runner.register_hook(precise_bn_hook)

    runner.register_hook(PrepareTrainClipHook(data_loaders[0]))

    if validate:
        eval_cfg = cfg.get("evaluation", {})
        val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
        dataloader_setting = dict(
            videos_per_gpu=cfg.data.get("videos_per_gpu", 1),
            workers_per_gpu=cfg.data.get("workers_per_gpu", 1),
            persistent_workers=cfg.data.get("persistent_workers", False),
            # cfg.gpus will be ignored if distributed
            num_gpus=len(cfg.gpu_ids),
            dist=distributed,
            shuffle=False,
        )
        dataloader_setting = dict(
            dataloader_setting, **cfg.data.get("val_dataloader", {})
        )
        val_dataloader = build_dataloader(val_dataset, **dataloader_setting)
        eval_hook = (
            DistEvalHook(val_dataloader, **eval_cfg)
            if distributed
            else EvalHook(val_dataloader, **eval_cfg)
        )
        runner.register_hook(eval_hook, priority=80)

    # add

    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
    runner_kwargs = dict()
    if cfg.omnisource:
        runner_kwargs = dict(train_ratio=train_ratio)
    runner.run(data_loaders, cfg.workflow, cfg.total_epochs, **runner_kwargs)

    if test["test_last"] or test["test_best"]:
        best_ckpt_path = None
        if test["test_best"]:
            if hasattr(eval_hook, "best_ckpt_path"):
                best_ckpt_path = eval_hook.best_ckpt_path

            if best_ckpt_path is None or not osp.exists(best_ckpt_path):
                test["test_best"] = False
                if best_ckpt_path is None:
                    runner.logger.info(
                        "Warning: test_best set as True, but "
                        "is not applicable "
                        "(eval_hook.best_ckpt_path is None)"
                    )
                else:
                    runner.logger.info(
                        "Warning: test_best set as True, but "
                        "is not applicable (best_ckpt "
                        f"{best_ckpt_path} not found)"
                    )
                if not test["test_last"]:
                    return

        test_dataset = build_dataset(cfg.data.test, dict(test_mode=True))
        gpu_collect = cfg.get("evaluation", {}).get("gpu_collect", False)
        tmpdir = cfg.get("evaluation", {}).get("tmpdir", osp.join(cfg.work_dir, "tmp"))
        dataloader_setting = dict(
            videos_per_gpu=cfg.data.get("videos_per_gpu", 1),
            workers_per_gpu=cfg.data.get("workers_per_gpu", 1),
            persistent_workers=cfg.data.get("persistent_workers", False),
            num_gpus=len(cfg.gpu_ids),
            dist=distributed,
            shuffle=False,
        )
        dataloader_setting = dict(
            dataloader_setting, **cfg.data.get("test_dataloader", {})
        )

        test_dataloader = build_dataloader(test_dataset, **dataloader_setting)

        names, ckpts = [], []

        if test["test_last"]:
            names.append("last")
            ckpts.append(None)
        if test["test_best"]:
            names.append("best")
            ckpts.append(best_ckpt_path)

        for name, ckpt in zip(names, ckpts):
            if ckpt is not None:
                runner.load_checkpoint(ckpt)

            outputs = multi_gpu_test(runner.model, test_dataloader, tmpdir, gpu_collect)
            rank, _ = get_dist_info()
            if rank == 0:
                out = osp.join(cfg.work_dir, f"{name}_pred.pkl")
                test_dataset.dump_results(outputs, out)

                eval_cfg = cfg.get("evaluation", {})
                for key in [
                    "interval",
                    "tmpdir",
                    "start",
                    "gpu_collect",
                    "save_best",
                    "rule",
                    "by_epoch",
                    "broadcast_bn_buffers",
                ]:
                    eval_cfg.pop(key, None)

                eval_res = test_dataset.evaluate(outputs, **eval_cfg)
                runner.logger.info(f"Testing results of the {name} checkpoint")
                for metric_name, val in eval_res.items():
                    runner.logger.info(f"{metric_name}: {val:.04f}")
