# Copyright 2025 The corr_faith Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Generate counterfactual interventions and assess their naturalness."""

from collections.abc import Sequence
import os
import pdb
import sys
import time
import traceback

from absl import app
from absl import flags
from absl import logging
import accelerate
from corr_faith import common_io
from corr_faith.experiments import counterfactual_interventions
from corr_faith.experiments import dataset_specific
from corr_faith.experiments import intervention_naturalness
from corr_faith.experiments import performance_logging
from corr_faith.experiments.models import huggingface_models_container
from ml_collections import config_dict
from ml_collections import config_flags
import pandas as pd
import torch
import transformers
import tree

FLAGS = flags.FLAGS

_EXPERIMENT_ID = flags.DEFINE_integer(
    "experiment_id",
    0,
    "Identifier for the current experiment. Increment to avoid overwriting"
    " existing results.",
)
_WORKER_ID = flags.DEFINE_integer(
    "worker_id",
    0,
    "Identifier for the current experiment worker.",
)


def get_config():
  """Returns a ConfigDict containing experiment hyperparameters."""
  config = config_dict.ConfigDict()
  config.eval_start_idx = 0
  config.eval_end_idx = 100
  config.model_is_instruction_tuned = True
  config.model = "Qwen/Qwen2.5-14B-Instruct"
  config.dataset = "esnli"
  # Can set to True for local debugging.
  config.raise_on_generation_failure = False
  config.raise_on_high_prob_of_non_class_tokens = False

  config.interventions = config_dict.ConfigDict()
  config.interventions.n_interventions_per_example = 20
  config.interventions.keep_top_frac = 0.05

  config.esnli_kwargs = config_dict.ConfigDict()
  config.esnli_kwargs.keys_to_intervene_on = ("Sentence1", "Sentence2")

  config.io = config_dict.ConfigDict()
  config.io.base_data_dir = "/home/nonroot/data/"
  config.io.download_hf_weights_from_gcs_path = ""
  config.io.max_download_attempts = 5
  config.io.save_results_df_path = "/home/nonroot/results/"
  return config


_CONFIG = config_flags.DEFINE_config_dict("config", get_config())


accelerator = accelerate.Accelerator()


