import argparse
import bisect
import collections
import dataclasses
import multiprocessing as mp
import os
import sys
import typing as tp
import uuid

import datasets as ds
import numpy as np
import pandas as pd
import ray
import torch
import torch.utils.data
import tree
import upath
from loguru import logger
from tqdm.auto import tqdm
from transformers import (
  AutoConfig,
  AutoModelForCausalLM,
  AutoTokenizer,
  PreTrainedModel,
  PreTrainedTokenizer,
  PreTrainedTokenizerFast,
)

from llm_inference import ray_utils, scoring_utils, utils

os.environ["TOKENIZERS_PARALLELISM"] = "false"

TORCH_DTYPES = {
  "fp32": torch.float32,
  "fp16": torch.float16,
  "bf16": torch.bfloat16,
}


# ---------------------------------------------------------------------------- #
#                                 Write Outputs                                #
# ---------------------------------------------------------------------------- #
def format_index_names(row):
  results = []
  for k, v in row.items():
    results.append(f"{k}={v}")
  return "_".join(results)


def write_score_outputs(
  batch_inputs: pd.DataFrame,
  batch_output_dicts: list[dict],
  parquet_output_path: upath.UPath,
  tensor_output_folder: upath.UPath | None = None,
  tensor_keys_to_save: tp.Sequence[str] = ("hidden_states", "logits"),
  index_columns: tp.Sequence[str] = ("idx", "dataset_idx", "sample_idx"),
  save_tensor_dtype=np.float16,
  temperatures: tp.Sequence[float] = (),
):
  if not parquet_output_path.parent.exists():
    parquet_output_path.parent.mkdir(parents=True)
  if tensor_output_folder and not tensor_output_folder.exists():
    tensor_output_folder.mkdir(parents=True)
  batch_indices = batch_inputs[list(index_columns)]
  import awkward as ak

  columnar_batch_outputs = collections.defaultdict(list)

  for batch_output_dict in batch_output_dicts:
    for k, v in batch_output_dict.items():
      columnar_batch_outputs[k].append(v)
  if "hidden_states" in batch_output_dicts[0]:
    for layer_idx in batch_output_dicts[0]["hidden_states"]:
      columnar_batch_outputs[f"hidden_states.{layer_idx}"] = [
        x[layer_idx] for x in columnar_batch_outputs["hidden_states"]
      ]
    del columnar_batch_outputs["hidden_states"]
  flat_columnar_batch_outputs = dict(columnar_batch_outputs)
  for k in index_columns:
    flat_columnar_batch_outputs[k] = batch_indices[k].tolist()
  # flat_columnar_batch_outputs = flatten_dict.flatten(
  #   columnar_batch_outputs, reducer="dot"
  # )
  flat_columnar_batch_outputs = tree.map_structure(
    lambda x: x.tolist() if isinstance(x, np.ndarray) else x,
    flat_columnar_batch_outputs,
  )
  ak.to_parquet(
    ak.Array(flat_columnar_batch_outputs),
    parquet_output_path.as_posix(),
  )


def _load_tokenizer(
  model_id,
  padding_side: tp.Literal["left", "right"] = "left",
  trust_remote_code: bool = True,
  set_pad_to_eos_token: bool = True,
):
  tokenizer = AutoTokenizer.from_pretrained(
    model_id, padding_side=padding_side, trust_remote_code=trust_remote_code
  )
  if set_pad_to_eos_token:
    if tokenizer.pad_token_id is None:
      tokenizer.pad_token_id = tokenizer.eos_token_id
      tokenizer.pad_token = tokenizer.eos_token

  return tokenizer


