import abc
import hashlib
import dataclasses
import functools
import multiprocessing as mp
import typing as tp

import datasets as ds
import numpy as np
import pandas as pd
from loguru import logger

from llm_inference import output_parsers, utils

RefType = tp.TypeVar("RefType", contravariant=True)
PredType = tp.TypeVar("PredType", contravariant=True)
BatchResult = tp.TypeVar("BatchResult", covariant=True)


T = tp.TypeVar("T", bound="Metric")


class Metric(abc.ABC):
  __execution_strategies__ = {"process", "none", "thread"}
  """Which parallelization strategies are supported by this metric."""

  @abc.abstractmethod
  def compute(
    self,
    preds: np.ndarray,
    refs: np.ndarray,
  ) -> pd.DataFrame:
    """Compute row-wise metrics for the given predictions and references.

    Args:
        preds: array of predictions
        refs: array of references

    Returns:
        pd.DataFrame: DataFrame containing the computed metrics, in same order as input.
    """
    ...

  @abc.abstractmethod
  def summarize(self, details: pd.DataFrame, index_key: str) -> dict:
    """Compute summary metrics from the detailed metrics.

    Here `details` should be the output of `compute`.

    Args:
        details: detailed metrics computed by `compute`, with added `index_key` column
        index_key: name of the column that points to indices in the original dataset. This can
          be used to compute metrics that require multiple rows to be aggregated.
    """
    ...


DictAccessor = tp.Callable[[dict], tp.Any]


@dataclasses.dataclass
class EvaluationConfig:
  metric: Metric
  """Metric to evaluate."""
  get_reference: DictAccessor
  """A function that given an input row as a dict gets"""
  output_parser: tp.Sequence[output_parsers.OutputParser] = (
    output_parsers.passthrough_output_parser,
  )
  """Parse raw predictions into a structured format."""
  output_parser_fallback: str = "[invalid]"
  execution_strategy: str = "none"
  """Parallelization strategy to use."""
  batch_size: int = 32
  """Number of items to process in each batch."""


@dataclasses.dataclass
class Evaluator:
  cfg: EvaluationConfig
  """Configuration for evaluation defining postprocesing of preds and refs."""
  prediction: str
  """Column which points to the predictions."""
  dataset_index_key: str = "dataset_idx"
  """Key to use to extract indices from row in input data."""
  other_index_keys: tp.Sequence[str] = ("sample_idx",)
  """Any other index keys to include in the output."""
  num_workers: int = 0
  """Number of parallel workers to use."""
  progress: bool = True
  """Whether to display a progress bar during computation."""
  extracted_ref_key: str = "extracted_ref"
  """Key to use for extracted reference in output."""
  extracted_pred_key: str = "extracted_pred"
  """Key to use for extracted prediction in output."""
  parser_error_key: str = "parse_error"
  """Key to use for parse error in output."""

  def run(
    self,
    pred_df: pd.DataFrame,
    ref_ds: ds.Dataset,
  ):
    if self.cfg.execution_strategy != "none" and self.num_workers < 1:
      logger.debug(
        "Execution strategy is {} but num_workers is {}. Using all available CPUs.",
        self.cfg.execution_strategy,
        self.num_workers,
      )
      self.num_workers = mp.cpu_count()

    preds: np.ndarray = pred_df[self.prediction].values
    pred_indices = pred_df[self.dataset_index_key].values

    apply_output_parser = functools.partial(
      output_parsers.try_parse_output,
      output_parsers=self.cfg.output_parser,
      fallback=self.cfg.output_parser_fallback,
    )
    logger.debug("Parsing predictions")
    parse_results = list(
      utils.batch_imap(
        apply_output_parser,
        preds,
        num_workers=self.num_workers,
      )
    )
    logger.debug("Parsing complete")
    parsed_predictions = np.array([p.value for p in parse_results], dtype="O")
    parse_errors = [
      str(p.error) if isinstance(p, output_parsers.ParseFailure) else None
      for p in parse_results
    ]
    logger.debug("Broadcasting refs")
    # broadcasted_refs = ref_ds.select(pred_indices)
    logger.debug("Extracting references")

    extracted_ref_ds = ref_ds.map(
      _extract_refs,
      fn_kwargs={"get_reference": self.cfg.get_reference},
      num_proc=self.num_workers,
    )
    logger.debug("References extracted")
    broadcasted_refs = extracted_ref_ds.select(pred_indices)
    extracted_refs = np.array(broadcasted_refs["ref"], dtype="O")
    if self.cfg.execution_strategy not in self.cfg.metric.__execution_strategies__:
      raise ValueError(
        f"Unknown parallelization strategy: '{self.cfg.execution_strategy}'. "
        f"Should be one of {self.cfg.metric.__execution_strategies__}"
      )
    logger.debug("Launching metric computation")
    details = batch_eval_metrics(
      pred_indices,
      parsed_predictions,
      extracted_refs,
      self.cfg.metric,
      batch_size=self.cfg.batch_size,
      num_workers=self.num_workers,
      parallel_strategy=self.cfg.execution_strategy,
    )
    details = details.assign(**{self.dataset_index_key: pred_indices})
    summary = self.cfg.metric.summarize(details, self.dataset_index_key)

    for key in self.other_index_keys:
      details[key] = pred_df[key].values

    details[self.extracted_ref_key] = extracted_refs
    details[self.extracted_pred_key] = parsed_predictions
    details[self.parser_error_key] = parse_errors

    return summary, details


