# coding=utf-8
# Copyright 2022 The Conceptual Learning Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest
from absl.testing import parameterized

from conceptual_learning.tools.metrics import config_utils


class ConfigUtilsTest(parameterized.TestCase):

  @parameterized.named_parameters(
      ('simple_case', 'EXP_', {
          'num_layers': 6
      }),
      ('more_complicated', "EXP_{'num_layers': 24, 'seed': 0}", {
          'num_layers': 24,
          'seed': 0
      }),
      ('with_defaults', "EXP_{'num_layers': 6, 'seed': 0}", {
          'num_layers': 6,
          'seed': 0
      }),
  )
  def test_parse_config_overrides(self, trial_name, expected_value):
    parsed_overrides = config_utils.parse_config_overrides(trial_name, 't5x')
    self.assertDictEqual(parsed_overrides, expected_value)

  @parameterized.named_parameters(
      ('random_invalid_config_name', 'SOME INVALID CONFIG NAME'),
      ('baseline_name_that_is_not_same_as_config_name', 't5x_no_context'),
  )
  def test_parse_config_overrides_raises_error_if_config_name_is_invalid(
      self, config_name):
    with self.assertRaisesRegex(ValueError, 'Invalid config name'):
      config_utils.parse_config_overrides('EXP_', config_name)

  @parameterized.named_parameters(
      ('trial_name_empty', ''),
      ('trial_name_does_not_begin_with_EXP', "{'num_layers':6}"),
  )
  def test_parse_config_overrides_raises_error_if_trial_name_is_invalid(
      self, trial_name):
    with self.assertRaisesRegex(ValueError, 'Invalid baseline variant name'):
      config_utils.parse_config_overrides(trial_name, 't5x')

  # pyformat: disable
  @parameterized.named_parameters(
      ('all_values_match',
       {'num_layers': 12, 'seed': 0},
       {'num_layers': [12, 24], 'seed': 0},
       True),
      ('default_value_matches_allowed_value',
       {'seed': 0},  # Default value of num_layers for etc is 6.
       {'num_layers': [6, 12, 24], 'seed': 0},
       True),
      ('value_matches_default',
       {'num_layers': 6, 'seed': 0},
       {'seed': 0},
       True),
      ('value_does_not_match_list',
       {'num_layers': 6, 'seed': 0},
       {'num_layers': [12, 24], 'seed': 0},
       False),
      ('value_does_not_match_non_list',
       {'num_layers': 12, 'seed': 1},
       {'num_layers': [12, 24], 'seed': 0},
       False),
      ('key_missing_from_overrides',
       {'num_layers': 12},
       {'num_layers': [12, 24], 'seed': 1},
       False),
      ('key_missing_from_allowed_overrides',
       {'num_layers': 12, 'seed': 1},
       {'num_layers': [12, 24]},
       False),
  )
  # pyformat: enable
  def test_match_config_overrides(self, overrides, allowed_overrides, is_match):
    self.assertEqual(
        is_match,
        config_utils.match_config_overrides(overrides, allowed_overrides,
                                            't5x'))


if __name__ == '__main__':
  absltest.main()
