import dataclasses
import pathlib
import textwrap
import time
import typing as tp

import datasets as ds
import more_itertools
import numpy as np
import ray
import ray.actor
import ray.data
import torch
from loguru import logger
from tqdm.auto import tqdm
from vllm import LLM, SamplingParams

from llm_inference import utils


def repeat_ds(dataset: ds.Dataset, *, num_repeats: int):
  repeated_indices = np.repeat(np.arange(len(dataset), dtype=np.int32), num_repeats)
  return dataset.select(repeated_indices)


def add_indices_to_batch(batch, indices, *, num_samples: int = 1):
  return {
    **batch,
    "idx": indices,
    "dataset_idx": [i // num_samples for i in indices],
    "sample_idx": [i % num_samples for i in indices],
  }


def batched_apply_prompt(
  batch: dict[str, np.ndarray],
  *,
  prompt: tp.Callable[[dict], str],
  output_key: str,
):
  # batch[output_key] = np.array([prompt(d) for d in batch])
  prompts = []
  n = len(next(iter(batch.values())))
  for i in range(n):
    prompts.append(prompt({k: v[i] for k, v in batch.items()}))

  return {output_key: prompts}


def format_batch_indices(batch_indices: dict[str, tp.Sequence]) -> str:
  max_key_len = max(len(k) for k in batch_indices.keys())
  base = "\n".join(f"{k.rjust(max_key_len)}={v}" for k, v in batch_indices.items())
  return f"{base}"


@dataclasses.dataclass
class SamplingConfig:
  top_p: float = 0.95
  top_k: int = -1
  stop_tokens: list[str] = dataclasses.field(default_factory=lambda: [])
  max_new_tokens: int = 512
  max_prompt_tokens: int = 1536
  temperature: float = 0.2

  def to_vllm(self):
    return SamplingParams(
      top_p=self.top_p,
      top_k=self.top_k,
      stop=self.stop_tokens,
      max_tokens=self.max_new_tokens,  # type: ignore
      temperature=self.temperature,
    )

  def to_dict(self):
    return dataclasses.asdict(self)

  def replace(self, **kwargs):
    return dataclasses.replace(self, **kwargs)


BatchData = tp.Dict[str, np.ndarray]
BatchPreprocessor = tp.Callable[[BatchData], BatchData]


@dataclasses.dataclass
class ShardingConfig:
  num_shards: int
  shard_idx: int
  range: tuple[int, int]


@dataclasses.dataclass
class GenerationConfig:
  sampling_config: SamplingConfig
  tensor_parallel_size: int = 1
  trust_remote_code: bool = True
  enforce_eager: bool = False
  prompt_key: str | None = "prompt"
  input_ids_key: str | None = None


def split_dataset(N, k):
  # Calculate the minimum number of elements per shard
  min_elements_per_shard = N // k
  # Calculate the remainder to distribute the extra elements
  remainder = N % k

  shards = []
  start_idx = 0

  for i in range(k):
    # Shards that get an extra element
    if i < remainder:
      end_idx = start_idx + min_elements_per_shard + 1
    else:
      end_idx = start_idx + min_elements_per_shard

    shards.append([start_idx, end_idx])
    start_idx = end_idx

  return shards


def _vllm_worker(
  model_id: str,
  dataset: ds.Dataset,
  generation_cfg: GenerationConfig,
  batch_size: int = 32,
):
  llm = LLM(
    model_id,
    tensor_parallel_size=generation_cfg.tensor_parallel_size,
    enforce_eager=generation_cfg.enforce_eager,
    trust_remote_code=generation_cfg.trust_remote_code,
  )

  for batch in dataset.iter(batch_size):
    prompts = batch[generation_cfg.prompt_key]
    request_outputs = llm.generate(
      prompts=prompts,
      sampling_params=generation_cfg.sampling_config.to_vllm(),
      use_tqdm=False,
    )

    input_ids = []
    token_ids = []
    generated_text = []
    for request_output in request_outputs:
      input_ids.append(request_output.prompt_token_ids)
      token_ids.append(request_output.outputs[0].token_ids)
      generated_text.append(request_output.outputs[0].text)

    batch_outputs = {
      "prompt": prompts,
      "input_ids": input_ids,
      "token_ids": token_ids,
      "generated_text": generated_text,
    }

    yield batch, batch_outputs


def apply_preprocessors(
  dataset: ds.Dataset,
  preprocessors: tp.Sequence[BatchPreprocessor] | BatchPreprocessor,
  num_samples: int = 1,
):
  if callable(preprocessors):
    preprocessors = [preprocessors]
  preprocessed_ds = dataset
  for preprocessor in preprocessors:
    preprocessed_ds = preprocessed_ds.map(
      preprocessor,
      batched=True,
      desc=f"Applying preprocessor: {utils.describe_callable(preprocessor)}",
    )

  if num_samples > 1:
    preprocessed_ds = repeat_ds(preprocessed_ds, num_repeats=num_samples)

  return preprocessed_ds.map(
    add_indices_to_batch,
    with_indices=True,
    batched=True,
    fn_kwargs={"num_samples": num_samples},
    desc="Adding indices to batch",
  )


CHUNK_FORMAT = "chunk-{:06d}.parquet"


def distributed_generate_vllm_v2(
  model_id: str,
  hf_dataset: ds.Dataset,
  *,
  output_file: str | pathlib.Path,
  preprocessors: tp.Sequence[BatchPreprocessor] | BatchPreprocessor,
  generation_cfg: GenerationConfig,
  concurrency: int = -1,
  num_samples: int = 1,
  batch_size: int = 32,
  index_columns: tp.Sequence[str] = ("idx", "dataset_idx", "sample_idx"),
):
  import upath

  output_file = upath.UPath(output_file)
  chunk_manager = utils.ChunkManager(output_file, index_columns)
  chunk_state = chunk_manager.get_chunk_state()
  preprocessed_ds = apply_preprocessors(
    hf_dataset, preprocessors, num_samples=num_samples
  )
  dataset_indices = zip(*(preprocessed_ds[col] for col in index_columns))
  dataset_indices = list(
    idx
    for idx, row_index in enumerate(dataset_indices)
    if row_index not in chunk_state.cached_indices
  )
  logger.info(
    f"Total Dataset len: {len(preprocessed_ds)}, num cached: {len(chunk_state.cached_indices)}, num to process: {len(dataset_indices)}"
  )
  if not dataset_indices:
    logger.info("All indices are cached. Checking if they need to be merged.")
    if not output_file.exists():
      logger.info("Merging chunks")
      chunk_manager.merge_chunks(sort_by_index=True)
    logger.info("This dataset has already been processed. Exiting.")
    return
  preprocessed_ds = preprocessed_ds.select(dataset_indices)

  if concurrency == -1:
    concurrency = torch.cuda.device_count()
  ray.init(num_cpus=concurrency, num_gpus=concurrency)
  logger.info("Distributing dataset across {} shards", concurrency)
  all_shard_indices = more_itertools.divide(concurrency, range(len(preprocessed_ds)))
  _worker = ray.remote(num_gpus=1, num_cpus=1)(_vllm_worker)
  actors = [
    _worker.remote(
      model_id=model_id,
      dataset=preprocessed_ds.select(shard_indices),
      generation_cfg=generation_cfg,
      batch_size=batch_size,
    )
    for shard_indices in all_shard_indices
  ]

  async_writer = utils.FailFastThreadPoolExecutor(max_workers=1)
  logger.info("Starting generation with batch_size={}", batch_size)
  chunk_idx = chunk_state.chunk_idx
  with tqdm(total=len(preprocessed_ds)) as bar:
    for batch_inputs, batch_outputs in utils.ray_as_completed(actors):
      indices = {k: v for k, v in batch_inputs.items() if k in index_columns}

      logger.info(
        "Finished processing batch with indices:\n{}",
        format_batch_indices(indices),
      )
      async_writer.submit(
        utils.write_parquet,
        {**indices, **batch_outputs},
        chunk_manager.get_chunk_path(chunk_idx).as_posix(),
      )
      log_generations(batch_outputs)
      chunk_idx += 1
      bar.update(len(batch_inputs["idx"]))

  async_writer.shutdown()
  chunk_manager.merge_chunks(sort_by_index=True)
  ray.shutdown()


def log_generations(batch):
  for i, (prompt, generated_text) in enumerate(
    zip(batch["prompt"], batch["generated_text"]), start=1
  ):
    log_message = textwrap.dedent(f"""
        {'='*50}
        Generation {i}:
        {'='*50}
        Prompt:
        {prompt}
        
        Generated Text:
        {generated_text}
        {'='*50}
        """)
    logger.info(log_message)
    break
