"""Preprocesses inference data for usage in CLM."""

import argparse
from pathlib import Path
from loguru import logger
import pandas as pd
import numpy as np
import typing as tp
from generative_prediction_sets import utils
import json
import upath

_LOADERS = {}


def register(name):
  def _decorator(func):
    _LOADERS[name] = func
    return func

  return _decorator


@register("hf")
def load_single_experiment(path: str | Path):
  path = upath.UPath(path)

  infer = pd.read_parquet((path / "infer.parquet").as_posix())
  logger.info(infer.generated_text.nunique())
  generation_scores = pd.read_parquet((path / "generation-scores.parquet").as_posix())

  infer["all_token_ids"] = infer.apply(
    lambda x: tuple(x["input_ids"]) + tuple(x["token_ids"]), axis=1
  )
  logger.debug(f"Len(infer): {len(infer)}")
  generation_scores["all_token_ids"] = generation_scores["token_ids"].apply(
    lambda x: tuple(x)
  )
  logger.debug(f"Len(generation_scores): {len(generation_scores)}")
  infer = pd.merge(
    infer,
    generation_scores[["all_token_ids", "token_logprobs"]],
    on="all_token_ids",
    how="left",
    validate="m:1",
  )

  logger.debug(f"Len(infer) after merge: {len(infer)}")

  def _compute_logprobs(row):
    input_len = len(row["input_ids"])
    return utils.compute_normalized_logprobs(row["token_logprobs"][input_len:])

  infer["normalized_logprobs"] = infer.apply(_compute_logprobs, axis=1)
  with (path / "eval.json").open("r") as f:
    eval_data = json.load(f)["aux"]
    if "results" in eval_data:
      eval_df = pd.DataFrame(eval_data["results"])
    else:
      eval_df = pd.DataFrame(eval_data)
      # eval_df["sample_idx"] = np.arange(len(eval_df)) % len(
      #   infer.query("dataset_idx == 0")
      # )
      # eval_df = eval_df.rename(columns={"passed": "correct"})
      # breakpoint()

  if "correct" not in eval_df.columns:
    for col in ["passed"]:
      if col in eval_df.columns:
        eval_df = eval_df.rename(columns={col: "correct"})
        break
  logger.info(f"Before eval: {infer.generated_text.nunique()}")
  merged = pd.merge(
    infer,
    eval_df[["dataset_idx", "sample_idx", "correct"]],
    on=["dataset_idx", "sample_idx"],
  )
  logger.debug(f"Len(merged): {len(merged)}")
  logger.info(
    f"After eval: {infer.generated_text.nunique()}, {infer.dataset_idx.nunique()=}"
  )
  merged = merged.rename(columns={"correct": "admissible"})

  return merged


@register("openai")
def load_openai_experiment(
  input_folder, eval_file="eval.json", generations_file="infer.parquet"
):
  generations_file = upath.UPath(input_folder) / generations_file
  logger.info("Loading generations from {}", generations_file)
  eval_file = upath.UPath(input_folder) / eval_file
  df = pd.read_parquet(generations_file.as_posix())

  if not eval_file.exists():
    eval_df = pd.read_parquet(upath.UPath(input_folder) / "eval.parquet")
    eval_df = eval_df.rename(columns={"passed": "correct"})
  else:
    with eval_file.open("r") as f:
      eval_data = json.load(f)["aux"]
      if "results" in eval_data:
        eval_df = pd.DataFrame(eval_data["results"])
      else:
        eval_df = pd.DataFrame(eval_data)
        eval_df["sample_idx"] = np.arange(len(eval_df)) % len(
          df.query("dataset_idx == 0")
        )
      eval_df = eval_df.rename(columns={"passed": "correct"})

  logger.info("Generations: {}", df.shape)
  logger.info("Eval: {}", eval_df.shape)
  merged = df.merge(eval_df, on=["dataset_idx", "sample_idx"])
  logger.info("Merged: {}", merged.shape)
  merged = merged.rename(columns={"correct": "admissible"})

  merged["normalized_logprobs"] = (
    merged["token_logprobs"]
    .apply(utils.compute_normalized_logprobs)
    .drop(columns=["token_logprobs"])
  )
  logger.info("Loaded {} samples (openai loader)", len(merged))
  return merged


