import argparse
import bisect
import dataclasses
import os
import sys
import typing as tp

import more_itertools
import numpy as np
import pandas as pd
import ray
import ray.actor
import torch
import upath
from loguru import logger
from tqdm.auto import tqdm
from transformers import (
  AutoConfig,
  AutoModelForCausalLM,
  AutoTokenizer,
)

from llm_inference import ray_utils, scoring_utils, utils
import multiprocessing as mp

os.environ["TOKENIZERS_PARALLELISM"] = "false"

TORCH_DTYPES = {
  "fp32": torch.float32,
  "fp16": torch.float16,
  "bf16": torch.bfloat16,
}


def format_index_names(row):
  results = []
  for k, v in row.items():
    results.append(f"{k}={v}")
  return "_".join(results)


def write_score_outputs(
  batch_inputs: pd.DataFrame,
  batch_output_dicts: list[dict],
  parquet_output_path: upath.UPath,
  tensor_output_folder: upath.UPath | None = None,
  tensor_keys_to_save: tp.Sequence[str] = ("hidden_states", "logits"),
  index_columns: tp.Sequence[str] = ("idx", "dataset_idx", "sample_idx"),
  save_tensor_dtype=np.float16,
  temperatures: tp.Sequence[float] = (),
):
  batch_indices = batch_inputs[list(index_columns)]
  indices = batch_indices.apply(format_index_names, axis=1).values

  # Save hidden states to disk
  for output_idx, batch_output in zip(indices, batch_output_dicts):
    # ---------------------------------------------------------------------------- #
    #                           Maybe save hidden states                           #
    # ---------------------------------------------------------------------------- #
    if "hidden_states" in tensor_keys_to_save:
      if batch_output["hidden_states"] is not None and tensor_output_folder:
        hidden_state_output_file = (
          tensor_output_folder / f"hidden-state-{output_idx}.npy"
        )
        with hidden_state_output_file.open("wb") as f:
          np.savez(
            f,
            **{
              str(k): v.astype(save_tensor_dtype)
              for k, v in batch_output["hidden_states"].items()
            },
          )
        del batch_output["hidden_states"]
        batch_output["hidden_states"] = utils.get_relative_path(
          hidden_state_output_file, tensor_output_folder.parent
        )
    else:
      if "hidden_states" in batch_output:
        del batch_output["hidden_states"]
    # ---------------------------------------------------------------------------- #
    #                               maybe save logits                              #
    # ---------------------------------------------------------------------------- #
    if "logits" in tensor_keys_to_save:
      if batch_output["logits"] is not None and tensor_output_folder:
        logits_output_file = tensor_output_folder / f"logits-{output_idx}.npy"
        with logits_output_file.open("wb") as f:
          np.save(f, batch_output["logits"].astype(save_tensor_dtype))
        del batch_output["logits"]
        batch_output["logits"] = get_relative_path(
          logits_output_file, tensor_output_folder.parent
        )
    else:
      if "logits" in batch_output:
        del batch_output["logits"]
    if "attentions" in batch_output:
      del batch_output["attentions"]

  # Save scores to parquet
  batch_outputs_df = pd.DataFrame(batch_output_dicts)
  batch_outputs_with_indices = pd.concat(
    [
      batch_indices.reset_index(drop=True),
      batch_outputs_df.reset_index(drop=True),
    ],
    axis=1,
  )
  utils.write_parquet(
    {
      k: batch_outputs_with_indices[k].values
      for k in batch_outputs_with_indices.columns
    },
    parquet_output_path.as_posix(),
  )


def _load_tokenizer(
  model_id,
  padding_side: tp.Literal["left", "right"] = "left",
  trust_remote_code: bool = True,
  set_pad_to_eos_token: bool = True,
):
  tokenizer = AutoTokenizer.from_pretrained(
    model_id, padding_side=padding_side, trust_remote_code=trust_remote_code
  )
  if set_pad_to_eos_token:
    if tokenizer.pad_token_id is None:
      tokenizer.pad_token_id = tokenizer.eos_token_id
      tokenizer.pad_token = tokenizer.eos_token

  return tokenizer


