# Copyright (c) OpenMMLab. All rights reserved.
import copy
import warnings

from mmengine.config import ConfigDict


def compat_cfg(cfg):
    """This function would modify some filed to keep the compatibility of
    config.

    For example, it will move some args which will be deprecated to the correct
    fields.
    """
    cfg = copy.deepcopy(cfg)
    cfg = compat_imgs_per_gpu(cfg)
    cfg = compat_loader_args(cfg)
    cfg = compat_runner_args(cfg)
    return cfg


def compat_runner_args(cfg):
    if 'runner' not in cfg:
        cfg.runner = ConfigDict({
            'type': 'EpochBasedRunner',
            'max_epochs': cfg.total_epochs
        })
        warnings.warn(
            'config is now expected to have a `runner` section, '
            'please set `runner` in your config.', UserWarning)
    else:
        if 'total_epochs' in cfg:
            assert cfg.total_epochs == cfg.runner.max_epochs
    return cfg


def compat_imgs_per_gpu(cfg):
    cfg = copy.deepcopy(cfg)
    if 'imgs_per_gpu' in cfg.data:
        warnings.warn('"imgs_per_gpu" is deprecated in MMDet V2.0. '
                      'Please use "samples_per_gpu" instead')
        if 'samples_per_gpu' in cfg.data:
            warnings.warn(
                f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
                f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
                f'={cfg.data.imgs_per_gpu} is used in this experiments')
        else:
            warnings.warn('Automatically set "samples_per_gpu"="imgs_per_gpu"='
                          f'{cfg.data.imgs_per_gpu} in this experiments')
        cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
    return cfg


def compat_loader_args(cfg):
    """Deprecated sample_per_gpu in cfg.data."""

    cfg = copy.deepcopy(cfg)
    if 'train_dataloader' not in cfg.data:
        cfg.data['train_dataloader'] = ConfigDict()
    if 'val_dataloader' not in cfg.data:
        cfg.data['val_dataloader'] = ConfigDict()
    if 'test_dataloader' not in cfg.data:
        cfg.data['test_dataloader'] = ConfigDict()

    # special process for train_dataloader
    if 'samples_per_gpu' in cfg.data:

        samples_per_gpu = cfg.data.pop('samples_per_gpu')
        assert 'samples_per_gpu' not in \
               cfg.data.train_dataloader, ('`samples_per_gpu` are set '
                                           'in `data` field and ` '
                                           'data.train_dataloader` '
                                           'at the same time. '
                                           'Please only set it in '
                                           '`data.train_dataloader`. ')
        cfg.data.train_dataloader['samples_per_gpu'] = samples_per_gpu

    if 'persistent_workers' in cfg.data:

        persistent_workers = cfg.data.pop('persistent_workers')
        assert 'persistent_workers' not in \
               cfg.data.train_dataloader, ('`persistent_workers` are set '
                                           'in `data` field and ` '
                                           'data.train_dataloader` '
                                           'at the same time. '
                                           'Please only set it in '
                                           '`data.train_dataloader`. ')
        cfg.data.train_dataloader['persistent_workers'] = persistent_workers

    if 'workers_per_gpu' in cfg.data:

        workers_per_gpu = cfg.data.pop('workers_per_gpu')
        cfg.data.train_dataloader['workers_per_gpu'] = workers_per_gpu
        cfg.data.val_dataloader['workers_per_gpu'] = workers_per_gpu
        cfg.data.test_dataloader['workers_per_gpu'] = workers_per_gpu

    # special process for val_dataloader
    if 'samples_per_gpu' in cfg.data.val:
        # keep default value of `sample_per_gpu` is 1
        assert 'samples_per_gpu' not in \
               cfg.data.val_dataloader, ('`samples_per_gpu` are set '
                                         'in `data.val` field and ` '
                                         'data.val_dataloader` at '
                                         'the same time. '
                                         'Please only set it in '
                                         '`data.val_dataloader`. ')
        cfg.data.val_dataloader['samples_per_gpu'] = \
            cfg.data.val.pop('samples_per_gpu')
    # special process for val_dataloader

    # in case the test dataset is concatenated
    if isinstance(cfg.data.test, dict):
        if 'samples_per_gpu' in cfg.data.test:
            assert 'samples_per_gpu' not in \
                   cfg.data.test_dataloader, ('`samples_per_gpu` are set '
                                              'in `data.test` field and ` '
                                              'data.test_dataloader` '
                                              'at the same time. '
                                              'Please only set it in '
                                              '`data.test_dataloader`. ')

            cfg.data.test_dataloader['samples_per_gpu'] = \
                cfg.data.test.pop('samples_per_gpu')

    elif isinstance(cfg.data.test, list):
        for ds_cfg in cfg.data.test:
            if 'samples_per_gpu' in ds_cfg:
                assert 'samples_per_gpu' not in \
                       cfg.data.test_dataloader, ('`samples_per_gpu` are set '
                                                  'in `data.test` field and ` '
                                                  'data.test_dataloader` at'
                                                  ' the same time. '
                                                  'Please only set it in '
                                                  '`data.test_dataloader`. ')
        samples_per_gpu = max(
            [ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test])
        cfg.data.test_dataloader['samples_per_gpu'] = samples_per_gpu

    return cfg
