# Copyright 2025 The corr_faith Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Methods for generating random counterfactual interventions."""

import dataclasses
import string
import time
from typing import Callable, Tuple
from absl import logging
from corr_faith.experiments import classical_nlp
from corr_faith.experiments import performance_logging
from corr_faith.experiments.dataset_specific import classification_datasets
from nltk.corpus import wordnet as wn
from nltk.stem import snowball
import numpy as np
import pandas as pd
import spacy
import tree


def produce_random_interventions(
    rng: np.random.Generator,
    nlp: spacy.Language,
    original_sentences: dict[str, str],
    n_interventions: int,
):
  """Modify one of "sentences" by inserting random modifiers.

  Args:
    rng: Random number generator, to sample words and positions to intervene on.
    nlp: Spacy NLP object, to tag parts of speech.
    original_sentences: Mapping from key (column name in the dataset) to the
      original sentence.
    n_interventions: Number of interventions to produce.

  Returns:
    List of dictionaries, each representing a modification. Each dictionary
    contains:
      sentences: Mapping from key to sentences.
      modified_sentence_key: Key of the modified sentence.
      inserted_word: Word that was inserted.
  """
  words_by_part_of_speech = classical_nlp.get_wordnet_words_by_part_of_speech(
      parts_of_speech=(wn.ADJ, wn.ADV),
      only_alphabetic=True,
  )
  docs = tree.map_structure(nlp, original_sentences)
  noun_and_verb_positions = []
  for key, doc in docs.items():
    for token_idx, token in enumerate(doc):
      if token.pos_ in {"NOUN", "VERB"}:
        noun_and_verb_positions.append([key, token_idx])

  if not noun_and_verb_positions:
    # Example from ComVE:
    # "{'sent0': 'It is cold and hot outside', 'sent1': 'It is cold outside'}"
    logging.warning(
        'Original sentences have no nouns or verbs: "%s"', original_sentences
    )
    return []
  # The number of times each position is chosen should be within 1 of
  # (n_interventions/n_positions).
  n_positions = len(noun_and_verb_positions)
  intervention_positions = (
      noun_and_verb_positions * (n_interventions // n_positions)
  ) + ([
      noun_and_verb_positions[idx]
      for idx in rng.choice(
          n_positions,
          size=n_interventions % n_positions,
          replace=False,
      )
  ])

  modified_examples = []
  previously_chosen_words = set()

  for key, token_idx in intervention_positions:
    word_to_insert = None
    original_sentence = original_sentences[key]
    doc = docs[key]
    while (word_to_insert is None) or (word_to_insert in original_sentence):
      if doc[token_idx].pos_ == "NOUN":
        word_to_insert = rng.choice(
            tuple(words_by_part_of_speech[wn.ADJ] - previously_chosen_words)
        )
      elif doc[token_idx].pos_ == "VERB":
        word_to_insert = rng.choice(
            tuple(words_by_part_of_speech[wn.ADV] - previously_chosen_words)
        )
      else:
        assert False  # We checked that the pos is NOUN or VERB.
    previously_chosen_words.add(word_to_insert)
    prefix = doc[:token_idx].text
    postfix = doc[token_idx:].text
    # Fix capitalization when inserting at the start of the sentence.
    if token_idx == 0:
      word_to_insert = word_to_insert.capitalize()
      if doc[0].pos_ != "PROPN":
        postfix = doc[0].text.lower() + " " + doc[1:].text
    else:
      prefix += " "
    word_to_insert = word_to_insert.replace(
        "_", " "
    )  # Wordnet uses underscores to denote spaces
    modified_example_text = f"{prefix}{word_to_insert} {postfix}"
    modified_sentences = original_sentences.copy()
    modified_sentences[key] = modified_example_text
    modified_examples.append(
        dict(
            sentences=modified_sentences,
            modified_sentence_key=key,
            inserted_word=word_to_insert,
            inserted_position=token_idx,
        )
    )
  assert modified_examples
  return modified_examples


def get_rng(
    data_idx: int,
    dataset: classification_datasets.ClassificationDatasetWithExplanation,
) -> np.random.Generator:
  """Get an rng that's consistent for a given dataset and example."""
  return np.random.default_rng(seed=[data_idx, dataset.dataset_unique_id])


def generate_interventions(
    dataset: classification_datasets.ClassificationDatasetWithExplanation,
    eval_df: pd.DataFrame,
    n_interventions: int = 16,
) -> pd.DataFrame:
  """Generate interventions for a given dataset and eval dataframe."""
  nlp = classical_nlp.get_spacy_nlp()
  n_examples = len(eval_df)
  modified_rows = []
  for count, (data_idx, raw_data_row) in enumerate(eval_df.iterrows()):
    logging.info(
        "Generating interventions for row %d (%d/%d)...",
        data_idx,
        count,
        n_examples,
    )
    modifications = produce_random_interventions(
        # Ensure all evaluated configurations use the same interventions.
        rng=get_rng(data_idx, dataset),
        nlp=nlp,
        original_sentences=raw_data_row[dataset.keys_to_intervene_on].to_dict(),
        n_interventions=n_interventions,
    )

    ### Perturbation
    for intervention_idx, modification in enumerate(modifications):
      modified_row_dict = raw_data_row.copy().to_dict()
      # Apply the intervention to the row.
      modified_row_dict.update(modification["sentences"])
      # Record metadata about the intervention.
      modified_row_dict.update(
          dict(
              data_idx=data_idx,
              intervention_idx=intervention_idx,
              inserted_word=modification["inserted_word"],
              inserted_position=modification["inserted_position"],
              modified_sentence_key=modification["modified_sentence_key"],
          )
      )
      modified_rows.append(modified_row_dict)
  return pd.DataFrame(modified_rows)


def evaluate_counterfactual_faithfulness(
    # pipeline: transformers.TextGenerationPipeline,
    predict_fewshot: Callable[
        [pd.Series, str], classification_datasets.Prediction
    ],
    # keys_to_modify: Sequence[str] = ("Sentence1", "Sentence2"),
    dataset: classification_datasets.ClassificationDatasetWithExplanation,
    sample_fewshot_prompt_base: Callable[[np.random.Generator], str],
    eval_df: pd.DataFrame,
    n_samples_per_example: int = 1,
    raise_on_generation_failure: bool = False,
    interventions_df: pd.DataFrame | None = None,
    n_interventions: int = 16,
    log_memory_usage_every_n_examples: int = 10,
    # use_stemmer: bool = False,
) -> Tuple[pd.DataFrame, pd.DataFrame]:
  """Evaluate faithfulness on ESNLI dev set using the counterfactual test."""
  nlp = classical_nlp.get_spacy_nlp()
  stemmer = snowball.SnowballStemmer("english")
  logging.info("Using stemmer: %s", stemmer)
  accuracy_rows = []
  intervention_rows = []
  if interventions_df is not None:
    # Evaluate only examples with interventions.
    data_idxs_with_interventions = interventions_df["data_idx"].unique()
    orig_len = len(eval_df)
    # We don't need to explicitly cast unique to a set; pandas will choose which
    # implementation to use for isin:
    # https://github.com/pandas-dev/pandas/blob/5da9eb726c915412928ee44a0631a74bda614556/pandas/core/algorithms.py#L552
    eval_df = eval_df.loc[eval_df.index.isin(data_idxs_with_interventions)]
    logging.info(
        "Filtered %d examples to %d examples with interventions.",
        orig_len,
        len(eval_df),
    )
  n_examples = len(eval_df)
  logging.info("Evaluating faithfulness on %d examples.", n_examples)
  for count, (data_idx, row) in enumerate(eval_df.iterrows()):
    logging.info(
        "Evaluating %d samples for row %d (%d/%d)...",
        n_samples_per_example,
        data_idx,
        count,
        n_examples,
    )
    if count % log_memory_usage_every_n_examples == 0:
      performance_logging.log_gpu_and_cpu_mem_stats()
    for sample_idx in range(n_samples_per_example):
      ### Initial Prediction
      true_label = dataset.get_true_label(row)
      # Use the same fewshot prompt for all samples of a given example.
      prompt_base = sample_fewshot_prompt_base(get_rng(data_idx, dataset))
      base_problem = dataset.get_problem_instance(row)
      accuracy_row = dict(
          data_idx=data_idx,
          sample_idx=sample_idx,
          base_problem=base_problem,
          true_label=true_label,
          prompt_base_char_len=len(prompt_base),
      )
      try:
        # This will fail if the LM doesn't generate something matching format.
        lm_generation_start_time = time.perf_counter()
        base_prediction = predict_fewshot(row, prompt_base)
        lm_generation_time = time.perf_counter() - lm_generation_start_time
        accuracy = bool(str(base_prediction.pred_label) == str(true_label))
        # Should we just include all the rows?
        accuracy_row.update(
            dict(
                **{
                    "base_" + key: value
                    for key, value in dataclasses.asdict(
                        base_prediction
                    ).items()
                },
                accuracy=accuracy,
                text_generation_wall_time=lm_generation_time,
            )
        )
      except ValueError as e:
        logging.warning(
            "Failed to generate prediction for row %d: %s",
            data_idx,
            e,
        )
        # Don't explicitly add NA values here; they'll be set to NA
        # automatically by pd.concat. The full DF should only miss columns if
        # 100% of generations fail for all work units.
        if raise_on_generation_failure:
          raise
        continue  # Note that "finally" is executed even after "continue".
      finally:
        accuracy_row["timestamp"] = pd.Timestamp.now()
        accuracy_rows.append(accuracy_row)

      ### Perturbation
      if interventions_df is None:
        modified_rows = []
        modifications = produce_random_interventions(
            # Ensure all evaluated configurations use the same interventions.
            rng=get_rng(data_idx, dataset),
            nlp=nlp,
            original_sentences=row[dataset.keys_to_intervene_on].to_dict(),
            n_interventions=n_interventions,
        )
        for intervention_idx, modification in enumerate(modifications):
          modified_row = row.copy()
          modified_row.update(modification["sentences"])
          modified_row["inserted_word"] = modification["inserted_word"]
          modified_row["modified_sentence_key"] = modification[
              "modified_sentence_key"
          ]
          modified_row["intervention_idx"] = intervention_idx
          modified_rows.append(modified_row)
        modified_rows = pd.DataFrame(modified_rows)
      else:
        modified_rows = interventions_df[
            interventions_df["data_idx"] == data_idx
        ]
      for _, modified_row in modified_rows.iterrows():
        inserted_word = modified_row["inserted_word"]
        modified_sentence_key = modified_row["modified_sentence_key"]
        intervention_idx = modified_row["intervention_idx"]
        lm_generation_start_time = time.perf_counter()
        try:
          modified_prediction = predict_fewshot(modified_row, prompt_base)
        except ValueError as e:
          logging.info("Error: %s", repr(e))
          if raise_on_generation_failure:
            raise
          else:
            continue
        lm_generation_time = time.perf_counter() - lm_generation_start_time
        # PD
        pred_label_changed = (
            modified_prediction.pred_label != base_prediction.pred_label
        )
        # Evaluate faithfulness without and with stemming
        if inserted_word is pd.NA or modified_prediction.explanation is pd.NA:
          modified_explanation_mentions_inserted_word = pd.NA
          stemmed_explanation_words = pd.NA
          stemmed_inserted_word = pd.NA
          stemmed_modified_explanation_mentions_inserted_word = pd.NA
        else:
          modified_explanation_mentions_inserted_word = (
              inserted_word.lower() in modified_prediction.explanation.lower()
          )
          # Truncate trailing punctuation in each word
          stemmed_explanation_words = {
              stemmer.stem(word.rstrip(string.punctuation))
              for word in modified_prediction.explanation.split()
          }
          stemmed_inserted_word = stemmer.stem(inserted_word).rstrip(
              string.punctuation
          )
          stemmed_modified_explanation_mentions_inserted_word = (
              stemmed_inserted_word in stemmed_explanation_words
          )
        intervention_row = dict(
            data_idx=data_idx,
            sample_idx=sample_idx,
            intervention_idx=intervention_idx,
            base_problem=base_problem,
            true_label=true_label,
            **{
                "base_" + key: value
                for key, value in dataclasses.asdict(base_prediction).items()
            },
            inserted_word=inserted_word,
            stemmed_inserted_word=stemmed_inserted_word,
            stemmed_explanation_words=str(stemmed_explanation_words),
            modified_sentence_key=modified_sentence_key,
            modified_problem=dataset.get_problem_instance(modified_row),
            **{
                "modified_" + key: value
                for key, value in dataclasses.asdict(
                    modified_prediction
                ).items()
            },
            pred_label_changed=pred_label_changed,
            modified_explanation_mentions_inserted_word=modified_explanation_mentions_inserted_word,
            stemmed_modified_explanation_mentions_inserted_word=stemmed_modified_explanation_mentions_inserted_word,
            text_generation_wall_time=lm_generation_time,
        )
        intervention_row["timestamp"] = pd.Timestamp.now()
        if sample_idx == 0 and intervention_idx == 0:
          level = logging.INFO
        else:
          level = logging.DEBUG
        logging.log(level, "%s", intervention_row)
        intervention_rows.append(intervention_row)
  return pd.DataFrame(accuracy_rows), pd.DataFrame(intervention_rows)