def _load_model_and_tokenizer(
  model_id: str,
  dtype: torch.dtype = torch.float32,
  padding_side: tp.Literal["left", "right"] = "left",
  model_parallel: int = 1,
  trust_remote_code: bool = True,
  set_pad_to_eos_token: bool = True,
  use_flash_attention: bool = True,
):
  model_cfg = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
  if use_flash_attention:
    try:
      logger.info("Trying to load model with flash attention")
      model = AutoModelForCausalLM.from_pretrained(
        model_id,
        config=model_cfg,
        torch_dtype=dtype,
        attn_implementation="flash_attention_2",
        trust_remote_code=trust_remote_code,
      ).to("cuda")
    except:  # noqa: E722
      logger.info(
        "Failed to load model with flash attention. Loading model without flash attention"
      )
      model = AutoModelForCausalLM.from_pretrained(
        model_id,
        config=model_cfg,
        device_map="auto" if model_parallel > 1 else {"": "cuda"},
        torch_dtype=dtype,
        trust_remote_code=trust_remote_code,
      )
  else:
    logger.info("Flash attention is disabled")
    model = AutoModelForCausalLM.from_pretrained(
      model_id,
      config=model_cfg,
      device_map="auto" if model_parallel > 1 else {"": "cuda"},
      torch_dtype=dtype,
      trust_remote_code=trust_remote_code,
    )

  total_memory = torch.cuda.memory_allocated(model.device) / 1024**3
  logger.info("Loaded model into GPU with {:.2f} GB memory", total_memory)
  tokenizer = _load_tokenizer(
    model_id,
    padding_side=padding_side,
    trust_remote_code=trust_remote_code,
    set_pad_to_eos_token=set_pad_to_eos_token,
  )
  return model, tokenizer