def _load_model_and_tokenizer(
  model_id: str,
  dtype: torch.dtype = torch.float32,
  padding_side: tp.Literal["left", "right"] = "left",
  model_parallel: int = 1,
  trust_remote_code: bool = True,
  set_pad_to_eos_token: bool = True,
  use_flash_attention: bool = True,
):
  model_cfg = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
  if use_flash_attention:
    try:
      logger.info("Trying to load model with flash attention")
      model = AutoModelForCausalLM.from_pretrained(
        model_id,
        config=model_cfg,
        torch_dtype=dtype,
        attn_implementation="flash_attention_2",
        trust_remote_code=trust_remote_code,
      ).to("cuda")
    except:  # noqa: E722
      logger.info(
        "Failed to load model with flash attention. Loading model without flash attention"
      )
      model = AutoModelForCausalLM.from_pretrained(
        model_id,
        config=model_cfg,
        device_map="auto" if model_parallel > 1 else {"": "cuda"},
        torch_dtype=dtype,
        trust_remote_code=trust_remote_code,
      )
  else:
    logger.info("Flash attention is disabled")
    model = AutoModelForCausalLM.from_pretrained(
      model_id,
      config=model_cfg,
      device_map="auto" if model_parallel > 1 else {"": "cuda"},
      torch_dtype=dtype,
      trust_remote_code=trust_remote_code,
    )

  total_memory = torch.cuda.memory_allocated(model.device) / 1024**3
  logger.info("Loaded model into GPU with {:.2f} GB memory", total_memory)
  tokenizer = _load_tokenizer(
    model_id,
    padding_side=padding_side,
    trust_remote_code=trust_remote_code,
    set_pad_to_eos_token=set_pad_to_eos_token,
  )
  return model, tokenizer


# ---------------------------------------------------------------------------- #
#                             Data loader utilities                            #
# ---------------------------------------------------------------------------- #
class PredefinedBatchSampler(torch.utils.data.Sampler):
  def __init__(self, batch_indices: tp.Sequence[tp.Sequence[int]]):
    self.batch_indices = batch_indices

  def __iter__(self):
    for batch in self.batch_indices:
      yield batch

  def __len__(self):
    return len(self.batch_indices)


@dataclasses.dataclass
class DataCollator:
  tokenizer: PreTrainedTokenizerFast | PreTrainedTokenizer
  input_ids_key: str = "token_ids_to_score"
  mode: tp.Literal["tokenize", "pad"] = "pad"

  def __call__(self, batch):
    columnar_batch = collections.defaultdict(list)

    for example in batch:
      for k, v in example.items():
        columnar_batch[k].append(v)

    if self.mode == "tokenize":
      encoded_input_ids = self.tokenizer(
        columnar_batch[self.input_ids_key],
        padding=True,
        return_tensors="pt",
      )
    else:
      encoded_input_ids = self.tokenizer.pad(
        {"input_ids": columnar_batch[self.input_ids_key]},
        return_tensors="pt",
      )

    return {
      "input_ids": encoded_input_ids.input_ids,
      "attention_mask": encoded_input_ids.attention_mask,
      **{
        k: torch.tensor(v) for k, v in columnar_batch.items() if k != self.input_ids_key
      },
    }


def create_adaptive_sampler(
  input_lengths: np.ndarray,
  *,
  model: PreTrainedModel,
  tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
  output_hidden_states: bool,
):
  logger.info("Average sequence length: {}", np.mean(input_lengths))
  max_seq_length = np.max(input_lengths)
  # choose from a predefined set of batch sizes we want to use
  partitions = [2**n for n in range(6, 15)]
  partitions = partitions[: bisect.bisect_right(partitions, max_seq_length) + 1]
  # only keep sequence lengths that are possible given the maximum sequence length
  all_batch_indices = utils.compute_adaptive_batch_indices(
    input_lengths,
    partitions,
    lambda input_ids: scoring_utils.score_batch(
      model=model,
      input_ids=input_ids,
      tokenizer=tokenizer,
      output_hidden_states=output_hidden_states,
      hidden_states_to_return=layers,
    ),
    device=model.device,
  )
  logger.info("Finished computing batch sizes for partitions")
  batch_lens = [len(x) for x in all_batch_indices]
  logger.info("Maximum batch size: {}", max(batch_lens))
  logger.info("Minimum batch size: {}", min(batch_lens))
  logger.info("Average batch size: {}", np.mean(batch_lens))
  logger.info("Number of batches: {}", len(all_batch_indices))
  logger.info(f"{batch_lens=}")
  all_batch_indices = [
    all_batch_indices[i] for i in np.random.permutation(len(all_batch_indices))
  ]
  return PredefinedBatchSampler(all_batch_indices)


