# Copyright 2025 The corr_faith Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Methods for loading and processing e-SNLI data."""

import os
from typing import Mapping, Sequence
from corr_faith.experiments.dataset_specific import classification_datasets
import pandas as pd


# Sizes of data splits.
DATA_SPLIT_COUNTS = {
    "train": 549367,
    "dev": 9842,
    "test": 9842,
}


class ESNLIDataset(
    classification_datasets.ClassificationDatasetWithExplanation
):
  """Class for working with data from e-SNLI.

  SNLI (Stanford Natural Language Inference) is a dataset
  on natural language inference: the task of determining the inference relation
  between two texts: entailment, contradiction, or neutral
  (https://nlp.stanford.edu/projects/snli/).

  e-SNLI adds human-written explanations for each NLI example.

  Full description of the dataset: https://github.com/OanaMariaCamburu/e-SNLI
  """

  def __init__(
      self, keys_to_intervene_on: Sequence[str] = ("Sentence1", "Sentence2")
  ):
    super().__init__()
    self._keys_to_intervene_on = keys_to_intervene_on

  def load_data_splits(self, base_data_dir: str) -> Mapping[str, pd.DataFrame]:
    train_df_1 = pd.read_csv(
        os.path.join(base_data_dir, "e-SNLI/dataset/esnli_train_1.csv")
    )
    train_df_2 = pd.read_csv(
        os.path.join(base_data_dir, "e-SNLI/dataset/esnli_train_2.csv")
    )
    train_df = pd.concat([train_df_1, train_df_2])
    dev_df = pd.read_csv(
        os.path.join(base_data_dir, "e-SNLI/dataset/esnli_dev.csv")
    )
    test_df = pd.read_csv(
        os.path.join(base_data_dir, "e-SNLI/dataset/esnli_test.csv")
    )

    return dict(
        train=train_df,
        dev=dev_df,
        test=test_df,
    )

  @property
  def dataset_unique_id(self) -> int:
    return 0

  @property
  def start_of_next_example_str(self) -> str:
    return "TEXT:"

  @property
  def class_labels(self) -> Sequence[str]:
    return ("entailment", "neutral", "contradiction")

  def describe_example(self, include_explanation: bool = True) -> str:
    prompt_prefix = (
        'An example consists of a pair of statements, "TEXT" and "HYPOTHESIS".'
        ' The task is to label each pair with a "JUDGEMENT": given the text, is'
        ' the hypothesis definitely true ("entailment"), maybe true'
        ' ("neutral"), or definitely false ("contradiction")?'
    )
    if include_explanation:
      prompt_prefix += (
          ' "EXPLANATION" explains why the selected judgement is chosen.'
      )
    return prompt_prefix

  @property
  def problem_instance_template(self) -> str:
    return "TEXT: {Sentence1}\nHYPOTHESIS: {Sentence2}"

  @property
  def problem_instance_template_modified_format(self) -> str:
    return "TEXT: {text}\nHYPOTHESIS: {hypothesis}"

  @property
  def problem_class_prefix(self) -> str:
    return "JUDGEMENT:"

  @property
  def problem_class_template(self) -> str:
    return self.problem_class_prefix + " {gold_label}"

  @property
  def problem_explanation_prefix(self) -> str:
    return "EXPLANATION:"

  @property
  def problem_explanation_template(self) -> str:
    return self.problem_explanation_prefix + " {Explanation_1}"

  @property
  def keys_to_intervene_on(self) -> Sequence[str]:
    # Pandas requires lists as indices, not tuples.
    return list(self._keys_to_intervene_on)

  def get_true_label(self, row) -> str:
    return row["gold_label"]

  def get_problem_instance(self, row) -> Mapping[str, str]:
    return dict(
        text=row["Sentence1"],
        hypothesis=row["Sentence2"],
    )
