from tqdm.auto import tqdm
from tqdm.auto import tqdm
import json
import typing as tp
import numpy as np
import openai
import datasets as ds
from llm_inference import utils


def create_messages(user_prompt: str, system_prompt: str | None = None):
  if not system_prompt:
    return [{"role": "user", "content": user_prompt}]
  else:
    return [
      {"role": "system", "content": system_prompt},
      {"role": "user", "content": user_prompt},
    ]


def create_openai_batch_request(
  *,
  custom_id: str | int,
  user_prompt: str,
  system_prompt: str | None = None,
  **gen_kwargs,
):
  if "model" not in gen_kwargs:
    raise ValueError("Model name must be provided in gen_kwargs")
  return {
    "custom_id": str(custom_id),
    "method": "POST",
    "url": "/v1/chat/completions",
    "body": {
      **gen_kwargs,
      "messages": create_messages(user_prompt, system_prompt),
    },
  }


def wait_until_batch_completion(batch_id: str, delay: int = 60):
  import time

  client = openai.OpenAI()

  def yield_until_stats(status):
    while True:
      batch = client.batches.retrieve(batch_id)
      if batch.status == status:
        break

      yield batch
      time.sleep(delay)

  status = client.batches.retrieve(batch_id)
  bar = tqdm(
    desc="Waiting for batch completion", unit="batch", total=status.request_counts.total
  )
  for status in yield_until_stats("completed"):
    if bar.total < status.request_counts.total:
      bar.total = status.request_counts.total
    bar.update(status.request_counts.completed - bar.n)

  return status


def post_batch(filename, *, metadata=None):
  import upath

  filepath = upath.UPath(filename)
  client = openai.Client()
  batch_input_file = client.files.create(file=filepath.open("rb"), purpose="batch")
  batch_input_file_id = batch_input_file.id

  return client.batches.create(
    input_file_id=batch_input_file_id,
    endpoint="/v1/chat/completions",
    completion_window="24h",
    metadata=metadata,
  )


def _download_file_contents(file_id):
  client = openai.Client()
  return client.files.content(file_id).text


def get_openai_files(file_ids):
  if len(file_ids) == 1:
    return [_download_file_contents(file_ids[0])]

  import multiprocessing as mp

  processors = min(mp.cpu_count(), len(file_ids))
  with mp.Pool(processors) as pool:
    file_contents = pool.imap(_download_file_contents, file_ids)
    file_contents = list(
      tqdm(file_contents, total=len(file_ids), desc="Downloading files")
    )

  return file_contents


def openai_batches_from_dataset(
  dataset: ds.Dataset,
  user_prompt: tp.Callable[[dict], str],
  system_prompt: tp.Callable[[dict], str] | None = None,
  **gen_kwargs,
):
  # apply_prompt = utils.dictify(prompt, output_key='prompt', remove_other_keys=True)
  dataset = dataset.map(
    lambda row, idx: create_openai_batch_request(
      custom_id=idx,
      user_prompt=user_prompt(row),
      system_prompt=system_prompt(row) if system_prompt is not None else None,
      **gen_kwargs,
    ),
    remove_columns=dataset.column_names,
    with_indices=True,
  )

  return dataset.cast_column("custom_id", ds.Value("string"))


def parse_chat_completion_response(completion_response):
  for sample_idx, choice in enumerate(completion_response["choices"]):
    generated_text = choice["message"]["content"]

    # Initialize result dictionary
    result = {
      "generated_text": generated_text,
      "custom_id": completion_response["custom_id"],
      "sample_idx": sample_idx,
    }
    # Check if logprobs exist and extract them
    if "logprobs" in choice:
      result = {
        **result,
        "token_logprobs": [],
        "tokens": [],
      }
      for token_info in choice["logprobs"]["content"]:
        if "token" in token_info and "logprob" in token_info:
          result["token_logprobs"].append(token_info["logprob"])
          result["tokens"].append(token_info["token"])
  return result


def parse_batch_completion_response(
  completion_response,
):
  parsed_choices = []

  for sample_idx, choice in enumerate(
    completion_response["response"]["body"]["choices"]
  ):
    generated_text = choice["message"]["content"]

    # Initialize result dictionary
    result = {
      "generated_text": generated_text,
      "custom_id": completion_response["custom_id"],
      "sample_idx": sample_idx,
    }
    # Check if logprobs exist and extract them
    if "logprobs" in choice:
      result = {
        **result,
        "token_logprobs": [],
        "tokens": [],
      }
      for token_info in choice["logprobs"]["content"]:
        if "token" in token_info and "logprob" in token_info:
          result["token_logprobs"].append(token_info["logprob"])
          result["tokens"].append(token_info["token"])

    parsed_choices.append(result)

  return parsed_choices


def parse_batch_completion_outputs(filename):
  import pandas as pd

  def _parse():
    for record in utils.iterjsonl(filename):
      yield from parse_batch_completion_response(record)

  return pd.DataFrame(_parse())


def parse_batch_submission(submission_file):
  import pandas as pd

  def _parse():
    for record in utils.iterjsonl(submission_file):
      custom_id = int(record["custom_id"])
      prompts = {}

      for message in record["body"]["messages"]:
        if message["role"] in prompts:
          raise ValueError(f"Duplicate role '{message['role']}' in submission")
        prompts[message["role"]] = message["content"]
      result = {
        "custom_id": custom_id,
      }
      for role, content in prompts.items():
        result[f"{role}_prompt"] = content

      yield result

  return pd.DataFrame(_parse())
