from datetime import datetime
import functools
import json
import os
import upath

# need to disable GPU usage explicitly for execution
# otherwise, multiprocessing tensorflow programs may hang indefinitely
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
# disable tensorflow logging
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import dataclasses

from typing import Sequence

import numpy as np
from tqdm.auto import tqdm

# import logging
from loguru import logger
import pandas as pd
import datasets as ds
import concurrent.futures
from execution import check_correctness
import argparse
import multiprocessing as mp
import itertools
import typing as tp
from tqdm.auto import tqdm

# logger = logging.getLogger(__name__)


Inputs = tp.TypeVar("Inputs")
S = tp.TypeVar("S")
T = tp.TypeVar("T")


def pd_unique(df: pd.DataFrame, subset: tp.Sequence[str]):
  """
  Function to return unique_idx and inverse_idx using pandas.duplicated

  Parameters:
  df (pandas.DataFrame): Input DataFrame
  subset (list): List of column names to consider for uniqueness

  Returns:
  tuple: (unique_idx, inverse_idx)
      unique_idx: Index of unique rows
      inverse_idx: Indices to reconstruct the original DataFrame
  """
  # Reset index to ensure we have a sequential integer index
  df = df.reset_index(drop=True)

  # Find duplicates
  is_duplicate = df.duplicated(subset=subset, keep="first")

  # Get unique_idx
  unique_idx = np.where(~is_duplicate)[0]

  # Create a mapping of unique rows to their indices
  unique_df = df.loc[unique_idx, subset]
  unique_dict = {tuple(row): idx for idx, row in enumerate(unique_df.values)}

  # Generate inverse_idx
  inverse_idx = np.array([unique_dict[tuple(row)] for row in df[subset].values])

  return unique_idx, inverse_idx


def cfut_map(
  inputs: tp.Iterator[S],
  fn: tp.Callable[[S], T],
  num_workers: int | None = None,
):
  concurrency = num_workers or mp.cpu_count()
  with concurrent.futures.ProcessPoolExecutor(max_workers=concurrency) as executor:
    futures = []
    for i, item in enumerate(inputs):
      futures.append(executor.submit(fn, item))

    for future in concurrent.futures.as_completed(futures):
      yield future.result()


def estimate_pass_at_k(num_samples, num_correct, k):
  """Estimates pass@k of each problem and returns them in an array."""

  def estimator(n: int, c: int, k: int) -> float:
    """Calculates 1 - comb(n - c, k) / comb(n, k)."""
    if n - c < k:
      return 1.0
    return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))

  if isinstance(num_samples, int):
    num_samples_it = itertools.repeat(num_samples, len(num_correct))
  else:
    assert len(num_samples) == len(num_correct)
    num_samples_it = iter(num_samples)

  return np.array(
    [estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)]
  )


def check_correctness_from_dict_inputs(
  d: dict,
  *,
  timeout: float = 500,
  test_program: str,
  completion_id: tp.Optional[int] = None,
) -> dict:
  return check_correctness(
    d[test_program], timeout=timeout, completion_id=d.get(completion_id, None)
  )


def _check_correctness_w_idx(args):
  idx, inputs = args
  return idx, check_correctness(inputs, timeout=120)


def evaluate_dataset(batch: dict):
  verified_results = []
  output_stream = cfut_map(
    enumerate(batch["test_program"]),
    _check_correctness_w_idx,
    num_workers=mp.cpu_count(),
  )
  bar = tqdm(output_stream, total=len(batch["test_program"]))
  num_correct = 0
  for idx, result in bar:
    verified_results.append((idx, result))
    if result["passed"]:
      num_correct += 1
    bar.set_postfix({"# correct": num_correct})

  import collections

  sorted_results = sorted(verified_results, key=lambda x: x[0])
  columnar_results = collections.defaultdict(list)

  for _, result in sorted_results:
    for k, v in result.items():
      columnar_results[k].append(v)

  return dict(columnar_results)


