from typing import Dict, Union

from omegaconf import DictConfig, OmegaConf

from common.utils.robot_utils import ROBOTS


def _assert_contain_same_str(str1: str, str2: str, keyword: str):
    if keyword in str1:
        assert keyword in str2


def check_dataset_and_env_consistency(dataset_name: str, env_id: str):
    _assert_contain_same_str(env_id, dataset_name, 'ant')
    _assert_contain_same_str(env_id, dataset_name, 'maze2d')
    _assert_contain_same_str(env_id, dataset_name, 'umaze')
    _assert_contain_same_str(env_id, dataset_name, 'medium')
    _assert_contain_same_str(env_id, dataset_name, 'Lift')
    for robot_name in ROBOTS:
        _assert_contain_same_str(env_id, dataset_name, robot_name)


def check_source_and_target_consistency(
    source_name: str,
    target_name: str,
):
    _assert_contain_same_str(source_name, target_name, 'umaze')
    _assert_contain_same_str(source_name, target_name, 'medium')
    _assert_contain_same_str(source_name, target_name, 'Lift')


def _add_prefix_to_dict(dic: Union[Dict, DictConfig], prefix: str) -> Dict:
    new_dict = {f'{prefix}_{key}': val for key, val in dic.items()}
    return new_dict


def read_env_config_yamls(args: DictConfig):
    """It checks args.source and args.target, read common/config_utils/{args.source}.yaml and {args.target}.yaml,
  and merge fields to args with a prefix "source_" or  "target_".
  It also calculates max_obs_dim and max_action_dim from configs of the two domains.

  Args:
      args: Omegaconf config dict.

  Returns:
      config dict that with additional items such as "source_env", "source_dataset", and "source_obs_dim".

  """
    args.source = args.source.replace('-', '_')
    args.target = args.target.replace('-', '_')
    source_conf = OmegaConf.load(f'common/config_utils/{args.source}.yaml')
    target_conf = OmegaConf.load(f'common/config_utils/{args.target}.yaml')
    if args.image_observation:
        source_conf.obs_dim *= 2
        target_conf.obs_dim *= 2
    args = OmegaConf.merge(args,
                           _add_prefix_to_dict(source_conf, prefix='source'))
    args = OmegaConf.merge(args,
                           _add_prefix_to_dict(target_conf, prefix='target'))
    args.max_obs_dim = max(args.source_obs_dim, args.target_obs_dim)
    args.max_action_dim = max(args.source_action_dim, args.target_action_dim)

    if args.image_observation:
        args.source_dataset = args.source_dataset.replace(
            ".hdf5", "_image.hdf5")
        args.target_dataset = args.target_dataset.replace(
            ".hdf5", "_image.hdf5")

    check_source_and_target_consistency(source_name=source_conf.env,
                                        target_name=target_conf.env)

    return args
