"""
This module provides functionality for performing distributed inference on HuggingFace models via ray and vllm.

Example usage from the command line:
python generate.py --model microsoft/phi-2 --config-file configs/generation/mbpp --num_samples 10 --output_dir ./outputs

Output Files:
- generate.log: Log file containing debug-level logs.
- config.yaml: YAML file containing the configuration used for generation.
- infer.parquet: Parquet file containing the generated samples.
"""

import argparse
import functools
import pathlib
import sys
import typing as tp
from datetime import datetime
import upath
from loguru import logger
from pydantic import Field

from llm_inference import (
  generate_lib,
  output_parsers,
  prompts,
  tasks,
  utils,
)
from llm_inference.eval import eval_predictions


def get_git_commit_hash():
  from git import Repo
  from git.exc import InvalidGitRepositoryError

  try:
    repo = Repo(search_parent_directories=True)
    return repo.head.commit.hexsha
  except InvalidGitRepositoryError:
    return "Not a git repository"
  except Exception as e:
    return f"Error: {str(e)}"


# ---------------------------------------------------------------------------- #
#                       Configuration for Generation Task                      #
# ---------------------------------------------------------------------------- #
class RawConfig(utils.ConfigDict):
  limit: int = Field(-1, description="Limit the number of examples to process")
  tensor_parallel_size: int = Field(
    1,
    description=(
      "Tensor parallel size. By default 1 (use 1 gpu per vllm model). "
      "Set to -1 to split model across all available gpus, "
      "or some n >= 1 to split across n gpus."
    ),
  )
  trust_remote_code: bool = Field(
    True,
    description="Trust remote code (required for several models with custom modelling code)",
  )
  enforce_eager: bool = Field(
    False,
    description="Enforce eager execution (saves memory in certain cases)",
  )
  batch_size: int = Field(
    32,
    description="Batch size for generation",
  )
  concurrency: int = Field(
    -1,
    description="Concurrency for generation. If set to -1, will use max available gpus",
  )
  top_p: float = Field(
    0.95,
    description="Top p sampling parameter for generation",
  )
  top_k: int = Field(
    -1,
    description="Top k sampling parameter for generation",
  )
  temperature: float = Field(
    0.2,
    description="Temperature sampling parameter for generation",
  )
  max_new_tokens: int = Field(
    512,
    description="Maximum number of tokens to generate",
  )
  max_prompt_tokens: int = Field(
    1536,
    description="Maximum number of tokens in the prompt",
  )
  prompt: tp.Optional[str] = Field(
    None,
    description="Prompt to use",
    examples=list(prompts.PROMPTS.keys()),
  )
  task: tp.Optional[str] = Field(
    None,
    description="Task to generate for",
    examples=list(tasks.TASKS.keys()),
  )
  model: tp.Optional[str] = Field(
    None,
    description="Model name or path from huggingface model hub",
  )
  num_samples: int = Field(
    1,
    description="Number of samples to generate per example",
  )
  stop_tokens: list[str] = Field(
    default_factory=list,
    description="List of stop tokens for generation",
  )
  output_parser: str = Field(
    default="passthrough_output_parser",
    description="Output parser to use for extracting code from model predictions (Only used for evaluation)",
    examples=list(output_parsers.OUTPUT_PARSERS.keys()),
  )

  @classmethod
  def names_to_exclude(cls):
    return ["stop_tokens"]

  def resolve(self) -> "ResolvedConfig":
    model_dict = self.model_dump()
    for k in ["model", "prompt", "task"]:
      if getattr(self, k) is None:
        raise ValueError(f"{k} is required to resolve the config")
      model_dict[k] = getattr(self, k)
    return ResolvedConfig(**model_dict)


class ResolvedConfig(RawConfig):
  model: str
  prompt: str
  task: str


T = tp.TypeVar("T")