def get_unique(*cols: np.ndarray):
  keys = np.array(
    [
      hashlib.md5(str(tuple(row)).encode()).hexdigest() for row in np.column_stack(cols)
    ],
    dtype=np.string_,
  )

  _, unique_idx, inverse_idx = np.unique(keys, return_index=True, return_inverse=True)

  return unique_idx, inverse_idx


def _extract_refs(row: dict, *, get_reference: tp.Callable, output_key: str = "ref"):
  return {output_key: get_reference(row)}


def _metric_worker(
  df: pd.DataFrame,
  *,
  metric: Metric,
):
  return metric.compute(df["pred"].to_numpy(), df["ref"].to_numpy())


def batch_eval_metrics(
  indices: np.ndarray,
  preds: np.ndarray,
  refs: np.ndarray,
  metric: Metric,
  *,
  batch_size: int = 32,
  num_workers: int = 0,
  parallel_strategy: tp.Literal["thread", "process"] = "process",
  progress=True,
) -> pd.DataFrame:
  if len(indices) != len(preds) or len(preds) != len(refs):
    raise ValueError(
      f"All inputs must have the same length. Got {len(indices)}, {len(preds)}, {len(refs)}"
    )
  # ---------------------------------------------------------------------------- #
  #                    1. Create dataframe, get unique indices                   #
  # ---------------------------------------------------------------------------- #
  if isinstance(refs[0], list):
    refs = np.array([tuple(r) for r in refs], dtype="O")

  df = pd.DataFrame({"idx": indices, "pred": preds, "ref": refs.tolist()})
  unique_idx, inverse_idx = utils.pd_unique(df, ["idx", "pred"])
  # ---------------------------------------------------------------------------- #
  #                           2. Process unique indices                          #
  # ---------------------------------------------------------------------------- #
  dedup_df = df.iloc[unique_idx].drop(columns=["idx"])
  unique_results_df = utils.map_dataframe(
    dedup_df,
    functools.partial(_metric_worker, metric=metric),
    batch_size=batch_size,
    n_workers=num_workers,
    parallelism=parallel_strategy,
    progress=progress,
  )
  # ---------------------------------------------------------------------------- #
  #                         3. Return broadcasted results                        #
  # ---------------------------------------------------------------------------- #

  return unique_results_df.iloc[inverse_idx]