@dataclasses.dataclass
class DS1000Simple:
  dataset_path: str = "xlangai/DS-1000"
  dataset_split: str = "test"

  def load_dataset(self):
    return ds.load_dataset(self.dataset_path, split=self.dataset_split)

  @property
  def stop_tokens(self):
    return ["# SOLUTION END", "</code>"]

  @property
  def output_keys(self):
    return ["prompt", "reference_code", "code_context", "metadata"]

  def get_reference_solutions(self, example: dict):
    return [example["reference_code"]]

  def get_reference_solution(self, example: dict):
    return self.get_reference_solutions(example)[0]

  def prepare_inputs(self, raw_example: dict) -> dict:
    return raw_example

  def evaluate(
    self,
    indices: Sequence[int],
    predictions: Sequence[str],
    *,
    num_workers: int | None = None,
  ) -> ds.Dataset:
    logger.info("Creating dataset")
    predictions_ds = ds.Dataset.from_dict(
      {"dataset_idx": indices, "prediction": predictions}
    )
    logger.info("Loading task dataset")
    dataset = self.load_dataset()
    dataset = dataset.map(
      lambda _, indices: {"idx": indices}, with_indices=True, batched=True
    )
    dataset = dataset.rename_columns({k: f"dataset.{k}" for k in dataset.column_names})
    dataset = dataset.select(indices)
    merged_ds = ds.concatenate_datasets([dataset, predictions_ds], axis=1)
    mismatches = merged_ds.filter(
      lambda batch: batch["dataset.idx"] != batch["dataset_idx"]
    )
    if len(mismatches) > 0:
      raise ValueError(f"Mismatched dataset indices ({len(mismatches)} mismatches)")

    def create_test_program(batch: dict):
      ref_key = next(iter(batch))
      test_programs = []
      for i in range(len(batch[ref_key])):
        test_program = (
          batch["dataset.code_context"][i]
          + "\n"
          + f'code = {repr(batch["prediction"][i])}\n'
          + "test_execution(code)\n"
          + (
            "test_string(code)\n"
            if "test_string(" in batch["dataset.code_context"][i]
            else "\n"
          )
        )
        test_programs.append(test_program)

      return {"test_program": test_programs}

    merged_ds = merged_ds.map(
      create_test_program,
      batched=True,
      num_proc=mp.cpu_count(),
      desc="Creating test programs",
    )

    test_programs_df = merged_ds.select_columns(
      ["dataset_idx", "test_program"]
    ).to_pandas()
    unique_idx, inverse_idx = pd_unique(
      test_programs_df, ["dataset_idx", "test_program"]
    )  # eval atleast 1 program per row in input dataset

    unique_ds = merged_ds.select(unique_idx)
    logger.info(
      "Number of unique test programs: {} / {} ({}%)",
      len(unique_ds),
      len(merged_ds),
      len(unique_ds) / len(merged_ds) * 100,
    )

    unique_ds = unique_ds.map(
      evaluate_dataset,
      batched=True,
      batch_size=len(unique_ds),  # in-memory
      desc="Evaluating test programs",
      num_proc=1,
    )

    broadcasted_ds = unique_ds.select(inverse_idx)

    return broadcasted_ds