def token_throughput_logger(
  batch_iter: tp.Iterator[T],
  *,
  get_generated_token_count: tp.Callable[[T], int],
  log_every: int = 2,
):
  """Logs token throughput for a batch iterator."""
  import time

  total_tokens = 0
  tokens_since_last_log = 0
  start_time = prev_time = time.time()
  for batch_idx, batch in enumerate(batch_iter):
    tokens_since_last_log += get_generated_token_count(batch)

    yield batch

    if batch_idx % log_every == 0:
      elapsed_time = time.time() - prev_time
      token_throughput = tokens_since_last_log / elapsed_time
      logger.info(
        "Batch: {}, Token Throughput: {:.2f} tokens/sec. Tokens generated per step: {}",
        batch_idx,
        token_throughput,
        tokens_since_last_log / log_every,
      )
      total_tokens += tokens_since_last_log
      tokens_since_last_log = 0
      prev_time = time.time()
  total_tokens += tokens_since_last_log
  total_time_elapsed = time.time() - start_time
  logger.info(
    "Generated {} tokens in {:.2f} seconds. Token Throughput: {:.2f} tokens/sec",
    total_tokens,
    total_time_elapsed,
    total_tokens / total_time_elapsed,
  )


def generate(
  cfg: ResolvedConfig,
  output_base_dir: upath.UPath,
  experiment_name: str | None = None,
  do_eval: bool = True,
):
  """
  Generate samples using the specified configuration.

  Args:
      cfg: The resolved configuration object.
      output_base_dir: The base directory where the output files will be saved.
      do_eval: Whether to run evaluation on the generated outputs. Defaults to True.

  Notes:
      This function generates samples using the specified configuration. It performs the following steps:
      1. Creates an experiment directory based on the model, task, number of samples, and timestamp.
      2. Loads the task and prompt, and loads the dataset.
      3. Prepares the dataset for generation by applying the prompt and limiting the number of samples.
      4. Starts the inference loop, generating samples batch by batch.
      5. Writes the generated samples to a Parquet file.
      6. Runs evaluation on the generated outputs if `do_eval` is True.

  Output Files:
      - generate.log: Log file containing debug-level logs.
      - config.yaml: YAML file containing the configuration used for generation.
      - infer.parquet: Parquet file containing the generated samples.

  """

  # ---------------------------------------------------------------------------- #
  #                          Create Experiment Directory                         #
  # ---------------------------------------------------------------------------- #
  timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
  if experiment_name is None:
    experiment_name = (
      f"{cfg.model.replace('/', '_')}-{cfg.task}-sample{cfg.num_samples}-{timestamp}"
    )
  experiment_dir = output_base_dir / experiment_name

  logger.info("Creating output directory: {}", str(experiment_dir))
  experiment_dir.mkdir(exist_ok=True, parents=True)
  logger_file_handler = logger.add(experiment_dir / "generate.log", level="DEBUG")
  with (experiment_dir / "command.sh").open("w") as f:
    f.write(" ".join(sys.argv) + "\n")

  if not (experiment_dir / "git_commit_hash.txt").exists():
    git_commit_hash = get_git_commit_hash()
    logger.info("Git commit hash: {}", git_commit_hash)
    with (experiment_dir / "git_commit_hash.txt").open("w") as f:
      f.write(git_commit_hash)

  # ---------------------------------------------------------------------------- #
  #                     Get task and prompt, and load dataset                    #
  # ---------------------------------------------------------------------------- #
  task = tasks.TASKS[cfg.task]
  prompt = prompts.PROMPTS[cfg.prompt]
  if not cfg.stop_tokens:
    logger.info(
      "No stop tokens provided, using default stop tokens from task: {}",
      task.stop_tokens,
    )
    cfg.stop_tokens = task.stop_tokens

  logger.info("Final Config:\n{}", cfg.dumps())

  dataset = task.load_dataset()
  if cfg.limit > 0:
    logger.info("Limiting dataset to {} samples", cfg.limit)
    dataset = dataset.select(range(cfg.limit))

  sampling_cfg = generate_lib.SamplingConfig(
    top_p=cfg.top_p,
    top_k=cfg.top_k,
    temperature=cfg.temperature,
    max_new_tokens=cfg.max_new_tokens,
    max_prompt_tokens=cfg.max_prompt_tokens,
    stop_tokens=cfg.stop_tokens,
  )
  generation_cfg = generate_lib.GenerationConfig(
    sampling_config=sampling_cfg,
    input_ids_key="prompt",
  )

  if not (experiment_dir / "config.yaml").exists():
    logger.info("Saving config to {}", str(experiment_dir / "config.yaml"))
    cfg.dump(experiment_dir / "config.yaml")

  # ---------------------------------------------------------------------------- #
  #                             Start inference loop                             #
  # ---------------------------------------------------------------------------- #
  output_file = experiment_dir / "infer.parquet"
  logger.info("Starting generation for {} samples", dataset.num_rows)
  generate_lib.distributed_generate_vllm_v2(
    model_id=cfg.model,
    hf_dataset=dataset,
    output_file=output_file,
    preprocessors=[
      functools.partial(
        generate_lib.batched_apply_prompt,
        prompt=utils.with_dict_inputs(prompt, strict=False),
        output_key="prompt",
      ),
      # functools.partial(
      #     utils.tokenize_dataset,
      #     tokenizer=AutoTokenizer.from_pretrained(cfg.model),
      #     columns=["prompt"],
      #     max_length=generation_cfg.sampling_config.max_prompt_tokens,
      #     padding=False,
      #     truncation=True,
      #     return_tensors=None,
      # ),
    ],
    generation_cfg=generation_cfg,
    concurrency=cfg.concurrency,
    num_samples=cfg.num_samples,
    batch_size=cfg.batch_size,
    # cache_dir=experiment_dir / ".generate-cache",
  )

  # output_ds.write_parquet(str(output_file))

  # # ---------------------------------------------------------------------------- #
  # #                                Run evaluation                                #
  # # ---------------------------------------------------------------------------- #
  if do_eval:
    logger.info("Starting evaluation for generated outputs")
    eval_predictions(
      task=task,
      infer_outputs=output_file.as_posix(),
      output_parser=output_parsers.OUTPUT_PARSERS[cfg.output_parser],
    )

  logger.info("Outputs saved to: {}", experiment_dir.as_posix())
  logger.remove(logger_file_handler)