@ray.remote
class PostprocessingWorker:
  def __init__(self, queue_size: int = 2):
    self._writer = utils.FailFastThreadPoolExecutor(max_workers=1)
    self.batch_queue = []
    self.queue_size = queue_size

  def submit(
    self,
    batch_outputs: dict,
    batch_inputs: dict,
    score_output_path: str,
    tensor_output_path: str,
    tokenizer: PreTrainedTokenizerFast | PreTrainedTokenizer,
    padding_side: tp.Literal["left", "right"] = "left",
    temperatures: tp.Sequence[float] | None = None,
  ):
    batch_outputs = tree.map_structure(torch.tensor, batch_outputs)
    batch_inputs = tree.map_structure(torch.tensor, batch_inputs)
    # ---------------------------------------------------------------------------- #
    #                   Calculate logprobs based on temperatures                   #
    # ---------------------------------------------------------------------------- #
    batch_outputs["token_logprobs"] = scoring_utils.compute_token_logprobs(
      batch_outputs["input_ids"], batch_outputs["logits"], 1
    )
    if temperatures is not None:
      if "logits" not in batch_outputs:
        raise ValueError(
          "Cannot calculate logprobs without logits. Logits not found in output dict"
        )
      for temperature in temperatures:
        batch_outputs[f"token_logprobs@T={temperature}"] = (
          scoring_utils.compute_token_logprobs(
            batch_outputs["input_ids"], batch_outputs["logits"], temperature=temperature
          )
        )
    # ---------------------------------------------------------------------------- #
    #                             Remove large columns                             #
    # ---------------------------------------------------------------------------- #
    if "logits" in batch_outputs:
      del batch_outputs["logits"]
    if "attentions" in batch_outputs:
      del batch_outputs["attentions"]

    batch_output_dicts = scoring_utils.postprocess_batch_outputs2(
      batch_outputs,
      tokenizer.pad_token_id,
      padding_side=padding_side,
      to_numpy=True,
    )
    for batch_output_dict in batch_output_dicts:
      batch_output_dict["token_ids"] = batch_output_dict["input_ids"]
      del batch_output_dict["input_ids"]
    batch_input_df = pd.DataFrame(
      {
        k: v.cpu().numpy()
        for k, v in batch_inputs.items()
        if k in ["dataset_idx", "sample_idx"]
      }
    )
    self.batch_queue.append(
      (batch_input_df, batch_output_dicts, score_output_path, tensor_output_path)
    )

    if len(self.batch_queue) >= self.queue_size:
      self._flush_queue()

    return len(batch_output_dicts)

  def _flush_queue(self):
    if self.batch_queue:
      batch_input_df = pd.concat([x[0] for x in self.batch_queue], ignore_index=True)
      batch_output_dicts = [entry for x in self.batch_queue for entry in x[1]]
      score_output_path = self.batch_queue[0][2]
      tensor_output_path = self.batch_queue[0][3]
      self._writer.submit(
        write_score_outputs,
        batch_input_df,
        batch_output_dicts,
        parquet_output_path=upath.UPath(score_output_path)
        / f"scores-{uuid.uuid4().hex}.parquet",
        tensor_output_folder=upath.UPath(tensor_output_path),
        tensor_keys_to_save=("hidden_states",),
        index_columns=("dataset_idx", "sample_idx"),
      )
      self.batch_queue = []

  def shutdown(self):
    self._flush_queue()
    self._writer.shutdown()


