import collections
import dataclasses
import pandas as pd
import itertools
import typing as tp

import numpy as np

from llm_inference import eval_utils
from llm_inference.metrics.code_eval.execution import check_correctness


def estimate_pass_at_k(num_samples, num_correct, k):
  """Estimates pass@k of each problem and returns them in an array."""

  def estimator(n: int, c: int, k: int) -> float:
    """Calculates 1 - comb(n - c, k) / comb(n, k)."""
    if n - c < k:
      return 1.0
    return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))

  if isinstance(num_samples, int):
    num_samples_it = itertools.repeat(num_samples, len(num_correct))
  else:
    assert len(num_samples) == len(num_correct)
    num_samples_it = iter(num_samples)

  return np.array(
    [estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)]
  )


@dataclasses.dataclass
class CodeEval(eval_utils.Metric):
  timeout: float = 3.0
  ks: tp.Sequence[int] = (1, 5, 10, 40, 80, 100)

  __execution_strategies__ = ("thread", "none")

  def compute(
    self,
    preds: np.ndarray,
    refs: np.ndarray,
  ):
    # from loguru import logger

    # logger.info("Starting code_eval for batch")
    results = collections.defaultdict(list)
    for pred, ref in zip(preds, refs):
      # logger.info("Check correctness")
      result = check_correctness(
        pred + "\n" + ref, timeout=self.timeout, task_id=0, completion_id=0
      )
      # logger.info("Complete: {}", result)
      results["passed"].append(result["passed"])
      results["output"].append(result["result"])
    return pd.DataFrame(results)

  def summarize(self, details: pd.DataFrame, index_key: str):
    num_samples = details.groupby(index_key).size().to_numpy()
    num_correct = details.groupby(index_key)["passed"].sum().to_numpy()
    summary = {}
    for k in self.ks:
      if k > num_samples.max():
        continue
      summary[f"pass@{k}"] = estimate_pass_at_k(num_samples, num_correct, k).mean()

    return summary