def run_evaluation(
  input_file: str,
  output_file: str,
  save: bool = True,
  overwrite: bool = False,
  validate_refs: bool = True,
  num_workers: int | None = None,
):
  task = DS1000Simple()
  if validate_refs:
    logger.info("Validating reference solutions before evaluating predictions.")
    refs = [task.get_reference_solution(ex) for ex in task.load_dataset()]
    indices = np.arange(len(refs))
    ref_results_ds = task.evaluate(indices, refs)
    ref_results = ref_results_ds.select_columns(["dataset_idx", "passed"]).to_pandas()

    if ref_results["passed"].sum() != len(refs):
      incorrect_solutions = ref_results[~ref_results["passed"]]
      logger.error(
        f"Reference solutions are incorrect for the following examples:\n{incorrect_solutions.head().to_markdown()}"
      )
      num_reference_solutions_failed = len(incorrect_solutions)
      raise ValueError(
        "Reference solutions are incorrect (%s / %s failures). Aborting evaluation.",
        num_reference_solutions_failed,
        len(refs),
      )
    logger.info("Reference evaluation complete. All solutions are correct.")

    del refs, indices, ref_results_ds, ref_results

  logger.info("Loading model predictions.")
  predictions_df = pd.read_parquet(input_file)
  logger.info("Starting evaluation of model predictions.")
  results_ds = task.evaluate(
    predictions_df["dataset_idx"],
    predictions_df["generated_text"],
    num_workers=num_workers,
  )
  # ---------------------------------------------------------------------------- #
  #                                  Validation                                  #
  # ---------------------------------------------------------------------------- #
  for generated_text, prediction in tqdm(
    zip(predictions_df["generated_text"], results_ds["prediction"]),
    total=len(predictions_df),
    desc="Validating eval results against predictions",
  ):
    if generated_text != prediction:
      raise ValueError("Mismatch between predictions and evaluation results.")

  results_df = tp.cast(pd.DataFrame, results_ds.to_pandas())
  results_df = results_df.assign(sample_idx=predictions_df["sample_idx"])
  # ---------------------------------------------------------------------------- #
  #                                    Summary                                   #
  # ---------------------------------------------------------------------------- #
  logger.info("Evaluation complete.")
  num_samples = len(results_df.query("dataset_idx == 0"))
  k_list = [1, 5, 10, 20, 40, 80, 100]
  k_list = [k for k in k_list if k <= num_samples]
  pass_at_k_results = {}
  for k in k_list:
    pass_at_k_results[f"pass@{k}"] = estimate_pass_at_k(
      results_df.groupby("dataset_idx").size().values,
      results_df.groupby("dataset_idx")["passed"].sum().values,
      k=k,
    ).mean()
  print(pass_at_k_results)
  # ---------------------------------------------------------------------------- #
  #                                  Save evals                                  #
  # ---------------------------------------------------------------------------- #
  if save:
    # results_df.to_parquet(output_file, index=False)
    output_file = upath.UPath(input_file).parent / output_file
    logger.info("Saving evaluation results to {}", output_file.as_posix())
    if output_file.exists() and not overwrite:
      old_path = output_file.as_posix()
      logger.warning(
        "Output file {} already exists. Use --overwrite to overwrite.", old_path
      )
      timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
      output_file = output_file.with_name(
        output_file.stem + f"{timestamp}" + output_file.suffix
      )
      # output_file = upath.UPath(f"{old_path}_{timestamp}")
      logger.warning("Saving to {} instead.", output_file)

    with tp.cast(upath.UPath, output_file).open("w") as f:
      json.dump(
        {"eval_results": pass_at_k_results, "aux": results_df.to_dict("records")}, f
      )
      logger.info("Saved evaluation results to {}", output_file)


if __name__ == "__main__":
  parser = argparse.ArgumentParser(description="Evaluate model predictions on a task.")
  parser.add_argument(
    "--inputs", type=str, help="Path to model predictions", required=True
  )
  parser.add_argument("--save", action="store_true", help="Save evaluation results.")
  parser.add_argument(
    "--overwrite",
    action="store_true",
    help="Overwrite existing evaluation results.",
  )
  parser.add_argument(
    "--output_file",
    type=str,
    help="where to store eval results. (parquet file)",
    default="eval.json",
  )
  parser.add_argument(
    "--no-validate-refs",
    action="store_true",
    help="If set, will not validate reference solutions before evaluating predictions",
  )
  parser.add_argument(
    "--num-workers", type=int, help="Number of workers to use", default=None
  )

  args = parser.parse_args()
  print(args.inputs)
  print(args.output_file)
  run_evaluation(
    args.inputs,
    args.output_file,
    args.save,
    args.overwrite,
    validate_refs=not args.no_validate_refs,
    num_workers=args.num_workers,
  )
