# Copyright 2025 The corr_faith Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Methods for assessing the naturalness of counterfactual interventions."""

import warnings
from absl import logging
from corr_faith.experiments.models import tokenization
import numpy as np
import pandas as pd
import transformers


INTERVENTION_NATURALNESS_PROMPT_TEMPLATE = """I'm going to show a sentence, followed by the same sentence with a word added. It's fine if the added word changes the meaning of the sentence. However, I want you to tell me if the second sentence still makes sense with the added word.

Sentence 1: "{sent1}"

Sentence 2: "{sent2}"

Does the second sentence make sense with the added word? Please begin your answer with "Yes" or "No"."""


def ask_yes_no_question(
    pipeline: transformers.TextGenerationPipeline,
    prompt: str,
    top_n_tokens: int = 20,
) -> tuple[float, float, float]:
  """Ask a yes/no question to the model and return the probabilities."""
  tokenizer = pipeline.tokenizer
  messages = [dict(role="user", content=prompt)]
  input_ids = tokenizer.apply_chat_template(
      messages,
      add_generation_prompt=True,
      return_tensors="pt",
  )
  with warnings.catch_warnings():
    # Example warning:
    # /usr/local/lib/python3.8/dist-packages/transformers/generation/configuration_utils.py:601:
    # UserWarning: `do_sample` is set to `False`. However, `temperature` is
    # set to `0.6` -- this flag is only used in sample-based generation modes.
    # You should set `do_sample=True` or unset `temperature`.
    warnings.filterwarnings(
        "ignore",
        message=(
            "`do_sample` is set to `False`. However,"
            " `(temperature|top_p|top_k)` is set to"
        ),
    )
    warnings.filterwarnings(
        # The consensus seems to be that defaults are actually correct here:
        # https://www.reddit.com/r/KoboldAI/comments/yz26ol/comment/ix1za2e/
        # https://github.com/suno-ai/bark/issues/402#issuecomment-1672725122
        "ignore",
        message=(
            "The attention mask and the pad token id were not set. As a"
            " consequence, you may observe unexpected behavior. Please pass"
            " your input's `attention_mask` to obtain reliable results."
        ),
    )
    warnings.filterwarnings(
        "ignore",
        message=(
            "Setting `pad_token_id` to `eos_token_id`:None for open-end"
            " generation."
        ),
    )
    outputs = pipeline.model.generate(
        input_ids=input_ids.to(pipeline.device),
        return_dict_in_generate=True,
        output_scores=True,
        max_new_tokens=1,
        do_sample=False,
    )
  scores = np.asarray(
      [score.detach().cpu().numpy() for score in outputs.scores]
  )
  del outputs  # Ensure we don't leak any memory.
  assert len(scores.shape) == 3
  assert scores.shape[:2] == (1, 1)
  token_prob_df = tokenization.logit_vector_to_token_probs_df(
      tokenizer=tokenizer,
      logit_vector=scores[0, 0],
  )
  top_tokens_df = token_prob_df[:top_n_tokens]
  contains_yes = top_tokens_df["tokens"].apply(lambda x: "yes" in x.lower())
  contains_no = top_tokens_df["tokens"].apply(lambda x: "no" in x.lower())
  # Some tokens contain both, e.g. #76894: ".YesNo" for Qwen 2.5.
  contains_both = contains_yes & contains_no
  contains_neither = ~(contains_yes | contains_no)
  top_tokens_df.loc[:, "contains_yes"] = contains_yes & ~contains_both
  top_tokens_df.loc[:, "contains_no"] = contains_no & ~contains_both
  top_tokens_df.loc[:, "invalid"] = contains_neither | contains_both
  assert top_tokens_df["contains_yes"].sum() + top_tokens_df[
      "contains_no"
  ].sum() + top_tokens_df["invalid"].sum() == len(top_tokens_df)
  yes_prob = top_tokens_df.loc[top_tokens_df["contains_yes"], "probs"].sum()
  no_prob = top_tokens_df.loc[top_tokens_df["contains_no"], "probs"].sum()
  invalid_prob = top_tokens_df.loc[top_tokens_df["invalid"], "probs"].sum()
  if yes_prob + no_prob <= 0.95:
    logging.warning(
        "Significant probability of invalid response: %s", invalid_prob
    )
  return dict(
      yes=yes_prob,
      no=no_prob,
      invalid=invalid_prob,
  )


def assess_intervention_naturalness(
    pipeline: transformers.TextGenerationPipeline,
    eval_df: pd.DataFrame,
    intervention_df: pd.DataFrame,
    top_n_tokens: int = 20,
) -> pd.DataFrame:
  """Assess intervention naturalness using a LLM."""
  n_examples = len(eval_df)
  naturalness_rows = []
  for count, (data_idx, original_data_row) in enumerate(eval_df.iterrows()):
    logging.info(
        "Assessing intervention naturalness for row %d (%d/%d)...",
        data_idx,
        count,
        n_examples,
    )
    intervention_rows = intervention_df[
        intervention_df["data_idx"] == data_idx
    ].copy()
    for _, intervention_row in intervention_rows.iterrows():
      modified_sentence_key = intervention_row["modified_sentence_key"]
      prompt = INTERVENTION_NATURALNESS_PROMPT_TEMPLATE.format(
          sent1=original_data_row[modified_sentence_key],
          sent2=intervention_row[modified_sentence_key],
      )
      naturalness_probs = ask_yes_no_question(
          pipeline,
          prompt,
          top_n_tokens,
      )
      naturalness_rows.append(
          dict(
              **intervention_row,
              **{
                  "intervention_naturalness_yes_prob": naturalness_probs["yes"],
                  "intervention_naturalness_no_prob": naturalness_probs["no"],
                  "intervention_naturalness_invalid_prob": naturalness_probs[
                      "invalid"
                  ],
                  "prompt": prompt,
              },
          )
      )
  return pd.DataFrame(naturalness_rows)
