# Copyright 2025 The corr_faith Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Produce docker commands to run a given hyperparameter sweep."""

from collections.abc import Sequence
import functools
import math
from typing import Any
from absl import app
from absl import flags
from absl import logging
from corr_faith import common_io
from corr_faith.analysis import model_metadata
import pandas as pd


CONTAINER_USER = "nonroot"
DATASETS = ["esnli", "ecqa", "comve"]

DataFrameLike = Any


EVAL_EXAMPLES_BY_DATASET = dict(
    # Evaluating on test sets.
    esnli=9842,
    ecqa=2194,
    comve=999,
)

_INTERVENTION_EXPERIMENT_ID = flags.DEFINE_integer(
    "intervention_experiment_id",
    0,
    "Identifier for the current experiment. Increment to avoid overwriting"
    " existing results.",
)

_FAITHFULNESS_EXPERIMENT_ID = flags.DEFINE_integer(
    "faithfulness_experiment_id",
    1,
    "Identifier for the current experiment worker.",
)

_SAVE_SWEEPS_PATH = flags.DEFINE_string(
    "save_sweeps_path",
    "/home/nonroot/results/",
    "Directory to save sweeps to.",
)


def cartesian_product(dfs: Sequence[DataFrameLike]) -> pd.DataFrame:
  dfs = [pd.DataFrame(df) for df in dfs]
  return functools.reduce(
      lambda left, right: pd.merge(left, right, how="cross"), dfs
  )


def intervention_naturalness_dataset_config(
    dataset: str,
    required_total_interventions: int = 10000,
    intervention_top_frac: float = 0.05,
) -> pd.DataFrame:
  required_interventions_per_example = math.ceil(
      required_total_interventions / EVAL_EXAMPLES_BY_DATASET[dataset]
  )
  candidate_interventions_per_example = math.ceil(
      required_interventions_per_example / intervention_top_frac
  )
  return pd.DataFrame([{
      "config.dataset": dataset,
      "config.interventions.n_interventions_per_example": (
          candidate_interventions_per_example
      ),
      "config.interventions.keep_top_frac": intervention_top_frac,
  }])


def intervention_naturalness_sweep(
    judge_model: str = "Qwen/Qwen2.5-72B-Instruct",
) -> pd.DataFrame:
  return cartesian_product([
      {"config.model": [judge_model]},
      pd.concat(
          [
              intervention_naturalness_dataset_config(dataset=dataset)
              for dataset in DATASETS
          ],
          axis="rows",
      ),
  ])


def evaluate_faithfulness_sweep() -> pd.DataFrame:
  """Returns a dataframe with model info for all models."""
  models_df = model_metadata.get_models_df()
  it_models = models_df[models_df["instruction_tuned"]]["model_name"]
  pt_models = models_df[~models_df["instruction_tuned"]]["model_name"]

  shared_settings = cartesian_product([
      {
          "config.dataset": DATASETS,
          "config.interventions.load_assessed_interventions_from_path": [
              f"/home/nonroot/results/{_INTERVENTION_EXPERIMENT_ID.value}/{worker_id}/"
              for worker_id in range(len(DATASETS))
          ],
      },
      {
          "config.prompt_template": [
              "predict_then_explain",
              "explain_then_predict",
          ]
      },
      {"config.io.save_results_df_path": ["/home/nonroot/results/"]},
  ])
  pt_settings = cartesian_product([
      {"config.model": pt_models},
      {"config.model_is_instruction_tuned": [False]},
      {"config.include_explanations_in_fewshot": [True]},
      {"config.explanation_length_instructions": ["EMPTY"]},
  ])
  it_settings = cartesian_product([
      {"config.model": it_models},
      {"config.model_is_instruction_tuned": [True]},
      {"config.include_explanations_in_fewshot": [True, False]},
      {
          "config.explanation_length_instructions": [
              "EMPTY",
              "COMPREHENSIVE",
              "VERY_COMPREHENSIVE",
              "CONCISE",
              "VERY_CONCISE",
          ]
      },
  ])
  return cartesian_product([
      shared_settings,
      pd.concat([pt_settings, it_settings], axis="rows"),
  ])