def _scoring_worker(
  model_id: str,
  dataset: ds.Dataset,
  dataset_shard_info: tuple[int, int],
  tensor_output_path: str,
  score_output_path: str,
  token_ids_column: str = "token_ids_to_score",
  dtype: torch.dtype = torch.float32,
  padding_side: tp.Literal["left", "right"] = "left",
  set_pad_to_eos_token: bool = True,
  output_hidden_states: bool = False,
  layers: int | list[int] | None = None,
  batch_size: int = 32,
  model_parallel: int = 1,
  trust_remote_code: bool = True,
  batching_strategy: tp.Literal["adaptive", "fixed"] = "adaptive",
  use_flash_attention: bool = True,
  temperatures: tp.Sequence[float] | None = None,
  gcs_credentials: str | None = None,
  num_concurrent_postprocessors: int = 8,
  num_postprocessing_workers: int = 4,
  max_write_queue_size: int = 2,
):
  if gcs_credentials:
    os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = gcs_credentials
  dataset = dataset.shard(*dataset_shard_info)
  model, tokenizer = _load_model_and_tokenizer(
    model_id,
    dtype=dtype,
    padding_side=padding_side,
    model_parallel=model_parallel,
    trust_remote_code=trust_remote_code,
    set_pad_to_eos_token=set_pad_to_eos_token,
    use_flash_attention=use_flash_attention,
  )

  # 1. Figure out batch sizes based on sequence lengths
  input_lengths_ds = dataset.select_columns([token_ids_column]).map(
    lambda batch: {"lengths": [len(x) for x in batch[token_ids_column]]},
    batched=True,
  )
  sampler = create_adaptive_sampler(
    np.array(input_lengths_ds["lengths"]),
    model=model,
    tokenizer=tokenizer,
    output_hidden_states=output_hidden_states,
  )
  # randomize order so we get good estimates of time remaining
  collator = DataCollator(
    tokenizer=tokenizer, input_ids_key=token_ids_column, mode="pad"
  )
  dataloader = torch.utils.data.DataLoader(
    dataset,  # type: ignore
    batch_sampler=sampler,
    # batch_size=4,
    collate_fn=collator,
    shuffle=False,
    pin_memory=True,
  )
  # writer = utils.FailFastThreadPoolExecutor(max_workers=1)
  # postprocessing actor pool
  # 2. Score the dataset
  import ray.util
  import itertools

  perf_logger = utils.PerfLogger(logger.info)
  num_concurrent_postprocessors = 8
  # postprocessing_worker = PostprocessingWorker.remote()
  postprocessing_workers = [PostprocessingWorker.remote() for _ in range(4)]
  worker_cycle = itertools.cycle(postprocessing_workers)

  postprocessing_futures = []

  def flush_completed_postprocessing():
    nonlocal postprocessing_futures
    if not postprocessing_futures:
      return 0

    done, not_done = ray.wait(postprocessing_futures, num_returns=1, timeout=0)
    postprocessing_futures = not_done

    n = sum(ray.get(done))
    return n

  for batch_inputs in dataloader:
    perf_logger.log(
      f"Loaded batch (num_postprocessing_futures={len(postprocessing_futures)})"
    )
    # Try to flush completed postprocessing tasks without blocking
    yield flush_completed_postprocessing()
    # if len(postprocessing_futures) >= num_concurrent_postprocessors:
    #   postprocessing_futures, ready_refs = ray.wait(
    #     postprocessing_futures, num_returns=1
    #   )
    #   n = 0
    #   for ready_future in ray.get(ready_refs):
    #     n += ready_future
    #   yield n
    perf_logger.log("Completed flushing pending postprocessing")
    batch_outputs = scoring_utils.score_batch2(
      model=model,
      input_ids=batch_inputs["input_ids"].to(model.device),
      attention_mask=batch_inputs["attention_mask"].to(model.device),
      tokenizer=tokenizer,
      output_hidden_states=output_hidden_states,
      hidden_states_to_return=layers,
    )
    perf_logger.log("Forward Pass Complete")
    # If we've reached the concurrency limit, wait for at least one task to complete
    while len(postprocessing_futures) >= num_concurrent_postprocessors:
      yield flush_completed_postprocessing()

    perf_logger.log("Flushing postprocessing tasks completed")
    postprocessing_futures.append(
      next(worker_cycle).submit.remote(
        # need to cast to numpy to avoid ray serialization latency.
        tree.map_structure(lambda x: x.numpy(), batch_outputs),
        tree.map_structure(lambda x: x.numpy(), batch_inputs),
        score_output_path,
        tensor_output_path,
        tokenizer,
        padding_side=padding_side,
        temperatures=temperatures,
      )
    )
    perf_logger.log("Submitted postprocessing")

  logger.info("Waiting for writers to finish")
  # ray.get(postprocessing_futures)
  # Flush any remaining postprocessing tasks
  while postprocessing_futures:
    yield flush_completed_postprocessing()
  for worker in postprocessing_workers:
    ray.get(worker.shutdown.remote())
  logger.info("Finished writing")
  # writer.shutdown()


