# Copyright 2025 The corr_faith Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Methods for loading and processing ComVE data."""

import os
from typing import Mapping, Sequence
from corr_faith.experiments.dataset_specific import classification_datasets
import numpy as np
import pandas as pd


# Sizes of data splits after filtering out invalid examples.
DATA_SPLIT_COUNTS = {
    "train": 9995,
    "dev": 997,
    "test": 999,
}


def merge_dfs_on_index(df_a: pd.DataFrame, df_b: pd.DataFrame) -> pd.DataFrame:
  """Merge two dataframes on index, ensuring their indices correspond exactly."""
  assert len(df_a) == len(df_b)
  assert set(df_a.index) == set(df_b.index)
  return pd.merge(
      left=df_a,
      right=df_b,
      left_index=True,
      right_index=True,
      how="outer",
      validate="one_to_one",
      suffixes=(False, False),  # Raise an exception if any columns overlap.
  )


def merge_all_dfs_on_index(
    dfs: Sequence[pd.DataFrame],
) -> pd.DataFrame:
  """Merge multiple dataframes on index, ensuring their indices match exactly."""
  merged_df = dfs[0]
  for df in dfs[1:]:
    merged_df = merge_dfs_on_index(merged_df, df)
  return merged_df


class ComVEDataset(
    classification_datasets.ClassificationDatasetWithExplanation
):
  """Class for working with data from ComVE.

  Commonsense Validation and Explanation (ComVE) tests whether a system can
  differentiate natural language statements that make sense from those that do
  not.

  We use task A (classifying which of two statements violates common sense),
  paired with the ground truth justifications from task C (explaining why the
  chosen sentence violates common sense).

  Full description of the dataset:
  https://github.com/wangcunxiang/SemEval2020-Task4-Commonsense-Validation-and-Explanation
  """

  def load_data_splits(self, base_data_dir: str) -> Mapping[str, pd.DataFrame]:
    all_data_dir = os.path.join(
        base_data_dir,
        "SemEval2020-Task4-Commonsense-Validation-and-Explanation/ALL data",
    )
    train_dir = os.path.join(all_data_dir, "Training  Data")
    # Subtask data DFs give column names in their first row.
    # Subtask answer DFs don't, so we need to name them explicitly.
    train_subtask_a_df = pd.read_csv(
        os.path.join(train_dir, "subtaskA_data_all.csv"), index_col=0
    )
    train_subtask_a_answers_df = pd.read_csv(
        os.path.join(train_dir, "subtaskA_answers_all.csv"), names=["false_idx"]
    )
    train_subtask_c_df = pd.read_csv(
        os.path.join(train_dir, "subtaskC_data_all.csv"), index_col=0
    )
    train_subtask_c_answers_df = pd.read_csv(
        os.path.join(train_dir, "subtaskC_answers_all.csv"),
        names=["exp0", "exp1", "exp2"],
    )
    train_df = merge_all_dfs_on_index([
        train_subtask_a_df,
        train_subtask_a_answers_df,
        train_subtask_c_df,
        train_subtask_c_answers_df,
    ])

    dev_dir = os.path.join(all_data_dir, "Dev Data")
    dev_subtask_a_df = pd.read_csv(
        os.path.join(dev_dir, "subtaskA_dev_data.csv"), index_col=0
    )
    dev_subtask_a_answers_df = pd.read_csv(
        os.path.join(dev_dir, "subtaskA_gold_answers.csv"), names=["false_idx"]
    )
    dev_subtask_c_df = pd.read_csv(
        os.path.join(dev_dir, "subtaskC_dev_data.csv"), index_col=0
    )
    dev_subtask_c_answers_df = pd.read_csv(
        os.path.join(dev_dir, "subtaskC_gold_answers.csv"),
        names=["exp0", "exp1", "exp2"],
    )
    dev_df = merge_all_dfs_on_index([
        dev_subtask_a_df,
        dev_subtask_a_answers_df,
        dev_subtask_c_df,
        dev_subtask_c_answers_df,
    ])

    test_dir = os.path.join(all_data_dir, "Test Data")
    test_subtask_a_df = pd.read_csv(
        os.path.join(test_dir, "subtaskA_test_data.csv"), index_col=0
    )
    test_subtask_a_answers_df = pd.read_csv(
        os.path.join(test_dir, "subtaskA_gold_answers.csv"), names=["false_idx"]
    )
    test_subtask_c_df = pd.read_csv(
        os.path.join(test_dir, "subtaskC_test_data.csv"), index_col=0
    )
    test_subtask_c_answers_df = pd.read_csv(
        os.path.join(test_dir, "subtaskC_gold_answers.csv"),
        names=["exp0", "exp1", "exp2"],
    )
    test_df = merge_all_dfs_on_index([
        test_subtask_a_df,
        test_subtask_a_answers_df,
        test_subtask_c_df,
        test_subtask_c_answers_df,
    ])

    for df in train_df, dev_df, test_df:
      is_duplicate = df["sent0"] == df["sent1"]
      # Test row 491 is unmatched:
      # "he was swimming in Hawaii" vs. "he was swimming to Hawaii"
      is_unmatched = (df["sent0"] != df["FalseSent"]) & (
          df["sent1"] != df["FalseSent"]
      )
      is_invalid = is_duplicate | is_unmatched
      # Need to use df.index rather than np.where: DF index may differ from
      # integer location.
      df.drop(index=df[is_invalid].index, inplace=True)
      assert np.all((df["FalseSent"] == df["sent1"]) == df["false_idx"])

    return dict(
        train=train_df,
        dev=dev_df,
        test=test_df,
    )

  @property
  def dataset_unique_id(self) -> int:
    return 2

  @property
  def start_of_next_example_str(self) -> str:
    return "SENTENCE 0:"
    # return classification_datasets.FEWSHOT_EXAMPLE_SEP + "SENTENCE 0:"

  @property
  def class_labels(self) -> Sequence[str]:
    return ("0", "1")

  def describe_example(self, include_explanation: bool = True) -> str:
    prompt_prefix = (
        'An example consists of a pair of sentences, "SENTENCE 0" and "SENTENCE'
        ' 1". One of these sentences violates common sense. The task is to'
        ' predict which one violated common sense: this is the "FALSE'
        ' SENTENCE", either 0 or 1.'
    )
    if include_explanation:
      prompt_prefix += (
          ' "EXPLANATION" explains why the selected sentence is chosen.'
      )
    return prompt_prefix

  @property
  def problem_instance_template(self) -> str:
    return "SENTENCE 0: {sent0}\nSENTENCE 1: {sent1}"

  @property
  def problem_instance_template_modified_format(self) -> str:
    return "SENTENCE 0: {sentence_0}\nSENTENCE 1: {sentence_1}"

  @property
  def problem_class_prefix(self) -> str:
    # NOTE: Numbers don't seem to be tokenized with leading spaces.
    # Therefore, include the space in the prefix rather than template,
    # so the initial token after the prefix will be the class label.
    return "FALSE SENTENCE:"

  @property
  def problem_class_template(self) -> str:
    return self.problem_class_prefix + " {false_idx}"

  @property
  def problem_explanation_prefix(self) -> str:
    return "EXPLANATION:"

  @property
  def problem_explanation_template(self) -> str:
    return self.problem_explanation_prefix + " {exp0}"

  @property
  def keys_to_intervene_on(self) -> Sequence[str]:
    return ["sent0", "sent1"]

  def get_true_label(self, row) -> str:
    return str(row["false_idx"])

  def get_problem_instance(self, row) -> Mapping[str, str]:
    return dict(
        sentence_0=row["sent0"],
        sentence_1=row["sent1"],
    )