def get_docker_command(
    arg_dict: dict[str, Any],
    container_image: str = "$USER/corr_faith",
    container_entrypoint: str = "-m corr_faith.experiments.scripts.evaluate_faithfulness",
) -> str:
  """Returns a docker command for the given settings."""
  results_source = "/tmp/corr_faith/"
  results_dest = f"/home/{CONTAINER_USER}/results/"
  hf_cache_source = "$HOME/.cache/huggingface/hub/"
  hf_cache_dest = f"/home/{CONTAINER_USER}/.cache/huggingface/hub/"
  gcloud_cred_source = (
      "$HOME/.config/gcloud/application_default_credentials.json"
  )
  gcloud_cred_dest = f"/home/{CONTAINER_USER}/.config/gcloud/application_default_credentials.json"
  base_docker_cmd = f"""docker run --rm -it \
    --user {CONTAINER_USER} \
    --env GEMINI_API_KEY=$GEMINI_API_KEY \
    --env OPENAI_API_KEY=$OPENAI_API_KEY \
    --env ANTHROPIC_API_KEY=$ANTHROPIC_API_KEY \
    --mount type=bind,source={results_source},destination={results_dest} \
    --mount type=bind,source={hf_cache_source},destination={hf_cache_dest} \
    --mount readonly,type=bind,source={gcloud_cred_source},destination={gcloud_cred_dest} \
    """
  model_is_local = (
      model_metadata.get_model_info(arg_dict["config.model"])["api_type"]
      == "local"
  )
  if model_is_local:
    gpus_cmd = "--gpus all"
  else:
    # For API models, use only CPU.
    gpus_cmd = ""
  full_cmd = " ".join([
      base_docker_cmd,
      gpus_cmd,
      container_image,
      container_entrypoint,
      " ".join([f"--{key}={value}" for key, value in arg_dict.items()]),
  ])
  return full_cmd


def main(argv: Sequence[str]) -> None:
  if len(argv) > 1:
    raise app.UsageError("Too many command-line arguments.")
  logging.info("Generating sweeps.")

  intervention_sweep = intervention_naturalness_sweep()
  intervention_sweep["experiment_id"] = _INTERVENTION_EXPERIMENT_ID.value
  intervention_sweep["worker_id"] = range(len(intervention_sweep))
  intervention_commands = []
  for unused_idx, row in intervention_sweep.iterrows():
    intervention_commands.append(
        get_docker_command(
            arg_dict=row.to_dict(),
            container_image="$USER/corr_faith",
            container_entrypoint=(
                "-m"
                " corr_faith.experiments.scripts.generate_and_assess_interventions"
            ),
        )
    )
  faithfulness_sweep = evaluate_faithfulness_sweep()
  faithfulness_sweep["experiment_id"] = _FAITHFULNESS_EXPERIMENT_ID.value
  faithfulness_sweep["worker_id"] = range(len(faithfulness_sweep))
  faithfulness_commands = []
  for unused_idx, row in faithfulness_sweep.iterrows():
    faithfulness_commands.append(
        get_docker_command(
            arg_dict=row.to_dict(),
            container_image="$USER/corr_faith",
            container_entrypoint=(
                "-m corr_faith.experiments.scripts.evaluate_faithfulness"
            ),
        )
    )

  intervention_sweep_path = _SAVE_SWEEPS_PATH.value + "intervention_sweep.txt"
  with common_io.OpenFile(intervention_sweep_path, "w") as f:
    f.write("\n".join(intervention_commands) + "\n")
  logging.info("Saved intervention sweep to %s", intervention_sweep_path)

  faithfulness_sweep_path = _SAVE_SWEEPS_PATH.value + "faithfulness_sweep.txt"
  with common_io.OpenFile(faithfulness_sweep_path, "w") as f:
    f.write("\n".join(faithfulness_commands) + "\n")
  logging.info("Saved faithfulness sweep to %s", faithfulness_sweep_path)


if __name__ == "__main__":
  app.run(main)