def _process_input_df(
  dataset: ds.Dataset,
  *,
  tokenizer,
  dedup_keys: tp.Sequence[str],
  input_keys: tp.Sequence[str],
  output_key: str = "token_ids_to_score",
  dedup: bool = False,
):
  # check if input keys exist in dataset
  for k in input_keys:
    if k not in dataset.features:
      raise ValueError(
        f"Input key '{k}' not found in dataset. Available keys: {dataset.column_names}"
      )

  logger.info(
    "Tokenizing input columns.\n" "Input columns: {}\n" "Output column: {}",
    input_keys,
    output_key,
  )

  # check if input_keys are strings or sequences of integers
  input_id_cols = set()
  input_text_cols = set()
  for k in input_keys:
    if dataset.features[k].dtype == "string":
      input_text_cols.add(k)
    elif dataset.features[k].dtype == "list":
      input_id_cols.add(k)
    else:
      raise ValueError(
        f"Input keys must be either `string` or `list` types. Found: {dataset.features[k].dtype}"
      )

  # one of `input_id_cols` or `input_text_cols` must be empty
  if input_id_cols and input_text_cols:
    raise ValueError(
      "Input keys must be either all strings or all sequences of integers. \n"
      f"Found types: {dict((key, dataset.features[key].dtype) for key in input_keys)}"
    )
  if input_text_cols:
    dataset = dataset.map(
      lambda batch: {
        output_key: tokenizer(
          list(map(lambda xs: "\n".join(xs), zip(*[batch[k] for k in input_keys]))),
        ).input_ids
      },
      batched=True,
      num_proc=min(4, mp.cpu_count()),
    )
  else:
    dataset = dataset.map(
      lambda batch: {
        output_key: np.concatenate(
          [batch[k] for k in input_keys],
        )
      },
    )

  if dedup:
    old_len = len(dataset)
    logger.info("Trying to drop duplicates.")

    # breakpoint()
    def _dedup_filter(row, seen: set, dedup_keys: tp.Sequence[str]):
      dedup_values = [row[k] for k in dedup_keys]
      dedup_values = [
        tuple(x) if isinstance(x, (list, np.ndarray)) else x for x in dedup_values
      ]
      dedup_values = tuple(dedup_values)

      if dedup_values not in seen:
        seen.add(dedup_values)
        return True
      return False

    dataset = dataset.filter(
      _dedup_filter,
      fn_kwargs=dict(
        seen=set(),
        dedup_keys=dedup_keys,
      ),
    )

    logger.info(
      "Dropped duplicates, old length: {} new length: {}", old_len, len(dataset)
    )

  # log some sample inputs
  logger.info("Sample inputs:")
  for i in range(1):
    text = tokenizer.decode(dataset[i][output_key])
    logger.info("Input:\n{}", text)

  return dataset


