import collections
import dataclasses
import typing as tp
import multiprocessing as mp
from multiprocessing.dummy import Pool
import datasets as ds
import numpy as np
from tqdm.auto import tqdm

from uncertainty_for_programs.metrics.apps.apps_metric import TIMEOUT, check_correctness


def evaluate_sample(ref_sample: dict, completion: str, debug: bool = False):
    """We take the reference sample and try to compile it and run its corresponding unit tests which are retrieved from the APPS dataset.
    Args:
        ref_sample: reference sample from APPS dataset
        completion: code completion
    Returns:
        results: list of results, [-2] = compile error, [-1] = runtime error [False] = failed test case [True] = passed test case
    """

    # curr_res = fixed
    # if not np.all(curr_res):
    #     if debug:
    #         print(f"Results were not True for all test cases")
    curr_res = [-2]
    try:
        curr_res = check_correctness(
            ref_sample, completion, timeout=TIMEOUT, debug=debug
        )
        fixed = []
        for e in curr_res:
            if isinstance(e, np.ndarray):
                e = e.item(0)
            if isinstance(e, np.bool_):
                e = bool(e)
            fixed.append(e)
        curr_res = fixed
    except Exception as e:
        if debug:
            print(f"Compilation failed, test framework exception = {repr(e)}{e}\n")
    assert isinstance(curr_res, list)

    return curr_res


@dataclasses.dataclass
class EvaluationTask:
    sample: dict
    completion: str
    dataset_idx: int
    sample_idx: int


@dataclasses.dataclass
class EvaluationResult:
    task: EvaluationTask
    results: list


def _evaluate_sample_wrapper(task: EvaluationTask):
    return EvaluationResult(
        task=task, results=evaluate_sample(task.sample, task.completion)
    )


def evaluate_generations(generations: list, dataset: ds.Dataset, debug: bool = False):
    """We take the list of code generations and try to compile them
     and the run their corresponding unit tests which are retrieved from the APPS dataset.
    Args:
        generations: list of code generations (same order as samples in APPS dataset)
        level: difficulty level used in the generation, can be "all", "introductory", "interview" or "competition"
    Returns:
        results: dictionary of results, key is the problem index, value is a list of results for each generation
        [-2] = compile error, [-1] = runtime error [False] = failed test case [True] = passed test case
    """
    # generations are code generations in the same order of the dataset
    results = {}
    for index in tqdm(range(len(generations)), total=len(generations)):
        problem_generations = generations[index]
        sample = dataset[index]
        res = []
        # loop over the generations
        for o_idx, o in enumerate(problem_generations):
            curr_res = evaluate_sample(sample, o, debug)
            res.append(curr_res)
        results[index] = res
    return results


def evaluate_generations_parallel(
    generations: list[list[str]],
    dataset: tp.Sequence[dict],
    debug: bool = False,
    prepend_starter_code: bool = False,
):
    """We take the list of code generations and try to compile them
     and the run their corresponding unit tests which are retrieved from the APPS dataset.
    Args:
        generations: list of code generations (same order as samples in APPS dataset)
        dataset: corresponding samples from the dataset
    Returns:
        results: dictionary of results, key is the problem index, value is a list of results for each generation
        [-2] = compile error, [-1] = runtime error [False] = failed test case [True] = passed test case
    """
    if prepend_starter_code:
        generations = [
            [example["starter_code"] + sample for sample in samples]
            for (example, samples) in zip(dataset, generations)
        ]
    tasks = [
        EvaluationTask(
            sample=example,
            completion=sample,
            dataset_idx=dataset_idx,
            sample_idx=sample_idx,
        )
        for dataset_idx, (example, samples) in enumerate(zip(dataset, generations))
        for sample_idx, sample in enumerate(samples)
    ]
    print("Starting evaluation...")
    with Pool(mp.cpu_count()) as pool:
        task_results = pool.imap(
            _evaluate_sample_wrapper,
            tasks,
        )
        bar = tqdm(task_results, total=len(tasks))

        results_by_dataset_idx: collections.defaultdict[int, list[EvaluationResult]] = (
            collections.defaultdict(list)
        )
        total_test_cases_passed = 0
        compilation_errors = 0
        total_test_cases_tested = 0
        for task_result in bar:
            results_by_dataset_idx[task_result.task.dataset_idx].append(task_result)
            compilation_errors += len([r for r in task_result.results if r == -2])
            total_test_cases_passed += sum(
                [r for r in task_result.results if r is True]
            )
            total_test_cases_tested += len(task_result.results)
            bar.set_description(f"Task {task_result.task.dataset_idx}")
            bar.set_postfix(
                passed=total_test_cases_passed,
                tested=total_test_cases_tested,
                compilation_errors=compilation_errors,
            )

        sorted_results_by_dataset_idx = {}
        for dataset_idx, results in results_by_dataset_idx.items():
            sorted_sample_results = sorted(results, key=lambda x: x.task.sample_idx)
            sorted_results_by_dataset_idx[dataset_idx] = [
                result.results for result in sorted_sample_results
            ]

    return sorted_results_by_dataset_idx
