# Copyright 2025 The corr_faith Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

r"""Evaluate faithfulness of LLM explanations with counterfactual interventions."""

from collections.abc import Sequence
import functools
import os
import pdb
import sys
import time
import traceback

from absl import app
from absl import flags
from absl import logging
import accelerate
from corr_faith import common_io
from corr_faith.experiments import counterfactual_interventions
from corr_faith.experiments import dataset_specific
from corr_faith.experiments import performance_logging
from corr_faith.experiments.dataset_specific import classification_datasets
from corr_faith.experiments.models import api_models
from corr_faith.experiments.models import huggingface_models_container
from ml_collections import config_dict
from ml_collections import config_flags
import numpy as np
import pandas as pd
import scipy as sp
import spacy
import torch
import transformers
import tree


N_ESNLI_TEST_EXAMPLES = 9842

FLAGS = flags.FLAGS

_EXPERIMENT_ID = flags.DEFINE_integer(
    "experiment_id",
    0,
    "Identifier for the current experiment. Increment to avoid overwriting"
    " existing results.",
)
_WORKER_ID = flags.DEFINE_integer(
    "worker_id",
    0,
    "Identifier for the current experiment worker.",
)


def get_config():
  """Returns a ConfigDict containing experiment hyperparameters."""
  config = config_dict.ConfigDict()
  config.eval_start_idx = 0
  config.eval_end_idx = N_ESNLI_TEST_EXAMPLES
  config.model_is_instruction_tuned = True
  config.model = "Qwen/Qwen2.5-14B"
  config.prompt_template = "predict_then_explain"
  config.include_explanations_in_fewshot = True
  config.explanation_length_instructions = "EMPTY"
  config.explanation_faithfulness_instructions = "EMPTY"
  config.prompt_n_shots = 10
  config.lm_random_sampling = False
  # Datasets are relatively small, so we build them into the container.
  config.dataset = "esnli"
  # Can set to True for local debugging.
  config.raise_on_generation_failure = False
  config.raise_on_high_prob_of_non_class_tokens = False

  config.interventions = config_dict.ConfigDict()
  config.interventions.load_assessed_interventions_from_path = ""
  config.interventions.n_interventions_per_example = 1

  config.esnli_kwargs = config_dict.ConfigDict()
  config.esnli_kwargs.keys_to_intervene_on = ("Sentence1", "Sentence2")

  # config.io.download_model = True
  config.io = config_dict.ConfigDict()
  config.io.base_data_dir = "/home/nonroot/data/"
  config.io.download_hf_weights_from_gcs_path = ""
  config.io.max_download_attempts = 5
  config.io.save_results_df_path = "/home/nonroot/results/"
  return config


def get_explanation_instructions(config: config_dict.ConfigDict) -> str:
  return (
      classification_datasets.ExplanationLengthInstructions[
          config.explanation_length_instructions
      ].value
      + classification_datasets.ExplanationFaithfulnessInstructions[
          config.explanation_faithfulness_instructions
      ].value
  )


_CONFIG = config_flags.DEFINE_config_dict("config", get_config())


accelerator = accelerate.Accelerator()


@functools.lru_cache()
def get_spacy_nlp() -> spacy.Language:
  return spacy.load("en_core_web_lg")


def conf_interval_mean(a, conf):
  mean, sem, m = (
      np.mean(a),
      sp.stats.sem(a),
      sp.stats.t.ppf((1 + conf) / 2.0, len(a) - 1),
  )
  return mean, mean - m * sem, mean + m * sem


def mean_and_conf_interval_as_str(seq: Sequence[float], conf=0.95):
  mean, lower_bound, upper_bound = conf_interval_mean(seq, conf=conf)
  return f"{mean:.4f} (({lower_bound:.4f}, {upper_bound:.4f}) {conf} CI)"


