# Copyright 2025 The corr_faith Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Methods for working with classification datasets."""
import functools

from corr_faith.experiments.dataset_specific import classification_datasets
from corr_faith.experiments.dataset_specific import comve
from corr_faith.experiments.dataset_specific import ecqa
from corr_faith.experiments.dataset_specific import esnli
from ml_collections import config_dict


@functools.cache
def dataset_from_string(
    dataset_name: str,
    **kwargs,
) -> classification_datasets.ClassificationDatasetWithExplanation:
  match dataset_name:
    case "esnli":
      return esnli.ESNLIDataset(**kwargs)
    case "comve":
      return comve.ComVEDataset()
    case "ecqa":
      return ecqa.ECQADataset()
    case _:
      raise ValueError(f"Unknown dataset: {dataset_name}")


def dataset_from_config(
    config: config_dict.ConfigDict,
) -> classification_datasets.ClassificationDatasetWithExplanation:
  match config.dataset:
    case "esnli":
      return esnli.ESNLIDataset(**config.esnli_kwargs)
    case "comve":
      return comve.ComVEDataset()
    case "ecqa":
      return ecqa.ECQADataset()
    case _:
      raise ValueError(f"Unknown dataset: {config.dataset}")