def _scoring_worker(
  model_id: str,
  dataset: pd.DataFrame,
  token_ids_column: str = "token_ids_to_score",
  dtype: torch.dtype = torch.float32,
  padding_side: tp.Literal["left", "right"] = "left",
  set_pad_to_eos_token: bool = True,
  output_hidden_states: bool = False,
  layers: int | list[int] | None = None,
  batch_size: int = 32,
  model_parallel: int = 1,
  trust_remote_code: bool = True,
  batching_strategy: tp.Literal["adaptive", "fixed"] = "adaptive",
  use_flash_attention: bool = True,
  temperatures: tp.Sequence[float] | None = None,
):
  model, tokenizer = _load_model_and_tokenizer(
    model_id,
    dtype=dtype,
    padding_side=padding_side,
    model_parallel=model_parallel,
    trust_remote_code=trust_remote_code,
    set_pad_to_eos_token=set_pad_to_eos_token,
    use_flash_attention=use_flash_attention,
  )

  # 1. Figure out batch sizes based on sequence lengths
  input_lengths: np.ndarray = tp.cast(
    np.ndarray, dataset[token_ids_column].str.len().values
  )
  logger.info("Average sequence length: {}", np.mean(input_lengths))
  max_seq_length = np.max(input_lengths)
  # choose from a predefined set of batch sizes we want to use
  partitions = [2**n for n in range(6, 15)]
  # only keep sequence lengths that are possible given the maximum sequence length
  partitions = partitions[: bisect.bisect_right(partitions, max_seq_length) + 1]
  # calculate max batch size for this sequence length
  batch_sizes = [
    utils.get_maximum_batch_size_for_seq_len(
      lambda input_ids: scoring_utils.score_batch(
        model=model,
        input_ids=input_ids,
        tokenizer=tokenizer,
        output_hidden_states=output_hidden_states,
        hidden_states_to_return=layers,
      ),
      device=model.device,
      seq_len=n,
    )
    for n in partitions
  ]
  all_batch_indices = utils.create_batches_by_seq_len(
    batch_sizes, partitions, input_lengths
  )
  # randomize order so we get good estimates of time remaining
  all_batch_indices = [
    all_batch_indices[i] for i in np.random.permutation(len(all_batch_indices))
  ]
  logger.info("Finished computing batch sizes for partitions")
  batch_lens = [len(x) for x in all_batch_indices]
  logger.info("Maximum batch size: {}", max(batch_lens))
  logger.info("Minimum batch size: {}", min(batch_lens))
  logger.info("Average batch size: {}", np.mean(batch_lens))
  logger.info("Number of batches: {}", len(all_batch_indices))
  logger.info(f"{batch_lens=}")

  # 2. Score the dataset
  for batch_indices in all_batch_indices:
    batch_inputs = dataset.iloc[batch_indices]
    batch_input_ids = [
      np.array(x, dtype=int) for x in batch_inputs[token_ids_column].values
    ]
    batch_outputs = scoring_utils.score_batch(
      model=model,
      input_ids=batch_input_ids,
      tokenizer=tokenizer,
      output_hidden_states=output_hidden_states,
      hidden_states_to_return=layers,
    )

    batch_output_dicts = [dataclasses.asdict(x) for x in batch_outputs]
    # Remove logits/attention from each dict
    for output_dict in batch_output_dicts:
      # ---------------------------------------------------------------------------- #
      #                   Calculate logprobs based on temperatures                   #
      # ---------------------------------------------------------------------------- #
      if temperatures is not None:
        if "logits" not in output_dict:
          raise ValueError(
            "Cannot calculate logprobs without logits. Logits not found in output dict"
          )
        logits = torch.tensor(output_dict["logits"], device="cpu")
        token_ids = torch.tensor(
          output_dict["token_ids"], device="cpu", dtype=torch.long
        )
        # torch.gather(logits, -1, token_ids.unsqueeze(-1))
        for temperature in temperatures:
          scaled_logits = logits / temperature

          logprobs = torch.nn.functional.log_softmax(scaled_logits, dim=-1)
          token_logprobs = (
            torch.gather(logprobs, -1, token_ids.unsqueeze(-1)).squeeze(-1).numpy()
          )
          output_dict[f"token_logprobs@T={temperature}"] = token_logprobs

      # ---------------------------------------------------------------------------- #
      #                             Remove large columns                             #
      # ---------------------------------------------------------------------------- #
      if "logits" in output_dict:
        del output_dict["logits"]
      if "attentions" in output_dict:
        del output_dict["attentions"]
    yield batch_inputs, batch_output_dicts


def _process_input_df(
  input_df: pd.DataFrame,
  tokenizer,
  input_keys: tp.Sequence[str],
  output_key: str = "token_ids_to_score",
  dedup: bool = False,
):
  logger.info(
    "Tokenizing input columns.\n" "Input columns: {}\n" "Output column: {}",
    input_keys,
    output_key,
  )

  encoded_column = []
  for i in input_df.index:
    values = [input_df.at[i, key] for key in input_keys]
    assert all(isinstance(v, str) for v in values) or all(
      isinstance(v, (list, tuple, np.ndarray)) for v in values
    ), (
      f"Input keys must be either be all strings or all sequences of integers. \n"
      f"Found types: {dict((key, type(value)) for key, value in zip(input_keys, values))}"
    )
    if isinstance(values[0], str):
      encoded_column.append(tokenizer.encode("\n".join(values)))
    else:
      encoded_column.append(np.concatenate(values))
  input_df[output_key] = encoded_column

  if dedup:
    old_len = len(input_df)
    logger.info("Trying to drop duplicates.")
    # breakpoint()
    if isinstance(input_df["token_ids_to_score"].iloc[0], (list, np.ndarray)):
      input_df["token_ids_to_score"] = input_df["token_ids_to_score"].apply(tuple)
    input_df = input_df.drop_duplicates(subset=["token_ids_to_score"])
    logger.info(
      "Dropped duplicates, old length: {} new length: {}", old_len, len(input_df)
    )

  # log some sample inputs
  logger.info("Sample inputs:")
  for i in range(5):
    text = tokenizer.decode(input_df[output_key].iloc[i])
    logger.info("Input:\n{}", text)

  return input_df