def main(argv: Sequence[str]) -> None:
  if len(argv) > 1:
    raise app.UsageError(
        f'Too many command-line arguments. Try removing an extra "--"? {argv}'
    )

  logging.info("Starting program.")
  performance_logging.init_logging(accelerator)
  performance_logging.log_system_debugging_info()

  experiment_id = _EXPERIMENT_ID.value
  worker_id = _WORKER_ID.value
  config = _CONFIG.value
  logging.info("Config values: %s", config)
  logging.info("UID: %d, GID: %d", os.getuid(), os.getgid())

  if common_io.is_gcs(config.io.save_results_df_path):
    # If saving to GCS, ensure we have the proper credentials; check this
    # before running the experiment, so we don't just crash at the end.
    common_io.get_gcs_client()

  logging.info("Loading dataset.")
  dataset = dataset_specific.dataset_from_config(config)
  data_splits = dataset.load_data_splits(base_data_dir=config.io.base_data_dir)
  train_df = data_splits["train"]
  test_df = data_splits["test"]
  if config.interventions.load_assessed_interventions_from_path:
    assessed_interventions_path = os.path.join(
        config.interventions.load_assessed_interventions_from_path,
        config.dataset,
        "natural_interventions.parquet",
    )
    logging.info(
        "Loading interventions from %s.",
        assessed_interventions_path,
    )
    with common_io.OpenFile(assessed_interventions_path, "rb") as f:
      interventions_df = pd.read_parquet(f)
  else:
    logging.info(
        "Not loading interventions from GCS, will instead generate %s random"
        " interventions per example.",
        config.interventions.n_interventions_per_example,
    )
    if not isinstance(config.interventions.n_interventions_per_example, int):
      raise ValueError(
          "If not loading interventions,"
          " config.interventions.n_interventions_per_example must be an"
          "integer, but got"
          f" {config.interventions.n_interventions_per_example}."
      )
    interventions_df = None

  try:
    model = api_models.from_model_name(config.model)
    logging.info("Using API model.")
  except api_models.UnknownApiError:
    # If the model doesn't have a known API prefix, interpret it as an HF model.
    if accelerator.is_main_process:
      if config.io.download_hf_weights_from_gcs_path:
        # Debug crashes during model loading.
        with performance_logging.LogMemoryPeriodically(period=5.0):
          common_io.download_model_from_gcs(
              model_name=config.model,
              source_dir=config.io.download_hf_weights_from_gcs_path,
              destination_dir=config.io.base_data_dir,
              max_download_attempts=config.io.max_download_attempts,
          )
    # There are no plans to support remote models other than from HF hub:
    # https://github.com/huggingface/transformers/issues/23412
    # So to use remote models, we have to download them first.
    if config.io.download_hf_weights_from_gcs_path:
      model_path = os.path.join(config.io.base_data_dir, config.model)
    else:
      model_path = config.model
    logging.info("Loading transformers pipeline from model: %s", config.model)
    huggingface_models_container.suppress_dynamo_errors(config.model)
    model_kwargs = huggingface_models_container.MODEL_KWARGS_SPECIAL_CASES.get(
        config.model, dict()
    )
    if "torch_dtype" not in model_kwargs:
      model_kwargs["torch_dtype"] = torch.float16
    model = transformers.pipeline(
        "text-generation",
        model=model_path,
        device_map="auto",
        trust_remote_code=False,
        model_kwargs=model_kwargs,
    )
    logging.info("Loaded huggingface model.")

  if config.include_explanations_in_fewshot:
    fewshot_prompt_template = config.prompt_template
  else:
    fewshot_prompt_template = "predict_only"
  include_explanation_in_prefix = fewshot_prompt_template != "predict_only"

  def sample_fewshot_prompt_base(rng: np.random.Generator) -> str:
    return classification_datasets.FEWSHOT_EXAMPLE_SEP.join([
        dataset.fewshot_prompt_prefix(
            include_explanation=include_explanation_in_prefix
        ),
        dataset.get_fewshot_prompt_examples(
            rng=rng,
            train_df=train_df,
            n_shots=config.prompt_n_shots,
            prompt_template=fewshot_prompt_template,
        ),
    ])

  logging.info("is_instruction_tuned: %s", config.model_is_instruction_tuned)

  explanation_instructions = get_explanation_instructions(config)
  logging.info("explanation_instructions: %s", explanation_instructions)

  def predict_fewshot(
      row: pd.Series, prompt_base: str
  ) -> classification_datasets.Prediction:
    prompt = classification_datasets.FEWSHOT_EXAMPLE_SEP.join([
        prompt_base,
        dataset.get_prompt_end_for_new_example(
            row=row,
            prompt_template=config.prompt_template,
            is_instruction_tuned=config.model_is_instruction_tuned,
            include_explanations_in_fewshot=config.include_explanations_in_fewshot,
            explanation_instructions=explanation_instructions,
        ),
    ])
    return dataset.predict(
        model=model,
        prompt=prompt,
        prompt_template=config.prompt_template,
        lm_random_sampling=config.lm_random_sampling,
        is_instruction_tuned=config.model_is_instruction_tuned,
        raise_on_high_prob_of_non_class_tokens=config.raise_on_high_prob_of_non_class_tokens,
    )

  # If datasets aren't initially shuffled, correlated examples may be next to
  # each other; therefore, we shuffle the dataset so that if we're only
  # evaluating on a subset, we maximize diversity.
  # Ensure we use the same shuffle for all workers, and that it's independent
  # of sample length.
  shuffled_df = test_df.sample(frac=1, random_state=0)
  eval_df = shuffled_df[config.eval_start_idx : config.eval_end_idx]
  accuracy_df, intervention_df = (
      counterfactual_interventions.evaluate_counterfactual_faithfulness(
          predict_fewshot=predict_fewshot,
          dataset=dataset,
          sample_fewshot_prompt_base=sample_fewshot_prompt_base,
          eval_df=eval_df,
          n_samples_per_example=1,
          raise_on_generation_failure=config.raise_on_generation_failure,
          n_interventions=config.interventions.n_interventions_per_example,
          interventions_df=interventions_df,
      )
  )
  # Compute total variation distance resulting from interventions.
  # These DFs will have a column for each class token.
  base_probs_df = pd.json_normalize(intervention_df["base_token_probs"])
  modified_probs_df = pd.json_normalize(intervention_df["modified_token_probs"])
  assert np.array_equal(base_probs_df.columns, modified_probs_df.columns)
  base_probs_np = base_probs_df.values
  modified_probs_np = modified_probs_df.values
  # Shape: (n_examples, n_classes)
  assert base_probs_np.shape == modified_probs_np.shape
  # Could be 0, if we failed to get probabilities, e.g. for API models.
  assert base_probs_np.shape[1] in {0, 2, 3, 5}
  intervention_df["total_variation_distance"] = (
      np.linalg.norm(base_probs_np - modified_probs_np, ord=1, axis=1) / 2
  )

  # Save hyperparameters as a dataframe.
  flattened_config = tree.flatten_with_path(config.to_dict())
  flattened_config = {
      ".".join([str(k) for k in path]): v for path, v in flattened_config
  }
  config_df = pd.DataFrame([flattened_config])
  logging.info("Completed evaluation.")
  performance_logging.log_gpu_and_cpu_mem_stats()
  # Write dataframes.
  if accelerator.is_main_process:
    logging.info("Saving results...")
    for df, name in [
        (intervention_df, "intervention"),
        (accuracy_df, "accuracy"),
        (config_df, "config"),
    ]:
      df["model_is_instruction_tuned"] = config.model_is_instruction_tuned
      df["experiment_id"] = experiment_id
      df["worker_id"] = worker_id
      filename = os.path.join(
          config.io.save_results_df_path,
          f"{experiment_id}/{worker_id}/{name}.parquet",
      )
      with common_io.OpenFile(filename, mode="wb") as f:
        df.to_parquet(f)
      logging.info("Saved %s.", filename)

    # pylint: disable=g-explicit-length-test
    # This should always be nonempty, since we log accuracy data even if
    # prediction fails.
    # Run these checks after writing the dataframes, so that we still log data
    # even if these checks have errors.
    assert len(accuracy_df) > 0, f"No accuracy data: {accuracy_df}"
    if config.raise_on_generation_failure:
      if np.all(accuracy_df["base_pred_label"].isna()):
        raise ValueError(f"All {len(accuracy_df)} predictions failed.")
      if len(intervention_df) == 0:
        raise ValueError(f"No intervention data generated: {intervention_df}")
    logging.info("Average accuracy: %s", accuracy_df["accuracy"].mean())
    accuracies_by_example = accuracy_df.groupby("data_idx")["accuracy"].mean()
    logging.info(
        "Classified %d examples. Accuracy: %s",
        len(accuracies_by_example),
        mean_and_conf_interval_as_str(accuracies_by_example),
    )
    total_intervention_count = len(intervention_df)
    if total_intervention_count > 0:
      # Convert to pd.array so we can handle null data:
      # https://pandas.pydata.org/docs/user_guide/boolean.html#nullable-boolean-data-type
      as_bool = functools.partial(pd.array, dtype="boolean")
      pred_label_changed = as_bool(intervention_df["pred_label_changed"])
      mentions_inserted_word = as_bool(
          intervention_df["modified_explanation_mentions_inserted_word"]
      ) | as_bool(
          intervention_df["stemmed_modified_explanation_mentions_inserted_word"]
      )
      logging.info(
          "%d interventions, %d label changes, %d explanation mentions.",
          total_intervention_count,
          pred_label_changed.sum(),
          mentions_inserted_word.sum(),
      )
      true_positive = mentions_inserted_word & pred_label_changed
      false_positive = mentions_inserted_word & (~pred_label_changed)
      false_negative = (~mentions_inserted_word) & pred_label_changed
      true_negative = (~mentions_inserted_word) & (~pred_label_changed)
      logging.info(
          "TP: %d, FP: %d, FN: %d, TN: %d",
          true_positive.sum(),
          false_positive.sum(),
          false_negative.sum(),
          true_negative.sum(),
      )
      logging.info(
          "Counterfactual Test Faithfulness: %f",
          # NA values are treated as False when indexing:
          # https://pandas.pydata.org/docs/user_guide/boolean.html#indexing-with-na-values
          mentions_inserted_word[pred_label_changed].mean(),
      )
      logging.info(
          "Phi-CCT Faithfulness: %f",
          pd.Series(mentions_inserted_word).corr(pd.Series(pred_label_changed)),
      )
      if isinstance(model, transformers.TextGenerationPipeline):
        logging.info(
            "CCT Faithfulness: %f",
            pd.Series(mentions_inserted_word).corr(
                intervention_df["total_variation_distance"]
            ),
        )
      else:
        logging.warning(
            "CCT Faithfulness requires token probabilities, not supported for"
            " API models."
        )
    else:
      logging.warning("No interventions generated.")

  logging.info("Program completed.")