def run_scoring(
  model_id: str,
  input_file: upath.UPath,
  *,
  input_keys: tp.Sequence[str],
  dedup: bool = False,
  output_file: upath.UPath,
  output_hidden_states: bool = False,
  temperatures: tp.Sequence[float] | None = None,
  layers: int | list[int] | None = None,
  batch_size: int = 4,
  concurrency: int = -1,
  model_parallel: int = 1,
  dtype: str = "fp32",
  trust_remote_code: bool = True,
  index_columns: tp.Sequence[str] = ("idx", "dataset_idx", "sample_idx"),
  use_flash_attention: bool = True,
  limit: int | None = None,
  num_cpus: int | None = None,
):
  # ---------------------------------------------------------------------------- #
  #                              Validate arguments                              #
  # ---------------------------------------------------------------------------- #
  input_file = upath.UPath(input_file)
  output_file = upath.UPath(output_file)
  logger.info(input_file)
  logger.info(output_file)
  if dtype == "fp32":
    logger.warning(
      "You are using fp32 for scoring. This may be slower than using fp16 or bf16. Consider using --dtype=fp16 or --dtype=bf16"
    )
  logger.info("Batch Size: {}", batch_size)
  logger.info("Torch Dtype: {}", dtype)
  if concurrency == -1:
    concurrency = torch.cuda.device_count() // model_parallel
    logger.info(
      "Using {} processes ({} gpus per process) for scoring",
      concurrency,
      model_parallel,
    )

  if concurrency * model_parallel > torch.cuda.device_count():
    raise ValueError(
      f"Requested {concurrency * model_parallel} GPUs for scoring, "
      f"but only {torch.cuda.device_count()} are available"
    )

  # ---------------------------------------------------------------------------- #
  #             Check if there is a config that contains temperatures            #
  # ---------------------------------------------------------------------------- #
  if temperatures is None:
    logger.info("No temperatures provided. Checking input directory for config file")
    temperatures = []
  else:
    temperatures = list(temperatures)
  config_file = input_file.parent / "config.yaml"
  if config_file.exists():
    logger.info("Found config file: {}", config_file.as_posix())
    config = utils.load_yaml(config_file)
    if "temperature" in config:
      if config["temperature"] not in temperatures:
        logger.info("Using temperatures from config file: {}", config["temperature"])
        temperatures.append(config["temperature"])

  logger.info("Final temperatures: {}", temperatures)
  # ---------------------------------------------------------------------------- #
  #                                Load input file                               #
  # ---------------------------------------------------------------------------- #
  input_ds = tp.cast(
    ds.Dataset,
    ds.load_dataset(
      "parquet",
      data_files={
        "train": input_file.as_posix(),
      },
    )["train"],  # type: ignore
  )

  # ---------------------------------------------------------------------------- #
  #                              Process input file                              #
  # ---------------------------------------------------------------------------- #
  logger.info("Processing input file")
  tokenizer = _load_tokenizer(
    model_id,
    trust_remote_code=trust_remote_code,
    set_pad_to_eos_token=True,
    padding_side="left",
  )
  input_ds = _process_input_df(
    input_ds,
    tokenizer=tokenizer,
    input_keys=input_keys,
    dedup=dedup,
    dedup_keys=("dataset_idx", "token_ids_to_score"),
  )
  if limit is not None:
    input_ds = input_ds.take(limit)
  output_folder = output_file.parent
  hidden_state_output_folder = output_folder / output_file.stem
  hidden_state_output_folder.mkdir(exist_ok=True)

  # ---------------------------------------------------------------------------- #
  #                                  Check cache                                 #
  # ---------------------------------------------------------------------------- #
  import pyarrow.dataset as pa_ds

  if output_file.exists():
    logger.info("Output file already exists. Checking for completed rows")
    chunk_ds = pa_ds.dataset(
      output_file.as_posix(),
    )
    logger.info("# Completed rows: {}", chunk_ds.count_rows())
    chunk_indices = chunk_ds.to_table(columns=index_columns).to_pydict()
    completed_indices = set(zip(*(chunk_indices[col] for col in index_columns)))
    len_before = len(input_ds)
    input_ds = input_ds.filter(
      lambda row: tuple(row[col] for col in index_columns) not in completed_indices
    )
    len_after = len(input_ds)
    logger.info(
      "Dropped {} rows that were already scored. Remaining rows: {}",
      len_before - len_after,
      len_after,
    )

  input_ds = input_ds.select_columns([*index_columns, "token_ids_to_score"])

  # divided_indices = more_itertools.divide(concurrency, input_indices)
  # ---------------------------------------------------------------------------- #
  #                                     DEBUG                                    #
  # ---------------------------------------------------------------------------- #
  # output_stream = _scoring_worker(
  #   model_id,
  #   input_ds,
  #   dataset_shard_info=(1, 0),
  #   token_ids_column="token_ids_to_score",
  #   score_output_path=output_file.as_posix(),
  #   tensor_output_path=hidden_state_output_folder.as_posix(),
  #   dtype=TORCH_DTYPES[dtype],
  #   padding_side="left",
  #   output_hidden_states=output_hidden_states,
  #   layers=layers,
  #   batch_size=batch_size,
  #   model_parallel=model_parallel,
  #   trust_remote_code=trust_remote_code,
  #   batching_strategy="adaptive",
  #   use_flash_attention=use_flash_attention,
  #   temperatures=temperatures,
  #   gcs_credentials=os.environ.get("GOOGLE_APPLICATION_CREDENTIALS"),
  # )

  # for n in output_stream:
  #   pass

  # ---------------------------------------------------------------------------- #
  #                                   END DEBUG                                  #
  # ---------------------------------------------------------------------------- #
  # ray_utils.initialize_ray(num_cpus=mp.cpu_count(), num_gpus=torch.cuda.device_count())
  if num_cpus is None:
    num_cpus = mp.cpu_count()
  logger.info(
    f"Initializing ray with {mp.cpu_count()} cpus and {torch.cuda.device_count()} gpus"
  )
  ray.init(num_cpus=num_cpus, num_gpus=torch.cuda.device_count())
  num_cpus_per_scoring_worker = num_cpus // torch.cuda.device_count()
  _worker = ray.remote(num_gpus=model_parallel, num_cpus=num_cpus_per_scoring_worker)(
    _scoring_worker
  )
  actors = [
    _worker.remote(
      model_id,
      input_ds,
      dataset_shard_info=(concurrency, shard_idx),
      token_ids_column="token_ids_to_score",
      score_output_path=output_file.as_posix(),
      tensor_output_path=hidden_state_output_folder.as_posix(),
      dtype=TORCH_DTYPES[dtype],
      padding_side="left",
      output_hidden_states=output_hidden_states,
      layers=layers,
      batch_size=batch_size,
      model_parallel=model_parallel,
      trust_remote_code=trust_remote_code,
      batching_strategy="adaptive",
      use_flash_attention=use_flash_attention,
      temperatures=temperatures,
      gcs_credentials=os.environ.get("GOOGLE_APPLICATION_CREDENTIALS"),
      num_concurrent_postprocessors=num_cpus_per_scoring_worker * 2,
      num_postprocessing_workers=max(1, num_cpus_per_scoring_worker - 1),
      max_write_queue_size=4,
    )
    for shard_idx in range(concurrency)
  ]
  logger.info(
    f"num_cpus_per_scoring_worker: {num_cpus_per_scoring_worker}\n"
    f"num_postprocessing_workers: {max(1, num_cpus_per_scoring_worker - 1)}\n"
    f"max_write_queue_size: {4}"
  )

  output_stream = utils.ray_as_completed(actors)
  logger.info("Starting inference loop")
  with tqdm(total=len(input_ds), desc="Scoring") as bar:
    for n in output_stream:
      bar.update(n)

  logger.info("Finished scoring all rows. Waiting for writing to finish..")
  logger.info("Merging chunks")
  logger.info("Finished merging chunks")
  logger.info("Output saved to: {}", output_file.as_posix())
  ray.shutdown()


