# Copyright 2025 The corr_faith Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Methods for loading and processing ECQA data."""

import os
from typing import Mapping, Sequence
from absl import logging
from corr_faith.experiments.dataset_specific import classification_datasets
import pandas as pd


# Sizes of data splits after filtering out invalid examples.
DATA_SPLIT_COUNTS = {
    "train": 7598,
    "dev": 1086,
    "test": 2194,
}


def get_correct_answer_idx_from_row(row: pd.Series) -> int:
  """Get the index of the correct answer, given a row of the ECQA dataset."""
  # The dataset uses 1-indexing.
  options = [
      row[op_key] for op_key in ["q_op1", "q_op2", "q_op3", "q_op4", "q_op5"]
  ]
  idx_and_option = [(idx + 1, option) for idx, option in enumerate(options)]
  correct_idxs = [
      idx for idx, option in idx_and_option if option == row["q_ans"]
  ]
  if len(correct_idxs) != 1:
    assert len(correct_idxs) > 1
    logging.warning(
        "Multiple matching indices found! %s, %s", row["q_ans"], idx_and_option
    )
    # Use -1 rather than nan to keep int type
    return -1
  return correct_idxs[0]


class ECQADataset(classification_datasets.ClassificationDatasetWithExplanation):
  """Class for working with data from ECQA.

  Explanations for CommonsenseQA (ECQA) is a dataset of multiple choice natural
  language common sense questions, paired with human-written explanations.

  Full description of the dataset: https://github.com/dair-iitd/ECQA-Dataset
  """

  def load_data_splits(
      self, base_data_dir: str, ans_idx_key: str = "q_ans_idx"
  ) -> Mapping[str, pd.DataFrame]:
    dataset_dir = os.path.join(base_data_dir, "ECQA-Dataset")
    train_df = pd.read_csv(os.path.join(dataset_dir, "cqa_data_train.csv"))
    dev_df = pd.read_csv(os.path.join(dataset_dir, "cqa_data_val.csv"))
    test_df = pd.read_csv(os.path.join(dataset_dir, "cqa_data_test.csv"))

    splits = dict(
        train=train_df,
        dev=dev_df,
        test=test_df,
    )
    for df in splits.values():
      df[ans_idx_key] = df.apply(get_correct_answer_idx_from_row, axis=1)
      # Drop rows with multiple matching indices.
      # df.dropna(axis=0, subset=[ans_idx_key], inplace=True)
      df.drop(index=df[df[ans_idx_key] == -1].index, inplace=True)
    return splits

  @property
  def dataset_unique_id(self) -> int:
    return 1

  @property
  def start_of_next_example_str(self) -> str:
    return "QUESTION:"
    # return classification_datasets.FEWSHOT_EXAMPLE_SEP + "QUESTION:"

  @property
  def class_labels(self) -> Sequence[str]:
    return ("1", "2", "3", "4", "5")

  def describe_example(self, include_explanation: bool = True) -> str:
    prompt_prefix = (
        "An example consists of a question followed by five multiple choice"
        " options. The task is to choose the option that makes the most sense"
        ' as answer to the question; this option is labelled as "CORRECT'
        ' OPTION".'
    )
    if include_explanation:
      prompt_prefix += (
          ' "EXPLANATION" explains why the selected option is chosen.'
      )
    return prompt_prefix

  @property
  def problem_instance_template(self) -> str:
    return "\n".join([
        "QUESTION: {q_text}",
        "OPTION 1: {q_op1}",
        "OPTION 2: {q_op2}",
        "OPTION 3: {q_op3}",
        "OPTION 4: {q_op4}",
        "OPTION 5: {q_op5}",
    ])

  @property
  def problem_instance_template_modified_format(self) -> str:
    return "\n".join([
        "QUESTION: {question}",
        "OPTION 1: {op1}",
        "OPTION 2: {op2}",
        "OPTION 3: {op3}",
        "OPTION 4: {op4}",
        "OPTION 5: {op5}",
    ])

  @property
  def problem_class_prefix(self) -> str:
    # NOTE: Numbers don't seem to be tokenized with leading spaces.
    # Therefore, include the space in the prefix rather than template,
    # so the initial token after the prefix will be the class label.
    return "CORRECT OPTION:"

  @property
  def problem_class_template(self) -> str:
    return self.problem_class_prefix + " {q_ans_idx}"

  @property
  def problem_explanation_prefix(self) -> str:
    return "EXPLANATION:"

  @property
  def problem_explanation_template(self) -> str:
    return self.problem_explanation_prefix + " {taskB}"

  @property
  def keys_to_intervene_on(self) -> Sequence[str]:
    return ["q_text"]

  def get_true_label(self, row) -> str:
    return str(row["q_ans_idx"])

  def get_problem_instance(self, row) -> Mapping[str, str]:
    return dict(
        question=row["q_text"],
        op1=row["q_op1"],
        op2=row["q_op2"],
        op3=row["q_op3"],
        op4=row["q_op4"],
        op5=row["q_op5"],
        # explanation=row["taskB"],
    )