if __name__ == "__main__":
  print("Running main.", flush=True)
  try:
    app.run(main)
  except:  # pylint: disable=bare-except
    # When running locally, I want to start an interactive pdb post_mortem
    # session on a crash. However, the way accelerate uses multiple GPUs seems
    # to be to start a duplicate process for each one; when they both crash,
    # they both start pdb post_mortem sessions, which makes both unusable as
    # each stdin character seems to get sent to a random one.
    # But when one process exits, that seems to kill the entire session.
    # This hack will make the first process start the post_mortem debugging
    # session, and the others sleep forever (until a KeyboardInterrupt).
    # if FLAGS.pdb_post_mortem and sys.stdout.isatty():
    exc = sys.exc_info()[1]
    # Don't try to post-mortem debug successful SystemExits, since those
    # mean there wasn't actually an error. In particular, the test framework
    # raises SystemExit(False) even if all tests passed.
    if isinstance(exc, SystemExit) and not exc.code:
      raise
    logging.error("Caught exception, logging device statistics for debugging.")
    performance_logging.log_gpu_and_cpu_mem_stats(log_level=logging.WARNING)
    performance_logging.log_nvidia_smi_if_cuda_available(
        log_level=logging.WARNING
    )
    if sys.stdout.isatty():
      traceback.print_exc()
      if accelerator.is_main_process:
        print()
        print(" *** Entering post-mortem debugging ***")
        print()
        pdb.post_mortem()
        raise
      else:
        # Hacky workaround, see above
        while True:
          time.sleep(1000)
    else:
      raise