def load_experiments(*paths: str | Path, loader):
  dfs = []
  i = 0
  for path in paths:
    logger.info("Loading path: {}", path)
    df = loader(path)
    df["dataset_idx"] = df["dataset_idx"] + i
    i += df["dataset_idx"].nunique()
    print(f"{i=}")
    print(f"{df.generated_text.nunique()=}")
    dfs.append(df)

  return pd.concat(dfs)


def load_indices(input_path, splits_path, strict=False):
  infer = upath.UPath(input_path) / "infer.parquet"
  infer_df = pd.read_parquet(infer.as_posix(), columns=["dataset_idx"])
  all_dataset_idx = infer_df["dataset_idx"].unique()
  with upath.UPath(splits_path).open("rb") as f:
    split_idx = np.load(f)
    train_idx = split_idx["train"]
    test_idx = split_idx["test"]
    # all_indices = np.concatenate([train_idx, test_idx])
    # missing_idx = np.where(~np.isin(all_indices, all_dataset_idx))[0]

    # if len(missing_idx) > 0:
    #   if strict:
    #     raise ValueError(
    #       f"Missing dataset indices in splits: {all_indices[missing_idx]}"
    #     )
    #   else:
    #     logger.warning(
    #       "Missing dataset indices in splits: {}", all_indices[missing_idx]
    #     )
    #     logger.warning("Removing missing indices from splits")
    #     # print before/after shapes
    #     mapping = {idx: i for i, idx in enumerate(all_dataset_idx)}
    #     train_before = len(train_idx)
    #     train_idx = np.array([mapping[idx] for idx in train_idx if idx in mapping])
    #     logger.info("Train split: {} -> {}", train_before, len(train_idx))
    #     test_before = len(test_idx)
    #     test_idx = np.array([mapping[idx] for idx in test_idx if idx in mapping])
    #     logger.info("Test split: {} -> {}", test_before, len(test_idx))
    #     # Ensure that all indices are in bounds
    #     assert train_idx.max() < len(all_dataset_idx), (
    #       f"Expected {train_idx.max()} < {len(all_dataset_idx)}"
    #     )
    #     assert test_idx.max() < len(all_dataset_idx), (
    #       f"Expected {test_idx.max()} < {len(all_dataset_idx)}"
    #     )
    seq_idx_to_dataset_idx = {idx: i for i, idx in enumerate(all_dataset_idx)}
    before_train = len(train_idx)
    train_idx = np.array(
      [
        seq_idx_to_dataset_idx[idx]
        for idx in train_idx
        if idx in seq_idx_to_dataset_idx
      ]
    )
    logger.info("Train split: {} -> {}", before_train, len(train_idx))
    if len(train_idx) != before_train:
      missing_idx = [idx for idx in train_idx if idx not in seq_idx_to_dataset_idx]
      logger.warning(
        "Some train indices were not found in the dataset: {}", missing_idx
      )
    before_test = len(test_idx)
    test_idx = np.array(
      [seq_idx_to_dataset_idx[idx] for idx in test_idx if idx in seq_idx_to_dataset_idx]
    )
    logger.info("Test split: {} -> {}", before_test, len(test_idx))
    if len(test_idx) != before_test:
      missing_idx = [idx for idx in test_idx if idx not in seq_idx_to_dataset_idx]
      logger.warning("Some test indices were not found in the dataset: {}", missing_idx)
    return train_idx, test_idx