def _parse_float_list(s):
  # Remove any whitespace and split the string by commas
  elements = s.strip().split(",")

  result = []
  for i, e in enumerate(elements):
    e = e.strip()
    if not e:
      continue  # Skip empty elements
    try:
      result.append(float(e))
    except ValueError:
      raise ValueError(f"Invalid float value '{e}' at position {i}")

  if not result:
    raise ValueError("No valid float values found in the input string")

  return result


if __name__ == "__main__":
  parser = argparse.ArgumentParser()
  parser.add_argument("--model", type=str, help="HuggingFace model id", required=True)
  parser.add_argument("--inputs", type=str, help="Input parquet file", required=True)
  parser.add_argument("--outputs", type=str, help="Output parquet file", required=True)
  parser.add_argument(
    "--hidden-states",
    action="store_true",
    help="whether or not to output hidden states",
  )
  parser.add_argument(
    "--temperatures",
    type=str,
    help="Comma separated list of temperatures to use for calculating logprobs (e.g. --temperatues 0.2, 0.6)",
  )
  parser.add_argument(
    "--layers",
    help=(
      "Comma separated list of layers to output (e.g. '-1,-2', '0')."
      " If not specified, all hidden states are output."
    ),
    default=None,
  )
  parser.add_argument(
    "--batch-size",
    type=int,
    help="Batch size for scoring",
    default=4,
  )
  parser.add_argument(
    "--concurrency",
    type=int,
    help="Number of concurrent actors to use. If set to -1, set to number of available GPUs",
    default=-1,
  )
  parser.add_argument(
    "--model-parallel",
    type=int,
    help="Number of model parallel GPUs to use per concurrent process",
    default=1,
  )
  # Input arguments
  parser.add_argument(
    "--input-keys",
    help="Comma separated list of keys in input data to be scored. Each key must point to a string or token ids. Consecutive keys are concatenated",
    default="input_ids,token_ids",
  )

  parser.add_argument(
    "--dedup",
    action="store_true",
  )
  parser.add_argument(
    "--dtype",
    type=str,
    choices=TORCH_DTYPES.keys(),
    default="fp16",
    help="Torch dtype to use for scoring",
  )
  parser.add_argument(
    "--score-inputs-only", action="store_true", help="Score only the input-ids "
  )
  parser.add_argument(
    "--no-flash-attention",
    action="store_true",
    help="Disable flash attention",
  )
  parser.add_argument(
    "--index-keys",
    type=str,
    help="Comma separated list of keys to use as index columns",
    default="idx,dataset_idx,sample_idx",
  )
  parser.add_argument(
    "--num-cpus",
    type=int,
    help="Number of cpus to use for scoring",
  )
  parser.add_argument("--limit", type=int, help="Limit the number of rows to score")
  args = parser.parse_args()
  layers = None
  if args.layers is not None:
    layers = [int(x) for x in args.layers.split(",")]

  contains_gcs_paths = args.inputs.startswith("gs://") or args.outputs.startswith(
    "gs://"
  )
  if not utils.check_gcs_credentials():
    logger.error(
      "GCS credentials not found but GCS paths are provided:\n"
      f"{args.inputs}\n{args.outputs}\n"
      "Please set GOOGLE_APPLICATION_CREDENTIALS environment variable to a valid GCS credentials file."
    )
    sys.exit(1)

  run_scoring(
    model_id=args.model,
    input_file=args.inputs,
    output_file=args.outputs,
    input_keys=args.input_keys.split(","),
    index_columns=args.index_keys.split(","),
    output_hidden_states=args.hidden_states,
    temperatures=_parse_float_list(args.temperatures) if args.temperatures else None,
    layers=layers,
    batch_size=args.batch_size,
    concurrency=args.concurrency,
    dtype=args.dtype,
    model_parallel=args.model_parallel,
    dedup=args.dedup,
    use_flash_attention=not args.no_flash_attention,
    limit=args.limit,
    num_cpus=args.num_cpus,
  )
