import argparse
import collections
import concurrent.futures
import dataclasses
import functools
import gc
import inspect
import logging
import math
import multiprocessing as mp
import pathlib
import time
import typing as tp
from datetime import datetime

import datasets as ds
import flatten_dict
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.dataset as pa_ds
import pyarrow.parquet as pq
import pydantic
import ray
import torch
import upath
from tqdm.auto import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

logger = logging.getLogger(__name__)

P = tp.ParamSpec("P")
T = tp.TypeVar("T")
S = tp.TypeVar("S")


class with_dict_inputs(tp.Generic[P, T]):
  """Takes a function that accepts keyword arguments and returns a function that accepts a arguments as a dictionary.

  Example:
  >>> def fn(*, a, b, c):
  >>>     return a + b + c
  >>> fn_with_dict_inputs = with_dict_inputs(fn)
  >>> fn_with_dict_inputs(dict(a=1, b=2, c=3))
  6
  """

  def __init__(self, fn: tp.Callable[P, T], *, strict: bool = False):
    self._signature = inspect.signature(fn)
    # fn must have all keyword-only parameters (including kwargs)
    if not all(
      parameter.kind in (inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.VAR_KEYWORD)
      for parameter in self._signature.parameters.values()
    ):
      fn_summary = inspect.getsource(fn).splitlines()[:10]
      raise ValueError(
        f"Prompt function '{fn.__name__}' must have all keyword-only parameters "
        "(including kwargs), "
        f"but accepts parameters: {str(self._signature)}\n\n"
        f"Function source:\n\n" + "\n".join(fn_summary)
      )
    # use pydantic to validate the function call
    if strict:
      self._fn = pydantic.validate_call()(fn)
    else:
      self._fn = fn
    self._has_kwargs = any(
      parameter.kind == inspect.Parameter.VAR_KEYWORD
      for parameter in self._signature.parameters.values()
    )
    self.input_keys = tuple(
      name
      for name, parameter in self._signature.parameters.items()
      if parameter.kind
      not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
      and parameter.default is inspect.Parameter.empty
    )
    self.optional_keys = tuple(
      name for name in self._signature.parameters if name not in self.input_keys
    )
    self.keys = self.input_keys + self.optional_keys
    functools.update_wrapper(self, fn)

  def __call__(self, d: tp.Mapping) -> T:
    filtered_kwargs = {k: v for k, v in d.items() if k in self._signature.parameters}
    kwargs = (
      {k: v for k, v in d.items() if k not in self._signature.parameters}
      if self._has_kwargs
      else {}
    )
    bound_args = self._signature.bind_partial(**filtered_kwargs, **kwargs)
    return self._fn(*bound_args.args, **bound_args.kwargs)

  def apply(self, *args: P.args, **kwargs: P.kwargs) -> T:
    return self._fn(*args, **kwargs)

  def __repr__(self):
    return f"{self._fn.__name__}{str(self._signature)}"


class _dictpreprocessor(with_dict_inputs[P, T]):
  """like with dict inputs, but puts the result in a dictionary with a given key"""

  def __init__(
    self,
    fn: tp.Callable[P, T],
    output_key: str,
    *,
    strict: bool = False,
    remove_other_keys=False,
  ):
    super().__init__(fn, strict=strict)
    self.output_key = output_key
    self.remove_other_keys = remove_other_keys
    functools.update_wrapper(self, self.__call__)

  def __call__(self, d: tp.Mapping) -> tp.Mapping:
    results = super().__call__(d)
    if self.remove_other_keys:
      return {self.output_key: results}
    return {**d, self.output_key: results}

  # def batch_apply(
  #     self, batch: tp.Mapping[str, tp.Sequence[dict]]
  # ) -> tp.Sequence[dict]:
  #     batch_structure = {k: None for k in self.}

  #     def _apply_single(*flat_args):
  #         kwargs = tree.unflatten_as(batch_structure, flat_args)
  #         return self(kwargs)

  #     return tree.map_structure(_apply_single, *batch.values()) # type: ignroe


@tp.overload
def dictify(
  *, output_key: str, remove_other_keys: bool = False
) -> tp.Callable[[tp.Callable[P, T]], with_dict_inputs[P, T]]: ...


@tp.overload
def dictify(
  fn: tp.Callable[P, T], *, output_key: str, remove_other_keys: bool = False
) -> with_dict_inputs[P, T]: ...


def dictify(
  fn=None, *, output_key: str, strict: bool = False, remove_other_keys: bool = False
) -> tp.Union[
  tp.Callable[[tp.Callable[P, T]], with_dict_inputs[P, T]], with_dict_inputs[P, T]
]:
  """Decorator to convert a function that accepts keyword arguments to a function that accepts a dictionary and returns a dictionary."""
  if fn is not None:
    return _dictpreprocessor(
      fn, output_key, strict=strict, remove_other_keys=remove_other_keys
    )
  return functools.partial(_dictpreprocessor, output_key=output_key)