def run_scoring(
  model_id: str,
  input_file: upath.UPath,
  *,
  input_keys: tp.Sequence[str],
  dedup: bool = False,
  output_file: upath.UPath,
  output_hidden_states: bool = False,
  temperatures: tp.Sequence[float] | None = None,
  layers: int | list[int] | None = None,
  batch_size: int = 4,
  concurrency: int = -1,
  model_parallel: int = 1,
  dtype: str = "fp32",
  trust_remote_code: bool = True,
  index_columns: tp.Sequence[str] = ("idx", "dataset_idx", "sample_idx"),
  use_flash_attention: bool = True,
):
  # ---------------------------------------------------------------------------- #
  #                              Validate arguments                              #
  # ---------------------------------------------------------------------------- #
  input_file = upath.UPath(input_file)
  output_file = upath.UPath(output_file)
  logger.info(input_file)
  logger.info(output_file)
  if dtype == "fp32":
    logger.warning(
      "You are using fp32 for scoring. This may be slower than using fp16 or bf16. Consider using --dtype=fp16 or --dtype=bf16"
    )
  logger.info("Batch Size: {}", batch_size)
  logger.info("Torch Dtype: {}", dtype)
  if concurrency == -1:
    concurrency = torch.cuda.device_count() // model_parallel
    logger.info(
      "Using {} processes ({} gpus per process) for scoring",
      concurrency,
      model_parallel,
    )

  if concurrency * model_parallel > torch.cuda.device_count():
    raise ValueError(
      f"Requested {concurrency * model_parallel} GPUs for scoring, "
      f"but only {torch.cuda.device_count()} are available"
    )

  # ---------------------------------------------------------------------------- #
  #             Check if there is a config that contains temperatures            #
  # ---------------------------------------------------------------------------- #
  if temperatures is None:
    logger.info("No temperatures provided. Checking input directory for config file")
    temperatures = []
  else:
    temperatures = list(temperatures)
  config_file = input_file.parent / "config.yaml"
  if config_file.exists():
    logger.info("Found config file: {}", config_file.as_posix())
    config = utils.load_yaml(config_file)
    if "temperature" in config:
      if config["temperature"] not in temperatures:
        logger.info("Using temperatures from config file: {}", config["temperature"])
        temperatures.append(config["temperature"])

  logger.info("Final temperatures: {}", temperatures)
  # ---------------------------------------------------------------------------- #
  #                                Load input file                               #
  # ---------------------------------------------------------------------------- #
  input_df = pd.read_parquet(
    input_file, storage_options=dict(input_file.storage_options)
  )

  # make sure either (input_ids, generated_token_ids) or (input_prompts, generated_text) exist
  def _validate_input_keys():
    for k in input_keys:
      if k not in input_df:
        raise ValueError(
          f"Column '{k}' not found in input file. Available columns: {input_df.columns}"
        )

      # each key must be: str or sequence(list[int], tuple[int])
      exemplars = input_df[k].iloc[:2]
      for e in exemplars:
        if isinstance(e, str):
          continue
        if isinstance(e, (list, tuple, np.ndarray)):
          continue
        raise ValueError(
          f"Column '{k}' must contain either strings or sequences of integers, but found: {type(e)}"
        )

  _validate_input_keys()
  # ---------------------------------------------------------------------------- #
  #                              Process input file                              #
  # ---------------------------------------------------------------------------- #
  tokenizer = _load_tokenizer(
    model_id,
    trust_remote_code=trust_remote_code,
    set_pad_to_eos_token=True,
    padding_side="left",
  )
  input_df = _process_input_df(
    input_df,
    tokenizer,
    input_keys=input_keys,
    dedup=dedup,
  )
  output_folder = output_file.parent
  hidden_state_output_folder = output_folder / output_file.stem
  hidden_state_output_folder.mkdir(exist_ok=True)
  # check cache
  chunk_manager = utils.ChunkManager(output_file, index_cols=index_columns)
  chunk_state = chunk_manager.get_chunk_state()
  dataset_indices = zip(*(input_df[col].values for col in index_columns))
  if output_hidden_states:
    cached_indices = set()
    for index_row in chunk_state.cached_indices:
      index_dict = {k: idx for k, idx in zip(index_columns, index_row)}
      hidden_state_file = (
        hidden_state_output_folder
        / f"hidden-state-{format_index_names(index_dict)}.npy"
      )
      if hidden_state_file.exists():
        cached_indices.add(index_row)
      else:
        logger.warning(
          "Missing hidden state file for index row: {}",
          format_index_names(index_dict),
        )

  else:
    cached_indices = set(chunk_state.cached_indices)

  # remove from input_df
  dataset_indices = [
    idx
    for idx, row_index in enumerate(dataset_indices)
    if row_index not in cached_indices
  ]
  logger.info("Completed {} rows already in cache", len(list(cached_indices)))
  if not dataset_indices:
    logger.info("All rows are already in cache. Checking if chunks need to be merged")
    if not output_file.exists():
      logger.info("Merging chunks")
      chunk_manager.merge_chunks(sort_by_index=True)

    return

  input_df = input_df.iloc[dataset_indices]
  logger.info("Scoring {} rows", len(input_df))
  input_indices = np.arange(len(input_df))
  divided_indices = more_itertools.divide(concurrency, input_indices)

  # ray_utils.initialize_ray(num_cpus=mp.cpu_count(), num_gpus=torch.cuda.device_count())
  ray.init(num_cpus=concurrency, num_gpus=torch.cuda.device_count())

  _worker = ray.remote(num_gpus=model_parallel, num_cpus=1)(_scoring_worker)
  actors = [
    _worker.remote(
      model_id,
      input_df.iloc[shard_indices],
      token_ids_column="token_ids_to_score",
      dtype=TORCH_DTYPES[dtype],
      padding_side="left",
      output_hidden_states=output_hidden_states,
      layers=layers,
      batch_size=batch_size,
      model_parallel=model_parallel,
      trust_remote_code=trust_remote_code,
      batching_strategy="adaptive",
      use_flash_attention=use_flash_attention,
      temperatures=temperatures,
    )
    for shard_indices in divided_indices
  ]

  output_stream = utils.ray_as_completed(actors)
  writer_pool = utils.FailFastThreadPoolExecutor(max_workers=2)
  chunk_idx = chunk_state.chunk_idx
  tensor_keys_to_save = tuple(
    k for k, v in {"hidden_states": output_hidden_states}.items() if v
  )
  logger.info("Starting inference loop")
  with tqdm(total=len(input_df), desc="Scoring") as bar:
    for batch_inputs, batch_output_dicts in output_stream:
      writer_pool.submit(
        write_score_outputs,
        batch_inputs,
        batch_output_dicts,
        parquet_output_path=chunk_manager.get_chunk_path(chunk_idx),
        tensor_output_folder=hidden_state_output_folder,
        tensor_keys_to_save=tensor_keys_to_save,
        index_columns=index_columns,
      )
      bar.update(len(batch_output_dicts))
      chunk_idx += 1

  logger.info("Finished scoring all rows. Waiting for writing to finish..")
  writer_pool.shutdown()
  logger.info("Merging chunks")
  chunk_manager.merge_chunks(sort_by_index=True)
  logger.info("Finished merging chunks")
  logger.info("Output saved to: {}", output_file.as_posix())
  ray.shutdown()