def main(argv: Sequence[str]) -> None:
  if len(argv) > 1:
    raise app.UsageError(
        f'Too many command-line arguments. Try removing an extra "--"? {argv}'
    )

  # Remove this when pandas is updated to 3.0.
  # https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  pd.options.mode.copy_on_write = True

  logging.info("Starting program.")
  performance_logging.init_logging(accelerator)
  performance_logging.log_system_debugging_info()

  experiment_id = _EXPERIMENT_ID.value
  worker_id = _WORKER_ID.value
  config = _CONFIG.value
  logging.info("Config values: %s", config)

  if common_io.is_gcs(config.io.save_results_df_path):
    # If saving to GCS, ensure we have the proper credentials; check this
    # before running the experiment, so we don't just crash at the end.
    common_io.get_gcs_client()

  logging.info("Loading dataset.")
  dataset = dataset_specific.dataset_from_config(config)
  data_splits = dataset.load_data_splits(base_data_dir=config.io.base_data_dir)
  test_df = data_splits["test"]

  if accelerator.is_main_process:
    if config.io.download_hf_weights_from_gcs_path:
      # Debug crashes during model loading.
      with performance_logging.LogMemoryPeriodically(period=5.0):
        common_io.download_model_from_gcs(
            model_name=config.model,
            source_dir=config.io.download_hf_weights_from_gcs_path,
            destination_dir=config.io.base_data_dir,
            max_download_attempts=config.io.max_download_attempts,
        )
  # There are no plans to support remote models other than from HF hub:
  # https://github.com/huggingface/transformers/issues/23412
  # So to use remote models, we have to download them first.
  if config.io.download_hf_weights_from_gcs_path:
    model_path = os.path.join(config.io.base_data_dir, config.model)
  else:
    model_path = config.model
  logging.info("Loading transformers pipeline from model: %s", config.model)
  huggingface_models_container.suppress_dynamo_errors(config.model)
  model_kwargs = huggingface_models_container.MODEL_KWARGS_SPECIAL_CASES.get(
      config.model, dict()
  )
  if "torch_dtype" not in model_kwargs:
    model_kwargs["torch_dtype"] = torch.float16
  model = transformers.pipeline(
      "text-generation",
      model=model_path,
      device_map="auto",
      trust_remote_code=False,
      model_kwargs=model_kwargs,
  )
  logging.info("Loaded huggingface model.")

  if not config.model_is_instruction_tuned:
    raise ValueError(
        "Assessing interventions requires an instruction-tuned model."
    )

  # If datasets aren't initially shuffled, correlated examples may be next to
  # each other; therefore, we shuffle the dataset so that if we're only
  # evaluating on a subset, we maximize diversity.
  # Ensure we use the same shuffle for all workers, and that it's independent
  # of sample length.
  shuffled_df = test_df.sample(frac=1, random_state=0)
  eval_df = shuffled_df[config.eval_start_idx : config.eval_end_idx]
  generated_intervention_df = (
      counterfactual_interventions.generate_interventions(
          dataset=dataset,
          eval_df=eval_df,
          n_interventions=config.interventions.n_interventions_per_example,
      )
  )
  assessed_intervention_df = intervention_naturalness.assess_intervention_naturalness(
      pipeline=model,
      eval_df=eval_df,
      # prompt_template=PROMPT_TEMPLATE,
      intervention_df=generated_intervention_df,
  )
  intervention_naturalness_threshold_by_group = (
      assessed_intervention_df.groupby("data_idx")[
          "intervention_naturalness_yes_prob"
      ].transform(lambda x: x.quantile(1 - config.interventions.keep_top_frac))
  )
  natural_interventions_df = assessed_intervention_df[
      assessed_intervention_df["intervention_naturalness_yes_prob"]
      >= intervention_naturalness_threshold_by_group
  ]

  # Save hyperparameters as a dataframe.
  flattened_config = tree.flatten_with_path(config.to_dict())
  flattened_config = {
      ".".join([str(k) for k in path]): v for path, v in flattened_config
  }
  config_df = pd.DataFrame([flattened_config])
  logging.info("Completed evaluation.")
  if accelerator.is_main_process:
    for df, name in [
        (assessed_intervention_df, "assessed_interventions"),
        (natural_interventions_df, "natural_interventions"),
        (config_df, "config"),
    ]:
      df["model_is_instruction_tuned"] = config.model_is_instruction_tuned
      df["experiment_id"] = experiment_id
      df["worker_id"] = worker_id
      filename = os.path.join(
          config.io.save_results_df_path,
          f"{experiment_id}/{worker_id}/{config.dataset}/{name}.parquet",
      )
      with common_io.OpenFile(filename, mode="wb") as f:
        df.to_parquet(f)
      logging.info("Saved %s.", filename)


if __name__ == "__main__":
  print("Running app.", flush=True)
  try:
    performance_logging.add_stop_signal_handlers()
    app.run(main)
  except:  # pylint: disable=bare-except
    # When running locally, I want to start an interactive pdb post_mortem
    # session on a crash. However, the way accelerate uses multiple GPUs seems
    # to be to start a duplicate process for each one; when they both crash,
    # they both start pdb post_mortem sessions, which makes both unusable as
    # each stdin character seems to get sent to a random one.
    # But when one process exits, that seems to kill the entire session.
    # This hack will make the first process start the post_mortem debugging
    # session, and the others sleep forever (until a KeyboardInterrupt).
    # if FLAGS.pdb_post_mortem and sys.stdout.isatty():
    exc = sys.exc_info()[1]
    # Don't try to post-mortem debug successful SystemExits, since those
    # mean there wasn't actually an error. In particular, the test framework
    # raises SystemExit(False) even if all tests passed.
    if isinstance(exc, SystemExit) and not exc.code:
      raise
    if sys.stdout.isatty():
      traceback.print_exc()
      if accelerator.is_main_process:
        print()
        print(" *** Entering post-mortem debugging ***")
        print()
        pdb.post_mortem()
        raise
      else:
        # Hacky workaround, see above
        while True:
          time.sleep(1000)
    else:
      raise