def repeat_dataset(
  dataset: ds.Dataset,
  *,
  num_samples: int,
  dataset_idx="dataset_idx",
  sample_idx: str = "sample_idx",
  idx: str = "idx",
):
  """Repeats the dataset `num_samples` times and adds ('idx', 'dataset_idx', 'sample_idx') columns.

  'idx' is the index of the sample in the repeated dataset.
  'dataset_idx' is the index of the row in the original dataset.
  'sample_idx' is the index of the sample for the 'dataset_idx'.

  Example:
  >>> dataset = ds.Dataset.from_dict({"prompt": ["a", "b", "c"]})
  >>> repeated_dataset = repeat_dataset(dataset, num_samples=3)
  >>> repeated_dataset
  Dataset({
      'prompt': ['a', 'a', 'a', 'b', 'b', 'b', 'c', 'c', 'c'],
      'idx': [0, 1, 2, 3, 4, 5, 6, 7, 8]
      'dataset_idx': [0, 0, 0, 1, 1, 1, 2, 2, 2],
      'sample_idx': [0, 1, 2, 0, 1, 2, 0, 1, 2],
  })

  Args:
      dataset: dataset to repeat
      num_samples: number of times to repeat
      dataset_idx: key to use for 'dataset_idx'. Defaults to "dataset_idx".
      sample_idx: key to use for 'sample_idx'. Defaults to "sample_idx".
      idx: key to use for 'idx'. Defaults to "idx".
  """

  def _enumerate_dataset(batch: dict, indices: tp.Sequence[int]):
    dataset_indices = []
    sample_indices = []

    for i in indices:
      sample_indices.append(i % num_samples)
      dataset_indices.append(i // num_samples)

    return {
      **batch,
      idx: indices,
      dataset_idx: dataset_indices,
      sample_idx: sample_indices,
    }

  indices = [i for i in range(dataset.num_rows) for _ in range(num_samples)]
  dataset = dataset.select(indices)
  return dataset.map(
    _enumerate_dataset,
    with_indices=True,
    batched=True,
    desc="Adding 'dataset_idx' and 'sample_idx'",
  )


def sort_ds_by_length(dataset: ds.Dataset, *, column: str):
  lengths = [len(d) for d in dataset[column]]
  sorted_indices = np.argsort(lengths)
  return dataset.select(sorted_indices)


# ---------------------------------------------------------------------------- #
#            dataset preprocessors, i.e. f: ds.Dataset -> ds.Dataset           #
# ---------------------------------------------------------------------------- #
def apply_prompt_to_dataset(
  dataset: ds.Dataset, *, prompt: tp.Callable[..., str], output_key: str = "prompt"
):
  apply_prompt = dictify(fn=prompt, output_key=output_key)

  return dataset.map(apply_prompt, desc="Applying prompt")


def tokenize_dataset(
  batch: dict[str, tp.Sequence[str] | np.ndarray],
  *,
  tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
  columns: str | tp.Sequence[str],
  max_length: int = 1024,
  padding: str | bool = "max_length",
  truncation: bool | str = True,
  return_tensors: str | None = "np",
  input_ids_only: bool = True,
):
  for col in columns:
    batch[col] = tokenizer(
      batch[col].tolist() if isinstance(batch[col], np.ndarray) else batch[col],
      padding=padding,
      max_length=max_length,
      truncation=truncation,
      return_tensors=return_tensors,
    )
    if input_ids_only:
      batch[col] = batch[col]["input_ids"]
  return batch


class ParquetWriter:
  def __init__(self, filename, schema=None, flatten: bool = True):
    self.filename = filename
    self.schema = schema
    self._actual_schema = None
    self.pqwriter = None
    self.is_closed = False
    self.flatten = flatten

  def is_open(self):
    return self.pqwriter is not None

  def open(self):
    pass

  def __enter__(self):
    return self

  def __exit__(self, *args):
    self.close()

  def close(self):
    if self.pqwriter is not None:
      self.is_closed = True
      self.pqwriter.close()

  def write(self, item: list[dict[str, tp.Any]] | dict[str, list[tp.Any]]):
    if self.is_closed:
      raise ValueError("The parquet file has already been closed, cannot write to it.")

    # if isinstance(item, dict):
    #     item = [item]

    if isinstance(item, list):
      if self.flatten:
        item = [flatten_dict.flatten(x, reducer="dot") for x in item]
      table = pa.Table.from_pylist(item)
    else:
      if self.flatten:
        item = flatten_dict.flatten(item, reducer="dot")
      table = pa.Table.from_pydict(item)

    self.write_table(table)

  def write_table(self, table: pa.Table | pa.RecordBatch):
    if self.is_closed:
      raise ValueError("The parquet file has already been closed, cannot write to it.")

    if self.pqwriter is None:
      # take care to unfiy schema of new table
      # with existing schema of the parquet file
      if self._actual_schema is None:
        self._actual_schema = table.schema
      if self._actual_schema != table.schema:
        self._actual_schema = pa.unify_schemas([self._actual_schema, table.schema])
      self.pqwriter = pq.ParquetWriter(self.filename, self._actual_schema)

    self.pqwriter.write(table)

  @classmethod
  def merge(cls, filenames, output_file, sort_by=None):
    merged_dataset = pa_ds.dataset(filenames)
    if sort_by is not None:
      merged_dataset = merged_dataset.sort_by(sort_by)
    with cls(output_file) as writer:
      for batch in merged_dataset.to_batches():
        writer.write_table(batch)


def load_yaml(filename: str | pathlib.Path) -> dict:
  """Loads a yaml file and returns the data as a dictionary."""
  from yaml import load

  try:
    from yaml import CLoader as Loader
  except ImportError:
    from yaml import Loader

  filename = _canonicalize_path(filename)
  with filename.open("r") as f:
    return load(f, Loader=Loader)


def dump_yaml(data: dict):
  """Dumps a dictionary to a yaml string."""
  from yaml import dump

  try:
    from yaml import CDumper as Dumper
  except ImportError:
    from yaml import Dumper

  return dump(data, Dumper=Dumper)


def _canonicalize_path(path: str | upath.UPath | pathlib.Path) -> upath.UPath:
  if isinstance(path, str):
    return upath.UPath(path)

  return upath.UPath(path)


def write_yaml(data: dict, filename: str | upath.UPath | pathlib.Path):
  """Writes a dictionary as YAML to a yaml file."""
  filename = _canonicalize_path(filename)
  with filename.open("w") as f:
    f.write(dump_yaml(data))


class ConfigDict(pydantic.BaseModel):
  @classmethod
  def names_to_exclude(cls):
    return []

  @classmethod
  def add_arguments(cls, parser: argparse.ArgumentParser):
    for name, field in cls.model_fields.items():
      if name in cls.names_to_exclude():
        continue

      parser.add_argument(
        f"--{name.replace('_', '-')}",
        default=field.default,
        help=field.description,
        choices=field.examples,
      )
    return parser

  @classmethod
  def from_args(cls, args):
    return cls(**vars(args))

  @classmethod
  def merge(cls, *configs: tp.Self):
    """Merges "other" with the current config.

    Values of "other" override the current config.

    Args:
        other: config to merge

    Returns:
        merged config
    """
    model_dict = {}
    model_fields = cls.model_fields
    for config in configs:
      for k, v in config.model_dump().items():
        # Make sure to overwrite only if the value is different from the default
        # or the value is undefined
        if v is not None:
          if k in model_fields and model_fields[k].default != v:
            model_dict[k] = v

    return cls(**model_dict)

  @classmethod
  def from_yaml(cls, file: str | upath.UPath):
    yaml_data = load_yaml(file)
    return cls(**yaml_data)

  def dumps(self):
    return dump_yaml(self.model_dump())

  def dump(self, file: str | upath.UPath):
    write_yaml(self.model_dump(), file)


def group_predictions_by_indices(
  indices: np.ndarray,
  predictions: np.ndarray,
) -> tuple[np.ndarray, list[list]]:
  """
  Groups predictions based on indices.

  Args:
      predictions (np.ndarray): Array of predictions.
      indices (np.ndarray): Array of indices.

  Returns:
      tuple[np.ndarray, list[list]]: A tuple containing the unique sorted indices and a list of predictions grouped by indices.

  Example:
      >>> predictions = np.array([1, 2, 3, 4, 5])
      >>> indices = np.array([0, 1, 0, 1, 2])
      >>> group_predictions_by_indices(predictions, indices)
      (array([0, 1, 2]), [[1, 3], [2, 4], [5]])
  """

  index_argsort = np.argsort(indices)
  sorted_indices = indices[index_argsort]
  sorted_predictions = predictions[index_argsort]
  predictions_by_idx = collections.defaultdict(list)
  for idx, pred in zip(sorted_indices, sorted_predictions):
    predictions_by_idx[idx].append(pred)
  return np.unique(sorted_indices), list(predictions_by_idx.values())


@tp.overload
def timed(fn: tp.Callable[P, T]) -> tp.Callable[P, T]: ...


@tp.overload
def timed(
  fn: None = None, *, log_fn: tp.Callable[[str], None] = logger.info
) -> tp.Callable[[tp.Callable[P, T]], tp.Callable[P, T]]: ...


def timed(
  fn: tp.Callable[P, T] | None = None,
  *,
  log_fn: tp.Callable[[str], None] = logger.info,
) -> tp.Callable[P, T] | tp.Callable[[tp.Callable[P, T]], tp.Callable[P, T]]:
  """Decorator to time a function and log the execution time."""

  def decorator(func: tp.Callable[P, T]):
    @functools.wraps(func)
    def wrapper(*args: P.args, **kwargs: P.kwargs):
      start_time = time.time()
      result = func(*args, **kwargs)
      end_time = time.time()
      log_fn(
        f"Function '{func.__name__}' executed in {end_time - start_time:.4f} seconds."
      )
      return result

    return wrapper

  if fn is not None:
    return decorator(fn)

  return decorator


def describe_callable(callable_obj, max_arg_length=50):
  def truncate_arg_value(value, max_length=50):
    """
    Truncate the string representation of a value if it is longer than max_length.
    Adds an ellipsis (...) to indicate truncation.
    """
    value_str = repr(value)
    if len(value_str) > max_length:
      return value_str[: max_length - 3] + "..."
    return value_str

  description = ""

  if isinstance(callable_obj, functools.partial):
    func = callable_obj.func
    args = callable_obj.args
    kwargs = callable_obj.keywords if callable_obj.keywords else {}
    signature = inspect.signature(func)
    bound_args = signature.bind_partial(*args, **kwargs)
    bound_args.apply_defaults()

    args_str = ", ".join(
      f"{k}={truncate_arg_value(v, max_arg_length)}"
      for k, v in bound_args.arguments.items()
    )
    description = f"{func.__name__}({args_str})"

  elif inspect.isfunction(callable_obj):
    signature = inspect.signature(callable_obj)
    args_str = ", ".join(
      (
        f"{param.name}={truncate_arg_value(param.default, max_arg_length)}"
        if param.default is not param.empty
        else param.name
      )
      for param in signature.parameters.values()
    )
    description = f"{callable_obj.__name__}({args_str})"

  elif inspect.isclass(callable_obj):
    if dataclasses.is_dataclass(callable_obj):
      description = "Cannot directly describe a dataclass type without an instance."
    else:
      ctor = callable_obj.__init__
      signature = inspect.signature(ctor)
      args_str = ", ".join(
        (
          f"{param.name}={truncate_arg_value(param.default, max_arg_length)}"
          if param.default is not param.empty
          else param.name
        )
        for param in signature.parameters.values()
        if param.name != "self"
      )
      description = f"{callable_obj.__name__}({args_str})"

  elif dataclasses.is_dataclass(type(callable_obj)):
    description = repr(callable_obj)

  else:
    description = (
      f"Callable of type {type(callable_obj).__name__} without detailed description."
    )

  return description


P = tp.ParamSpec("P")
T = tp.TypeVar("T")


class FailFastThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
  """Wrapper for ThreadPoolExecutor that crashes main thread on exceptions.

  NOTE: this class should be used only from the main thread.
  """

  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self._incomplete_futures: list[concurrent.futures.Future] = []

  def check_for_exceptions(self, wait: bool = False):
    """Raises any exceptions from complete futures on the main thread."""
    still_incomplete_futures = []
    for future in self._incomplete_futures:
      try:
        exception = future.exception(timeout=0 if wait else None)
      except concurrent.futures.TimeoutError:
        still_incomplete_futures.append(future)
      if exception is not None:
        raise exception

    self._incomplete_futures = still_incomplete_futures

  def submit(
    self, fn: tp.Callable[P, T], *args: P.args, **kwargs: P.kwargs
  ) -> concurrent.futures.Future:
    """Submit function to threadpool, capturing the returned future."""
    future = super().submit(fn, *args, **kwargs)
    self._incomplete_futures.append(future)
    self.check_for_exceptions(wait=False)
    return future

  def shutdown(self, *args, wait: bool = False, **kwargs):
    self.check_for_exceptions(wait=wait)
    super().shutdown(*args, **kwargs)


def write_parquet(batch: dict[str, tp.Sequence], filename: str):
  table = pa.Table.from_pydict(batch)
  pq.write_table(table, filename)


@dataclasses.dataclass
class ChunkState:
  cached_indices: set
  chunk_idx: int


@dataclasses.dataclass
class ChunkManager:
  path: upath.UPath
  """Path to final output file"""
  index_cols: tp.Sequence[str] = (
    "idx",
    "dataset_idx",
    "sample_idx",
  )
  """Columns to use as an index. For any row, the tuple (row[col] for col in index_cols) should be unique."""
  chunk_path: upath.UPath = dataclasses.field(init=False)

  def __post_init__(
    self,
  ):
    self.index_cols = list(self.index_cols)
    self.chunk_path = self.path.with_suffix(self.path.suffix + ".CHUNKS")
    self.chunk_path.mkdir(exist_ok=True, parents=True)

  def get_chunk_state(self):
    chunk_files = self.get_chunks()
    cached_chunks = None
    if self.path.exists():
      cached_chunks = pa_ds.dataset(self.path.as_posix(), filesystem=self.path.fs)
    elif chunk_files:
      cached_chunks = pa_ds.dataset(
        [c.as_posix() for c in chunk_files], filesystem=self.path.fs
      )

    if cached_chunks:
      cached_indices = cached_chunks.to_table(columns=self.index_cols).to_pylist()
      cached_indices = {
        tuple(i[col] for col in self.index_cols) for i in cached_indices
      }
      chunk_indices = [int(p.stem.split("-")[1]) for p in chunk_files]
      chunk_idx = max(chunk_indices, default=-1) + 1
    else:
      cached_indices = set()
      chunk_idx = 0

    return ChunkState(cached_indices=cached_indices, chunk_idx=chunk_idx)

  def get_chunks(self) -> tp.Sequence[pathlib.Path]:
    return tuple(self.chunk_path.glob("chunk-*.parquet"))

  def get_chunk_path(self, chunk_idx: int):
    return self.chunk_path / f"chunk-{chunk_idx:06}.parquet"

  def merge_chunks(self, sort_by_index=True, ascending=True, remove_chunks=True):
    chunk_files = self.get_chunks()
    if not chunk_files:
      return
    dataset = pa_ds.dataset(
      [c.as_posix() for c in chunk_files], filesystem=self.path.fs
    )
    if sort_by_index:
      sort_order = "ascending" if ascending else "descending"
      dataset = dataset.sort_by([(col, sort_order) for col in self.index_cols])

    pq.write_table(dataset.to_table(), self.path.as_posix(), filesystem=self.path.fs)

    if remove_chunks:
      for chunk_file in chunk_files:
        chunk_file.unlink()
      # For cloud storage, the directory will be deleted after the last file is removed
      # so we need to check if the directory still exists before trying to remove it
      if self.chunk_path.exists():
        self.chunk_path.rmdir()


def get_maximum_batch_size_for_seq_len(
  fn: tp.Callable[[torch.Tensor], tp.Any],
  device: torch.device,
  initial_batch_size=64,
  token_id=200,
  seq_len=512,
):
  """Computes the maximum batch size for `fn` that can be used for a given sequence length.

  Creates tensors of shape `(batch_size, seq_len)` filled with `token_id` and calls `fn` with the tensor.
  If an out-of-memory error occurs, the batch size is halved and the process is repeated until a valid batch size is found.
  If no batch size is found, a `ValueError` is raised

  Args:
      fn: The function to call with the input tensor.
      device: The device to use for the input tensor.
      initial_batch_size: starting batch size. Defaults to 32.
      token_id: the token_id used to fill the input tensor. Defaults to 200.
      seq_len: the sequence length. Defaults to 512.

  Returns:
      int: The maximum batch size that can be used for the given sequence length.
  """
  current_batch_size = initial_batch_size
  min_batch_size = 1

  while True:
    try:
      input_ids = torch.full(
        (current_batch_size, seq_len),
        token_id,
        dtype=torch.long,
        device=device,
      )
      fn(
        input_ids,
      )

      return current_batch_size
    except RuntimeError as e:
      if isinstance(e, torch.cuda.OutOfMemoryError) or "CUDA out of memory" in str(e):
        torch.cuda.empty_cache()
        gc.collect()

        current_batch_size = max(current_batch_size // 2, min_batch_size)

        if current_batch_size < min_batch_size:
          raise ValueError("Batch size below minimum. Cannot proceed.")

        print(f"Reduced batch size to {current_batch_size}")
      else:
        raise  # Re-raise the exception if it's not an OOM erro


def create_batches_by_seq_len(
  batch_sizes: tp.Sequence[int],
  sequence_length_thresholds: tp.Sequence[int],
  input_lengths: np.ndarray,
):
  """Greedily creates batch of indices to minimize padding by grouping indices with similar input lengths.

  Args:
      batch_sizes: the maximum batch sizes for each sequence length
      partitions: sequence lengths corresponding to the batch sizes
      indices: the indices to group into batches
      input_lengths: corresponding input lengths for each index
  """
  all_batch_indices = []
  i = 0
  # batch_size_idx = len(batch_sizes) - 1
  indices = np.argsort(input_lengths)[::-1]
  while i < len(indices):
    for batch_size_idx in range(len(batch_sizes)):
      batch_indices = indices[i : i + batch_sizes[batch_size_idx]]
      max_len = np.max(input_lengths[batch_indices])
      if max_len <= sequence_length_thresholds[batch_size_idx]:
        all_batch_indices.append(batch_indices)
        i += len(batch_indices)
        break
    else:
      # This will only happen if no suitable batch size was found
      raise ValueError(
        "Unable to find a suitable batch size for the remaining elements.\n"
        f"Remaining element length: {max_len}\n"
        f"Smallest threshold: {sequence_length_thresholds[0]}\n"
        f"i: {i}\n"
      )
  return all_batch_indices


def compute_adaptive_batch_indices(
  input_lengths: np.ndarray,
  length_partitions: tp.Sequence[int],
  fn: tp.Callable,
  device,
):
  # calculate max batch size for this sequence length
  batch_sizes = [
    get_maximum_batch_size_for_seq_len(
      fn,
      device=device,
      seq_len=n,
    )
    for n in length_partitions
  ]
  return create_batches_by_seq_len(batch_sizes, length_partitions, input_lengths)


def ray_as_completed(tasks: tp.Sequence[ray.ObjectRef | ray.ObjectRefGenerator]):
  pending_tasks = list(tasks)
  while pending_tasks:
    completed_tasks, pending_tasks = ray.wait(pending_tasks)
    for completed_task in completed_tasks:
      if isinstance(completed_task, ray.ObjectRefGenerator):
        try:
          result = next(completed_task)
          yield ray.get(result)
        except StopIteration:
          continue
        else:
          pending_tasks.append(completed_task)
      else:
        yield ray.get(completed_task)


def ray_as_completed_lazy(tasks: tp.Sequence[ray.ObjectRef | ray.ObjectRefGenerator]):
  pending_tasks = list(tasks)
  while pending_tasks:
    completed_tasks, pending_tasks = ray.wait(pending_tasks)
    for completed_task in completed_tasks:
      if isinstance(completed_task, ray.ObjectRefGenerator):
        try:
          result = next(completed_task)
          yield result
        except StopIteration:
          continue
        else:
          pending_tasks.append(completed_task)
      else:
        yield completed_task


def check_gcs_credentials():
  import os

  cred_path = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
  if cred_path and os.path.isfile(cred_path):
    return True
  else:
    print(
      "WARNING: GOOGLE_APPLICATION_CREDENTIALS environment variable is not set "
      "or does not point to a valid file. GCS operations may fail."
    )
    return False


def get_relative_path(target_path: upath.UPath, base_path: upath.UPath) -> str:
  target_path_str = target_path.as_posix().removeprefix("gs://")
  base_path_str = base_path.as_posix().removeprefix("gs://")

  return pathlib.Path(target_path_str).relative_to(base_path_str).as_posix()


def row_to_columnar(rows: tp.Sequence[dict]) -> dict[str, np.ndarray]:
  """Converts a list of arbitrarily nested dictionaries to a columnar format.

  If the records are nested, the keys are flattened using dot notation.
  Columns are treated as follows:
  - If all values are scalars, they are stacked into a single np.ndarray.
  - If all values are lists or tuples, and all have the same length, they are stacked into a single np.ndarray.
  - If all values are np.ndarrays and have the same shape, they are stacked into a single np.ndarray.
  - Otherwise, they are returned as a np.array(..., dtype='O').

  Example
  -------
  >>> rows = [
  ...     {"a": 1, "b": 2, "c": [3, 4], "d": {"e": 5, "f": [6, 7]}},
  ...     {"a": 8, "b": 9, "c": [10, 11], "d": {"e": 12, "f": [13, 14]}},
  ... ]

  >>> row_to_columnar(rows)
  {
      "a": np.array([1, 8]),
      "b": np.array([2, 9]),
      "c": np.array([[3, 4], [10, 11]]),
      "d.e": np.array([5, 12]),
      "d.f": np.array([[6, 7], [13, 14]]),
  }

  Args:
      rows: list of dictionaries

  Returns:
      dict: A dictionary with keys as column names and values as columnar data.
  """
  import tree

  def collate_values(*values):
    if all(np.isscalar(v) for v in values):
      return np.array(values)

    if all(isinstance(v, (tuple, list)) for v in values):
      values = [np.array(v) for v in values]

    if all(isinstance(v, np.ndarray) for v in values):
      # if all of them have the same shape, stack them
      if all(v.shape == values[0].shape for v in values[1:]):
        return np.stack(values)

    # if not, return them as a list
    return np.array(values, dtype=object)

  # try:
  return tree.map_structure(collate_values, *rows)
  # return dict((".".join(k), v) for k, v in tree.flatten_with_path(flattened_structure))

  # except Exception as e:
  #   print("Error in row_to_columnar")
  #   print(e)
  #   print("rows")
  #   print(rows)
  #   raise e


def flatten_batch(batch):
  outputs = {}
  for col in batch:
    if isinstance(batch[col][0], dict):
      outputs[col] = row_to_columnar(batch[col])
    else:
      outputs[col] = batch[col]

  return flatten_dict.flatten(outputs, reducer="dot")


def columnar_to_row(columnar: dict[str, np.ndarray]) -> tp.Sequence[dict]:
  """Converts a columnar format to a list of dictionaries.

  The input dictionary should have keys as column names and values as columnar data.

  Example
  -------
  >>> columnar = {
  ...     "a": np.array([1, 8]),
  ...     "b": np.array([2, 9]),
  ...     "c": np.array([[3, 4], [10, 11]]),
  ...     "d.e": np.array([5, 12]),
  ...     "d.f": np.array([[6, 7], [13, 14]]),
  ... }

  >>> columnar_to_row(columnar)
  [
      {"a": 1, "b": 2, "c": [3, 4], "d": {"e": 5, "f": [6, 7]}},
      {"a": 8, "b": 9, "c": [10, 11], "d": {"e": 12, "f": [13, 14]}},
  ]

  Args:
      columnar: A dictionary with keys as column names and values as columnar data.

  Returns:
      list: A list of dictionaries.
  """
  flat_columnar = flatten_dict.flatten(columnar)
  flat_keys = list(flat_columnar.keys())

  rows = []
  for i in range(len(flat_columnar[flat_keys[0]])):
    flat_row = dict(zip(flat_keys, [flat_columnar[k][i] for k in flat_keys]))
    row = flatten_dict.unflatten(flat_row)
    rows.append(row)

  return rows


def get_gpu_memory_utilization_stats():
  # print(torch.cuda.memory_summary())
  if torch.cuda.is_available():
    stats = []
    for i in range(torch.cuda.device_count()):
      total_memory = torch.cuda.get_device_properties(i).total_memory
      allocated_memory = torch.cuda.memory_allocated(i)
      free_memory = total_memory - allocated_memory

      stats.append(
        {
          "free_memory": free_memory,
          "allocated_memory": allocated_memory,
          "total_memory": total_memory,
        }
      )
    return stats
  return []


def split_dataset(dataset: ds.Dataset, batch_size: int) -> list[ds.Dataset]:
  """
  Split a Hugging Face dataset into batches of a fixed size.

  Args:
  dataset (Dataset): The Hugging Face dataset to split.
  batch_size (int): The size of each batch.

  Returns:
  List[Dataset]: A list of Dataset objects, each representing a batch.
  """
  if not isinstance(dataset, ds.Dataset):
    raise TypeError("Input must be a Hugging Face Dataset object")

  if batch_size <= 0:
    raise ValueError("Batch size must be positive")

  dataset_size = len(dataset)
  num_batches = (dataset_size + batch_size - 1) // batch_size  # Ceiling division

  batches = [
    dataset.select(range(i * batch_size, min((i + 1) * batch_size, dataset_size)))
    for i in range(num_batches)
  ]

  return batches


def create_experiment_name(tags: tp.Sequence[str]) -> str:
  """
  Creates a time-stamped experiment name by combining a set of tags.

  Args:
  tags (List[str]): A list of strings representing tags for the experiment.

  Returns:
  str: A string combining the tags and a timestamp in the format "tag1_tag2_YYYYMMDDHHMMSS".
  """
  # Join the tags with underscores
  tag_string = "_".join(tags)

  # Get the current timestamp
  timestamp = datetime.now().strftime("%Y%m%d%H%M%S")

  # Combine tags and timestamp
  experiment_name = f"{tag_string}_{timestamp}"

  return experiment_name


def iterjsonls(jsonstr):
  """Iterate over a JSONL string."""
  import json

  for line in jsonstr.splitlines():
    yield json.loads(line)


def iterjsonl(file_like):
  import json

  with upath.UPath(file_like).open("r") as f:
    for line in f:
      yield json.loads(line)


def load_json(file_like):
  import json

  with upath.UPath(file_like).open("r") as f:
    return json.load(f)


def load_yaml_with_includes(path):
  """Loads a YAML file resolving any `include` directives.

  All includes are resolved relative to `path`
  Example:
  --------
  cfg:
  ```yaml
  include: base.yaml
  key: value
  ```
  base.yaml:
  ```yaml
  key2: value2
  ```
  >>> load_yaml_with_includes('cfg.yaml')
  {'key': 'value', 'key2': 'value2'}

  Args:
      path: path to base yaml file

  Returns:
      combined config, with `include` resolved and merged.
  """
  path = upath.UPath(path)
  raw_cfg = load_yaml(path)

  base_cfg = {}
  if "include" in raw_cfg:
    base_cfg_path = path.parent / raw_cfg["include"]
    if not base_cfg_path.exists():
      raise FileNotFoundError(f"File not found: {base_cfg_path.as_posix()}")
    base_cfg = load_yaml(base_cfg_path)
    raw_cfg.pop("include")

  return {**base_cfg, **raw_cfg}


def map_rows(fn: tp.Callable[[dict], dict], batch: tp.Mapping[str, np.ndarray]) -> dict:
  """
  Apply a function to each row of a batch of nested dictionaries with numpy array leaves.

  This function flattens the input batch, applies the given function to each row,
  and then reconstructs the nested structure with the results.

  Args:
      fn: A function that takes a dictionary representing a single row and returns a dictionary.
      batch: A nested dictionary with numpy arrays as leaf values. All arrays are expected
              to have the same length in their first dimension.

  Returns:
      A nested dictionary with the same structure as the input batch, but with the results
      of applying fn to each row.

  Example:
      >>> batch = {
      ...     'input': np.array([1, 2, 3]),
      ...     'metadata': {
      ...         'timestamp': np.array([1000, 2000, 3000]),
      ...         'source': np.array([10, 20, 30])
      ...     }
      ... }
      >>> def process_row(row):
      ...     return {
      ...         'output': row['input'] * 2,
      ...         'meta_sum': row['metadata']['timestamp'] + row['metadata']['source']
      ...     }
      >>> result = map_rows(process_row, batch)
      >>> print(result)
      {'output': array([2, 4, 6]), 'meta_sum': array([1010, 2020, 3030])}
  """
  import collections

  import flatten_dict

  ref_key = next(iter(batch))
  flat_batch = flatten_dict.flatten(
    batch,
  )
  # check all arrays have the same length
  n = len(batch[ref_key])
  for k, v in flat_batch.items():
    assert len(v) == n, (
      "Expected all arrays to have the same length, "
      f"but found lengths: {[('.'.join(k), len(v)) for k, v in flat_batch.items()]}"
    )
  results = collections.defaultdict(list)
  n = len(batch[ref_key])
  for i in range(n):
    row = flatten_dict.unflatten({k: v[i] for k, v in flat_batch.items()})
    row = fn(row)
    row = flatten_dict.flatten(row)
    for k, v in row.items():
      results[k].append(v)

  return flatten_dict.unflatten(
    results,
  )


def batch_imap(
  func: tp.Callable[[tp.Sequence[T]], tp.Sequence[S]],
  iterable: tp.Iterable[T],
  num_workers: int | None = None,
  min_batch_size: int = 1,
  parallel_strategy: tp.Literal["thread", "process"] = "process",
) -> tp.Iterable[S]:
  # Convert iterable to a list to get its length
  items = list(iterable)
  total_items = len(items)

  # Determine the number of workers
  num_workers = num_workers or mp.cpu_count()

  # Calculate batch size
  batch_size = max(min_batch_size, math.ceil(total_items / num_workers))

  def yield_batches():
    for i in range(0, total_items, batch_size):
      yield items[i : i + batch_size]

  if num_workers > 1:
    with mp.Pool(processes=num_workers) as pool:
      for result_batch in pool.imap(func, yield_batches()):
        yield from result_batch
  else:
    for batch in yield_batches():
      yield from func(batch)


def _worker_fn(
  args: tuple[np.ndarray, pd.DataFrame],
  func: tp.Callable[[pd.DataFrame], pd.DataFrame],
):
  indices, df = args
  return indices, func(df)


def map_dataframe(
  df: pd.DataFrame,
  func: tp.Callable,
  batch_size: int = 256,
  n_workers: int = -1,
  parallelism: str = "process",
  progress: bool = True,
) -> pd.DataFrame:
  """
  Apply a function to a DataFrame in batches, potentially using parallel execution.

  Args:
      df (pd.DataFrame): Input DataFrame
      func (Callable): Function to apply to each batch
      batch_size (int): Number of rows in each batch
      n_workers (int): Number of worker processes/threads (-1 for CPU count)
      parallelism (str): Type of parallelism, 'process' or 'thread'

  Returns:
      pd.DataFrame: Processed DataFrame
  """
  from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed

  if n_workers == -1:
    n_workers = mp.cpu_count()

  if parallelism not in ["process", "thread", "none"]:
    raise ValueError("parallelism must be either 'process' 'none' or 'thread'")
  if parallelism == "none":
    return func(df)
  batch_indices = np.array_split(np.arange(len(df)), np.ceil(len(df) / batch_size))
  batches = [df.iloc[batch_idx] for batch_idx in batch_indices]

  executor_class = (
    ProcessPoolExecutor if parallelism == "process" else ThreadPoolExecutor
  )

  with tqdm(total=len(df), disable=not progress) as pbar:
    with executor_class(max_workers=n_workers) as executor:
      futures = []

      for batch_idx, batch in zip(batch_indices, batches):
        future = executor.submit(_worker_fn, (batch_idx, batch), func)
        futures.append(future)

      result_dfs = []
      result_indices = []
      for result in as_completed(futures):
        indices, processed_batch = result.result()
        result_indices.append(indices)
        result_dfs.append(processed_batch)
        pbar.update(len(indices))

      result_df = pd.concat(result_dfs)
      result_index = np.concatenate(result_indices)

      return result_df.iloc[np.argsort(result_index)]


def pd_unique(df: pd.DataFrame, subset: tp.Sequence[str]):
  """
  Function to return unique_idx and inverse_idx using pandas.duplicated

  Parameters:
  df (pandas.DataFrame): Input DataFrame
  subset (list): List of column names to consider for uniqueness

  Returns:
  tuple: (unique_idx, inverse_idx)
      unique_idx: Index of unique rows
      inverse_idx: Indices to reconstruct the original DataFrame
  """
  # Reset index to ensure we have a sequential integer index
  df = df.reset_index(drop=True)

  # Find duplicates
  is_duplicate = df.duplicated(subset=subset, keep="first")

  # Get unique_idx
  unique_idx = np.where(~is_duplicate)[0]

  # Create a mapping of unique rows to their indices
  unique_df = df.loc[unique_idx, subset]
  unique_dict = {tuple(row): idx for idx, row in enumerate(unique_df.values)}

  # Generate inverse_idx
  inverse_idx = np.array([unique_dict[tuple(row)] for row in df[subset].values])

  return unique_idx, inverse_idx


class PerfLogger:
  def __init__(self, log_fn: tp.Callable[[str], None] = logger.info):
    self._logs = [("start", time.time())]
    self.log_fn = log_fn

  def log(self, tag: str, *args, **kwargs):
    self._logs.append((tag, time.time()))
    elapsed = self._logs[-1][1] - self._logs[-2][1]
    self.log_fn(f"[{elapsed:.02f}s] {tag}", *args, **kwargs)


def pmap(
  fn: tp.Callable[[T], S], iterable: tp.Iterable[T], n_workers: int = -1
) -> tp.Iterator[S]:
  """
  Parallel map function using multiprocessing.

  Args:
      fn: Function to apply to each item in the iterable.
      iterable: Iterable to apply the function to.
      n_workers: Number of worker processes (-1 for CPU count).

  Returns:
      List: List of results from applying the function to each item in the iterable.
  """
  if n_workers == -1:
    n_workers = mp.cpu_count()
  print("LAUNCHING POOL")
  with mp.Pool(n_workers) as pool:
    yield from pool.imap(fn, iterable)
