import argparse
import json
import pandas as pd
import pathlib
import time
import typing as tp
from pprint import pformat

from loguru import logger

from llm_inference import output_parsers, tasks, eval_utils

DATASET_IDX_KEY = "dataset_idx"
GENERATED_TEXT_KEY = "generated_text"
INDEX_COLUMNS = ("dataset_idx", "sample_idx")


def eval_predictions(
  task: tasks.Task,
  infer_outputs: str,
  output_parser: tp.Callable[[str], str],
  save: bool = True,
  overwrite: bool = False,
  output_file: str | None = None,
):
  """Evaluate model predicitons on the given task.

  Args:
      task: task to evaluate on
      infer_outputs: path to model predictions
      save: whether or not to save evaluation results to a file. Defaults to True.
      overwrite: if saving eval results, whether or not to overwrite existing
          eval results (if it exists). Defaults to False.
      output_parser: parser to use for extracting valid code from model predictions. Defaults to None.
      output_file: name of the outputfile (this must be just the name rather
          than a complete path). Defaults to None.

  """
  import upath

  path = upath.UPath(infer_outputs)
  if output_file is None:
    eval_results_path = (
      path.parent / "eval.json" if path.is_file() else path / "eval.json"
    )
  else:
    eval_results_path = path.parent / output_file

  if eval_results_path.exists() and save and not overwrite:
    creation_time = time.localtime(eval_results_path.stat().st_mtime)
    raise ValueError(
      (
        "Evaluation file already at '{}' (created at '{}')\n"
        "Use 'overwrite=True' to overwrite existing eval file."
      ).format(
        str(eval_results_path),
        time.strftime("%Y-%m-%d %H:%M:%S", creation_time),
      )
    )

  if path.is_dir():
    if (path / "infer.parquet").exists():
      path = path / "infer.parquet"
    else:
      raise ValueError(f"Could not find infer.jsonl or infer.parquet in '{str(path)}'")

  df = pd.read_parquet(path.as_posix())
  eval_cfg = task.get_evaluation_cfg()
  evaluator = eval_utils.Evaluator(eval_cfg, prediction=GENERATED_TEXT_KEY)

  summary, details = evaluator.run(df, task.load_dataset())
  logger.info("Summary:\n{}", summary)
  logger.info("Saving evaluation results to '{}'", eval_results_path.as_posix())
  with eval_results_path.open("w") as f:
    json.dump(
      {
        "eval_results": summary,
        "aux": details.to_dict("records"),
      },
      f,
    )


if __name__ == "__main__":
  parser = argparse.ArgumentParser(description="Evaluate model predictions on a task.")
  parser.add_argument(
    "--inputs", type=str, help="Path to model predictions", required=True
  )
  parser.add_argument(
    "--task",
    type=str,
    help="Task to evaluate on.",
    choices=tasks.TASKS.keys(),
    required=True,
  )
  parser.add_argument(
    "--output-parser",
    type=str,
    help="Output parser to use for extracting code from model predictions.",
    choices=output_parsers.OUTPUT_PARSERS.keys(),
    default="passthrough_output_parser",
  )
  parser.add_argument("--save", action="store_true", help="Save evaluation results.")
  parser.add_argument(
    "--overwrite",
    action="store_true",
    help="Overwrite existing evaluation results.",
  )
  parser.add_argument(
    "--output-file",
    type=str,
    help="Name of the outputfile (this must be just the name rather than a complete path).",
    default="eval.json",
  )

  args = parser.parse_args()
  task = tasks.TASKS[args.task]
  output_parser = output_parsers.OUTPUT_PARSERS[args.output_parser]

  eval_predictions(
    task=task,
    output_parser=output_parser,
    infer_outputs=args.inputs,
    save=args.save,
    overwrite=args.overwrite,
    output_file=args.output_file,
  )
