"""
Uses OpenAI's Batch API to generate completions for a given task and prompt.

Usage:
    python script_name.py --config-file CONFIG_FILE [OPTIONS]

This script automates the process of submitting batch requests to OpenAI's API,
handling the responses, and organizing the outputs.

Arguments:
    --config-file CONFIG_FILE    Path to the YAML configuration file (required)
    --model MODEL                OpenAI model to use (default: gpt-4o-mini)
    --num-samples NUM            Number of samples per input (default: 25)
    --output-dir DIR             Base output directory (default: ./outputs/)
    --batch-size SIZE            Number of prompts per API request (default: 500)
    --limit LIMIT                Limit the number of inputs processed (optional)
    --experiment-name NAME       Custom name for the experiment (optional)
    --dry-run                    Run without making API calls (for testing)
    --delay SECONDS              Delay between batch completion checks (default: 60)

The script requires OpenAI API credentials to be set in the environment.

"""

import argparse
import dataclasses
import pandas as pd
import json
import textwrap
import time
import typing as tp

import openai
import upath
from dotenv import load_dotenv
from loguru import logger
from tqdm.auto import tqdm

from llm_inference import openai_utils, prompts, tasks, utils

load_dotenv()


def format_prompts(prompts: dict[str, str], width: int = 80) -> str:
  """
  Formats a dictionary of prompts for a language model, wrapping text using textwrap.

  Args:
  prompts (Dict[str, str]): A dictionary where keys are prompt names and values are the prompts.
  width (int): The maximum width of each line. Default is 80.

  Returns:
  str: A formatted string containing all prompts.
  """
  output = []
  for name, prompt in prompts.items():
    output.append("=" * width)
    output.append(f"PROMPT: {name}")
    output.append("=" * width)

    # Wrap the text using textwrap
    wrapped_text = textwrap.fill(prompt, width=width)
    output.append(wrapped_text)

    output.append("=" * width)
    output.append("")  # Add an empty line between prompts

  return "\n".join(output)


@dataclasses.dataclass
class Config:
  model: str
  task: str
  user_prompt: str
  system_prompt: str | None = None
  temperature: float | None = None
  num_samples: int | None = None
  top_p: float | None = None
  logprobs: bool | None = True

  _exclude_openai_gen_kwargs = ("task", "user_prompt", "system_prompt")

  def to_gen_kwargs(self):
    kwargs = {
      k: v
      for k, v in dataclasses.asdict(self).items()
      if k not in self._exclude_openai_gen_kwargs
    }
    kwargs["n"] = kwargs.pop("num_samples")
    kwargs = {k: v for k, v in kwargs.items() if v is not None}
    return kwargs


def regex_glob(pattern, directory=".", glob: str | None = None):
  import re

  path = upath.UPath(directory)
  # Regular expression pattern
  pattern = re.compile(pattern)
  # Find matching files
  files_to_match = path.glob(glob) if glob else path.iterdir()
  matching_files = [f for f in files_to_match if pattern.match(f.name)]
  return matching_files


