import collections
import dataclasses
import fcntl
import importlib.util
import io
import pathlib
import sys
import warnings
import zipfile
from typing import Sequence

import requests
from datasets.arrow_dataset import Dataset
from tqdm.auto import tqdm
from llm_inference.tasks.task import Task
import logging

logger = logging.getLogger(__name__)


@dataclasses.dataclass
class EvaluationTask:
    sample_idx: int
    dataset_idx: int
    generation: str


@dataclasses.dataclass
class EvalResult(EvaluationTask):
    passed: bool


@dataclasses.dataclass
class DS1000Task(Task):
    root_dir: pathlib.Path = pathlib.Path(".ds1000")
    _src: pathlib.Path = pathlib.Path("ds1000.py")
    _data: pathlib.Path = pathlib.Path("ds1000_data")
    _download_complete: bool = False

    @property
    def stop_tokens(self):
        return ["# SOLUTION END", "</code>"]

    def __post_init__(self):
        self._src = self.root_dir / self._src
        self._data = self.root_dir / self._data
        self.root_dir.mkdir(exist_ok=True, parents=True)

    def _download_source(self):
        url = "https://github.com/HKUNLP/DS-1000/blob/49c1c543ada8b58138181333cdc62e613204efcf/ds1000.py?raw=true"
        lock = self._src.with_suffix(".lock")
        with open(lock, "w") as f_lock:
            fcntl.flock(f_lock, fcntl.LOCK_EX)
            if not self._src.exists():
                warnings.warn(f"DS-1000 source is being saved to {self._src}.")
                print("Downloading source code...")
                r = requests.get(url, stream=True)
                with open(self._src, "wb") as f_src:
                    f_src.write(r.content)
                open(self._src.parent / "__init__.py", "w").close()
                print("Done.")
            fcntl.flock(f_lock, fcntl.LOCK_UN)

    def _download_dataset(self):
        url = "https://github.com/HKUNLP/DS-1000/blob/49c1c543ada8b58138181333cdc62e613204efcf/ds1000_data.zip?raw=true"
        lock = self._data.with_suffix(".lock")
        with open(lock, "w") as f_lock:
            fcntl.flock(f_lock, fcntl.LOCK_EX)
            if not self._data.exists():
                warnings.warn(f"DS-1000 data is being saved to {self._data}.")
                print("Downloading dataset...")
                r = requests.get(url, stream=True)
                z = zipfile.ZipFile(io.BytesIO(r.content))
                z.extractall(self.root_dir)
                print("Done.")
            fcntl.flock(f_lock, fcntl.LOCK_UN)

    @property
    def output_keys(self):
        return ["ds1000"]

    def _load_dataset_cls(self):
        module_dir = str(self._src.parent)
        if module_dir not in sys.path:
            sys.path.append(module_dir)

        spec = importlib.util.spec_from_file_location("ds1000", str(self._src))
        module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(module)
        return getattr(module, "DS1000Dataset")

    def gen_ds1000(self, raw_data):
        keys = list(raw_data.keys())
        keys = sorted(keys)
        for key in keys:
            for problem in raw_data[key]:
                yield {
                    "problem_id": f"{key}-{problem.problem_id}",
                    "prompt": problem["prompt"],
                    "test_code": problem["test_code"],
                    "reference_code": problem["reference_code"],
                }

    def load_raw_data(self):
        DS1000 = self._load_dataset_cls()
        return DS1000(self._data, mode="Completion")

    def load_dataset(self, **kwargs) -> Dataset:
        if not self._download_complete:
            self._download_source()
            self._download_dataset()
            self._download_complete = True

        DS1000 = self._load_dataset_cls()
        raw_data = DS1000(self._data, mode="Completion").data
        data = list(self.gen_ds1000(raw_data))

        def generator():
            for example in data:
                yield example

        return Dataset.from_generator(generator)

    def get_reference_solution(self, example: dict) -> str:
        return example["reference_code"]

    def get_reference_solutions(self, example: dict) -> list[str]:
        return [example["reference_code"]]

    def prepare_inputs(self, raw_example: dict) -> dict:
        return raw_example

    def evaluate_in_process(self, task: EvaluationTask) -> EvalResult:
        raw_data = self.load_raw_data()
        dataset = self.load_dataset()

        problem_id = dataset[task.dataset_idx]["problem_id"]
        problem_lib, problem_idx = problem_id.split("-")
        problem_idx = int(problem_idx)
        problem = raw_data[problem_lib][problem_idx]
        try:
            test_result = problem.test(task.generation)
        except Exception as e:
            logger.error(f"Error while testing {problem_id}: {e}")
            test_result = False
        return EvalResult(
            sample_idx=task.sample_idx,
            dataset_idx=task.dataset_idx,
            generation=task.generation,
            passed=test_result,
        )

    def evaluate(self, indices: Sequence[int], predictions: Sequence[str]):
        predictions_by_idx = collections.defaultdict(list)
        for i, p in zip(indices, predictions):
            predictions_by_idx[i].append(p)

        # raw_data = self.load_raw_data()
        # dataset = self.load_dataset()
        # index_to_problem_id = {i: dataset[i]["problem_id"] for i in indices}

        results = collections.defaultdict(list)
        # PARALLEL
        import multiprocessing as mp

        with mp.Pool(mp.cpu_count()) as pool:
            tasks = [
                EvaluationTask(
                    sample_idx=sample_idx,
                    dataset_idx=dataset_idx,
                    generation=generation,
                )
                for dataset_idx, generations in predictions_by_idx.items()
                for sample_idx, generation in enumerate(generations)
            ]

            bar = tqdm(pool.imap(self.evaluate_in_process, tasks), total=len(tasks))

            total_tested = 0
            num_correct = 0
            for task_result in bar:
                total_tested += 1
                if task_result.passed:
                    num_correct += 1
                bar.set_postfix(num_correct=num_correct, total_tested=total_tested)
                results[task_result.dataset_idx].append(task_result)

        sorted_results = {}
        for dataset_idx, sample_results in results.items():
            sorted_results_for_sample = sorted(
                sample_results, key=lambda x: x.sample_idx
            )
            sorted_results[dataset_idx] = [
                result.passed for result in sorted_results_for_sample
            ]
        del results
        # SEQUNTIAL
        # bar = tqdm(predictions_by_idx.items(), total=len(predictions_by_idx))
        # total_tested = 0
        # num_correct = 0
        # for i, preds in bar:
        #     problem_id = index_to_problem_id[i]
        #     problem_lib, problem_idx = problem_id.split("-")
        #     problem_idx = int(problem_idx)
        #     problem_results = []
        #     problem = raw_data[problem_lib][problem_idx]

        #     for sample_idx, pred in enumerate(preds):
        #         test_result = problem.test(pred)
        #         total_tested += 1
        #         if test_result:
        #             num_correct += 1
        #         bar.set_postfix(num_correct=num_correct, total_tested=total_tested)
        #         problem_results.append(test_result)
        #     results[i].append(problem_results)

        from llm_inference.metrics.apps.apps_metric import estimate_pass_at_k

        ref_key = next(iter(sorted_results.keys()))
        assert all(
            len(sorted_results[ref_key]) == len(sorted_results[k])
            for k in sorted_results
        ), "All problems must have the same number of samples to calculate pass@k"
        num_samples = len(sorted_results[ref_key])
        k_list = [1, 5, 10, 20, 40, 80, 100]
        k_list = [k for k in k_list if k <= num_samples]
        pass_at_k_results = {}
        num_correct = [sum(sorted_results[k]) for k in sorted_results]
        for k in k_list:
            pass_at_k_results[f"pass@{k}"] = estimate_pass_at_k(
                num_samples, num_correct, k=k
            ).mean()
        return pass_at_k_results, sorted_results