if __name__ == "__main__":
  parser = argparse.ArgumentParser()
  parser = RawConfig.add_arguments(parser)
  parser.add_argument(
    "--output-dir",
    type=str,
    default="./outputs",
    help="Path to the output directory",
  )
  parser.add_argument("--experiment-name", type=str, default=None)
  parser.add_argument(
    "--config-file",
    type=pathlib.Path,
    required=False,
    help="Path to the YAML config file",
    default=None,
  )
  parser.add_argument(
    "--dump-cfg", action="store_true", help="Dump the config to stdout"
  )
  parser.add_argument(
    "--no-eval",
    action="store_true",
    help="Skip execution based evaluation of model predictions",
  )
  parser.add_argument("--ray-num-cpus", type=int, default=None)
  args = parser.parse_args()

  contains_gcs_paths = args.output_dir.startswith("gs://")
  if not utils.check_gcs_credentials():
    logger.error(
      "GCS credentials not found but GCS paths are provided:\n"
      f"{args.inputs}\n{args.outputs}\n"
      "Please set GOOGLE_APPLICATION_CREDENTIALS environment variable to a valid GCS credentials file."
    )
    sys.exit(1)
  # if args.ray_num_cpus is not None:
  #     ray.init(num_cpus=args.ray_num_cpus)
  # ---------------------------------------------------------------------------- #
  #          Merge and resolve configs from command line and config file         #
  # ---------------------------------------------------------------------------- #
  configs = {}

  if args.config_file:
    configs[f"Config from config file ({str(args.config_file)})"] = RawConfig.from_yaml(
      args.config_file
    )

  configs["Config from command line"] = RawConfig.from_args(args)
  merged_config = RawConfig.merge(*configs.values())
  resolved_config = merged_config.resolve()

  # ---------------------------------------------------------------------------- #
  #                         Dump config or run generation                        #
  # ---------------------------------------------------------------------------- #
  if args.dump_cfg:
    for k, v in configs.items():
      print(f"----------------- {k} -----------------")
      print(v.dumps())
      print("------------------------------------------")
    print()
    print("----------------- Final Resolved Config -----------------")
    print(resolved_config.dumps())
    print("------------------------------------------")
  else:
    generate(
      resolved_config,
      upath.UPath(args.output_dir),
      args.experiment_name,
      do_eval=not args.no_eval,
    )