def main(
  cfg: Config,
  batch_size: int = 500,
  base_output_dir: str = "./outputs/",
  experiment_name: str | None = None,
  limit: int | None = None,
  dry_run: bool = False,
  delay: int = 60,
  overwrite_outputs: bool = False,
):
  if experiment_name is None:
    experiment_name = utils.create_experiment_name(
      [
        cfg.model.replace("/", "-"),
        cfg.task.replace("/", "-"),
      ]
    )
  output_dir = upath.UPath(base_output_dir) / experiment_name
  output_dir.mkdir(parents=True, exist_ok=True)
  config_output_path = output_dir / "config.yaml"
  if not config_output_path.exists():
    config_output_path.write_text(utils.dump_yaml(dataclasses.asdict(cfg)))
  else:
    assert utils.load_yaml(config_output_path) == dataclasses.asdict(
      cfg
    ), f"Config file mismatch: {config_output_path}"
    logger.info("[Config] Found config file")

  logger.info(f"Output directory: {output_dir.as_posix()}")
  task = tasks.TASKS[cfg.task]
  user_prompt = utils.with_dict_inputs(prompts.PROMPTS[cfg.user_prompt])

  dataset = task.load_dataset()
  if limit is not None:
    dataset = dataset.take(limit)

  # print some prompts
  examples = dataset.take(3)
  for example in examples:
    # model_inputs = utils.with_dict_inputs()
    prompts_to_log = {
      "user_prompt": user_prompt(example),
    }
    if cfg.system_prompt is not None:
      prompts_to_log["system_prompt"] = cfg.system_prompt
    logger.info("\n" + format_prompts(prompts_to_log))

  logger.info("Generation Kwargs:\n" + utils.dump_yaml(cfg.to_gen_kwargs()))
  if dry_run:
    return
  batch_requests = openai_utils.openai_batches_from_dataset(
    dataset,
    user_prompt,
    system_prompt=(lambda _: tp.cast(str, cfg.system_prompt))
    if cfg.system_prompt
    else None,
    **cfg.to_gen_kwargs(),
  )

  # ---------------------------------------------------------------------------- #
  #                           Collect data sequentially                          #
  # ---------------------------------------------------------------------------- #
  batch_request_chunks = utils.split_dataset(batch_requests, batch_size)
  batch_submissions_path = output_dir / "batch_submissions"
  batch_outputs_path = output_dir / "batch_outputs"
  batch_submissions_path.mkdir(parents=True, exist_ok=True)
  batch_outputs_path.mkdir(parents=True, exist_ok=True)

  for chunk_idx, batch_requests_chunk in tqdm(
    enumerate(batch_request_chunks),
    desc="Submitting batches",
    total=len(batch_request_chunks),
  ):
    batch_request_file = batch_submissions_path / f"batch_{chunk_idx}.jsonl"
    batch_info_file = batch_submissions_path / f"batch_info_{chunk_idx}_.json"
    if not batch_info_file.exists():
      batch_requests_chunk.to_json(batch_request_file)
      metadata = {
        "experiment_name": experiment_name,
        "output_dir": output_dir.as_posix(),
        "chunk_idx": str(chunk_idx),
        "task": cfg.task,
        "user_prompt": cfg.user_prompt,
        "system_prompt": cfg.system_prompt,
      }
      metadata = {k: v for k, v in metadata.items() if v is not None}
      batch = openai_utils.post_batch(
        batch_request_file,
        metadata=metadata,
      )
      with batch_info_file.open("w") as f:
        f.write(batch.to_json(indent=2))
      batch_info = batch.to_dict()
    else:
      logger.info("[Batch {}] Found batch info file", chunk_idx)
      with batch_info_file.open("r") as f:
        batch_info = json.load(f)

    # ---------------------------------------------------------------------------- #
    #                             write outputs to file                            #
    # ---------------------------------------------------------------------------- #
    batch_output_file = batch_outputs_path / f"batch_{chunk_idx}.jsonl"
    if not batch_output_file.exists():
      batch_completion_file = (
        batch_submissions_path / f"batch_info_{chunk_idx}_completed.json"
      )
      if not batch_completion_file.exists():
        completion_status = openai_utils.wait_until_batch_completion(
          batch_info["id"], delay=delay
        )
        with batch_completion_file.open("w") as f:
          f.write(completion_status.to_json(indent=2))
      else:
        logger.info("[Batch {}] Found batch completion file", chunk_idx)

      if not batch_output_file.exists():
        client = openai.OpenAI()
        while True:
          output_file_id = client.batches.retrieve(batch_info["id"]).output_file_id
          if output_file_id is not None:
            break
          logger.info("Waiting for output file to be created...")
          time.sleep(1)

        file_contents = openai_utils.get_openai_files([output_file_id])[0]
        with batch_output_file.open("w") as f:
          f.write(file_contents)
      else:
        logger.info("[Batch {}] Found batch output file", chunk_idx)
    else:
      logger.info("[Batch {}] Found batch output file", chunk_idx)

  # ---------------------------------------------------------------------------- #
  #                         Final output post-processing                         #
  # ---------------------------------------------------------------------------- #
  infer_output_file = output_dir / "infer.parquet"
  if (not infer_output_file.exists()) or overwrite_outputs:
    batch_input_files = regex_glob(
      r"^batch_(\d+)\.jsonl$", batch_submissions_path, glob="batch_*.jsonl"
    )
    batch_output_files = regex_glob(
      r"^batch_(\d+)\.jsonl$", batch_outputs_path, glob="batch_*.jsonl"
    )

    def _load_input():
      input_dfs = []
      for file in batch_input_files:
        df = openai_utils.parse_batch_submission(file)
        input_dfs.append(df)

      df = pd.concat(input_dfs)
      df["custom_id"] = df["custom_id"].astype(int)
      df = df.rename(columns={"custom_id": "dataset_idx"})
      return df

    def _load_output_files():
      output_dfs = []
      for file in batch_output_files:
        df = openai_utils.parse_batch_completion_outputs(file)
        output_dfs.append(df)

      df = pd.concat(output_dfs)
      df["custom_id"] = df["custom_id"].astype(int)
      df = df.rename(columns={"custom_id": "dataset_idx"})
      return df

    input_df = _load_input()
    output_df = _load_output_files()

    merged_df = pd.merge(input_df, output_df, on="dataset_idx")
    logger.info(f"Merged dataframe shape: \n{merged_df.head()}")
    logger.info(f"num rows: {len(merged_df)}")

    merged_df.to_parquet(output_dir / "infer.parquet")

    logger.info(f"Batch outputs written to {output_dir.as_posix()}")
  else:
    logger.info("[Final] Found infer output file")


if __name__ == "__main__":
  parser = argparse.ArgumentParser()
  parser.add_argument("--model", type=str, default="gpt-4o-mini")
  parser.add_argument("--config-file", type=str, required=True)
  parser.add_argument("--num-samples", type=int, default=25)
  parser.add_argument("--output-dir", type=str, default="./outputs/")
  parser.add_argument("--batch-size", type=int, default=500)
  parser.add_argument("--limit", type=int, default=None)
  parser.add_argument("--experiment-name", type=str, default=None)
  parser.add_argument("--dry-run", action="store_true")
  parser.add_argument("--delay", type=int, default=60)
  parser.add_argument("--overwrite-outputs", action="store_true")

  args = parser.parse_args()
  # ---------------------------------------------------------------------------- #
  #                                  Load config                                 #
  # ---------------------------------------------------------------------------- #
  # 1. Load raw config
  task_cfg = utils.load_yaml_with_includes(args.config_file)
  # 2. Create `Config` object
  cfg = Config(
    model=args.model,
    task=task_cfg["task"],
    user_prompt=task_cfg["prompt"],
    system_prompt=task_cfg.get("system_prompt"),
    temperature=task_cfg.get("temperature"),
    num_samples=args.num_samples,
    top_p=task_cfg.get("top_p"),
  )
  logger.info(
    "Launching task with config:\n" + utils.dump_yaml(dataclasses.asdict(cfg))
  )
  logger.info("Args:\n" + utils.dump_yaml(vars(args)))

  main(
    cfg,
    batch_size=args.batch_size,
    base_output_dir=args.output_dir,
    experiment_name=args.experiment_name,
    dry_run=args.dry_run,
    limit=args.limit,
    delay=args.delay,
    overwrite_outputs=args.overwrite_outputs,
  )