def _parse_float_list(s):
  # Remove any whitespace and split the string by commas
  elements = s.strip().split(",")

  result = []
  for i, e in enumerate(elements):
    e = e.strip()
    if not e:
      continue  # Skip empty elements
    try:
      result.append(float(e))
    except ValueError:
      raise ValueError(f"Invalid float value '{e}' at position {i}")

  if not result:
    raise ValueError("No valid float values found in the input string")

  return result


if __name__ == "__main__":
  parser = argparse.ArgumentParser()
  parser.add_argument("--model", type=str, help="HuggingFace model id", required=True)
  parser.add_argument("--inputs", type=str, help="Input parquet file", required=True)
  parser.add_argument("--outputs", type=str, help="Output parquet file", required=True)
  parser.add_argument(
    "--hidden-states",
    action="store_true",
    help="whether or not to output hidden states",
  )
  parser.add_argument(
    "--temperatures",
    type=str,
    help="Comma separated list of temperatures to use for calculating logprobs (e.g. --temperatues 0.2, 0.6)",
  )
  parser.add_argument(
    "--layers",
    help=(
      "Comma separated list of layers to output (e.g. '-1,-2', '0')."
      " If not specified, all hidden states are output."
    ),
    default=None,
  )
  parser.add_argument(
    "--batch-size",
    type=int,
    help="Batch size for scoring",
    default=4,
  )
  parser.add_argument(
    "--concurrency",
    type=int,
    help="Number of concurrent actors to use. If set to -1, set to number of available GPUs",
    default=-1,
  )
  parser.add_argument(
    "--model-parallel",
    type=int,
    help="Number of model parallel GPUs to use per concurrent process",
    default=1,
  )
  # Input arguments
  parser.add_argument(
    "--input-keys",
    help="Comma separated list of keys in input data to be scored. Each key must point to a string or token ids. Consecutive keys are concatenated",
    default="input_ids,token_ids",
  )

  parser.add_argument(
    "--dedup",
    action="store_true",
  )
  parser.add_argument(
    "--dtype",
    type=str,
    choices=TORCH_DTYPES.keys(),
    default="fp16",
    help="Torch dtype to use for scoring",
  )
  parser.add_argument(
    "--score-inputs-only", action="store_true", help="Score only the input-ids "
  )
  parser.add_argument(
    "--no-flash-attention",
    action="store_true",
    help="Disable flash attention",
  )
  parser.add_argument(
    "--index-keys",
    type=str,
    help="Comma separated list of keys to use as index columns",
    default="idx,dataset_idx,sample_idx",
  )
  args = parser.parse_args()
  layers = None
  if args.layers is not None:
    layers = [int(x) for x in args.layers.split(",")]

  contains_gcs_paths = args.inputs.startswith("gs://") or args.outputs.startswith(
    "gs://"
  )
  if not utils.check_gcs_credentials():
    logger.error(
      "GCS credentials not found but GCS paths are provided:\n"
      f"{args.inputs}\n{args.outputs}\n"
      "Please set GOOGLE_APPLICATION_CREDENTIALS environment variable to a valid GCS credentials file."
    )
    sys.exit(1)

  run_scoring(
    model_id=args.model,
    input_file=args.inputs,
    output_file=args.outputs,
    input_keys=args.input_keys.split(","),
    index_columns=args.index_keys.split(","),
    output_hidden_states=args.hidden_states,
    temperatures=_parse_float_list(args.temperatures) if args.temperatures else None,
    layers=layers,
    batch_size=args.batch_size,
    concurrency=args.concurrency,
    dtype=args.dtype,
    model_parallel=args.model_parallel,
    dedup=args.dedup,
    use_flash_attention=not args.no_flash_attention,
  )