def extract_clm_columns(
  merged: pd.DataFrame, score_column="normalized_logprobs", label_column="admissible"
):
  grouped = merged.groupby(["dataset_idx"]).agg(list)
  logger.info("STACKING SCORES")
  logprobs = np.stack(grouped[score_column].values.tolist())
  logger.info("STACKING LABELS")
  labels = np.stack(grouped[label_column].values.tolist())
  logger.info("GENERATING PAIRWISE ROUGEL SCORES")
  diversity_scores = utils.pairwise_comparisons_multi(
    grouped["generated_text"].values.tolist(), utils.compute_rouge_L
  )
  return logprobs, diversity_scores, labels


def save_experiment(
  logprobs: np.ndarray,
  diversity_scores: np.ndarray,
  labels: np.ndarray,
  output_folder: str | Path,
  train_idx: np.ndarray,
  test_idx: np.ndarray,
):
  output_folder = upath.UPath(output_folder)
  output_folder.mkdir(parents=True)

  split_dir = output_folder / "all"
  split_dir.mkdir(parents=True, exist_ok=True)
  with (split_dir / "diversity.npy").open("wb") as f:
    np.save(f, diversity_scores.astype(np.float32))
  with (split_dir / "scores.npy").open("wb") as f:
    np.save(f, logprobs.astype(np.float32))
  with (split_dir / "labels.npy").open("wb") as f:
    np.save(f, labels.astype(int))

  try:
    splits = [("train", train_idx), ("test", test_idx)]
    for split_name, split_idx in splits:
      split_dir = output_folder / split_name
      split_dir.mkdir(parents=True, exist_ok=True)

      diversity_scores_for_split = diversity_scores[split_idx]
      logprobs_for_split = logprobs[split_idx]
      labels_for_split = labels[split_idx]

      with (split_dir / "diversity.npy").open("wb") as f:
        np.save(f, diversity_scores_for_split.astype(np.float32))
      with (split_dir / "scores.npy").open("wb") as f:
        np.save(f, logprobs_for_split.astype(np.float32))
      with (split_dir / "labels.npy").open("wb") as f:
        np.save(f, labels_for_split.astype(int))
      with (split_dir / "indices.npy").open("wb") as f:
        np.save(f, split_idx.astype(int))
  except Exception as e:
    logger.error("Error saving splits: {}", e)


def main(
  inputs,
  output_folder,
  splits: str,
  concurrency: tp.Optional[int] = None,
  single: bool = False,
  loader: str = "hf",
  strict: bool = True,
):
  logger.info("Loading splits from {}", splits)

  dataloader = _LOADERS[loader]
  merged_df = load_experiments(*inputs, loader=dataloader)
  merged_df = merged_df.sort_values(by=["dataset_idx", "sample_idx"])
  logger.info(f"Loaded {len(merged_df)} samples")
  logger.info("Loading splits")
  train_idx, test_idx = load_indices(inputs[0], splits, strict=strict)
  logprobs, diversity_scores, labels = extract_clm_columns(merged_df)
  logger.info(f"Saving to {output_folder}")
  save_experiment(
    logprobs=logprobs,
    diversity_scores=diversity_scores,
    labels=labels,
    output_folder=output_folder,
    train_idx=train_idx,
    test_idx=test_idx,
  )
  logger.info("Done")


if __name__ == "__main__":
  parser = argparse.ArgumentParser()
  parser.add_argument(
    "--inputs",
    type=str,
    required=True,
    action="append",
  )
  parser.add_argument("--output-folder", type=str, required=True)
  parser.add_argument("--concurrency", type=int, default=None)
  parser.add_argument("--single", action="store_true")
  parser.add_argument("--loader", choices=_LOADERS.keys(), default="hf")
  parser.add_argument("--no-strict-index-check", action="store_true")
  parser.add_argument(
    "--splits",
    required=True,
    help="Path to npz file with keys 'train' and 'test', to define splits of the data.",
  )
  args = parser.parse_args()
  main(
    args.inputs,
    args.output_folder,
    splits=args.splits,
    concurrency=args.concurrency,
    single=args.single,
    loader=args.loader,
    strict=not args.no_strict_index_check,
  )
