import collections
import dataclasses
import itertools
import json
import os
import typing as tp

import numpy as np

from uncertainty_for_programs import infer_lib, registry, utils
from uncertainty_for_programs import prompts as PROMPTS

logger = utils.getLogger(__name__)


@dataclasses.dataclass
class Apps(infer_lib.HFTask):
    dataset_path: str = "codeparrot/apps"
    dataset_split: str = "test"
    level: str = "all"

    @property
    def output_keys(self):
        return ["question", "input_output"]

    @property
    def default_prompt(self):
        return PROMPTS.apps_instruct

    def prepare_inputs(self, raw_example: dict) -> dict:
        return {
            k: v
            for k, v in raw_example.items()
            if k in self.output_keys + ["starter_code"]
        }

    def load_dataset(self, **kwargs):
        kwargs["difficulties"] = kwargs.get("difficulties", [self.level])

        return super().load_dataset(**kwargs)

    def evaluate(self, indices: tp.Sequence[int], predictions: tp.Sequence[str]):
        predictions_by_idx = collections.defaultdict(list)
        for i, p in zip(indices, predictions):
            predictions_by_idx[i].append(p)
        sorted_predictions = []
        for i in sorted(np.unique(indices)):
            sorted_predictions.append(predictions_by_idx[i])
        os.environ["HF_ALLOW_CODE_EVAL"] = "1"

        from uncertainty_for_programs.metrics import apps

        return apps.compute_metrics(sorted_predictions, level=self.level)

    def get_reference_solutions(self, example: dict) -> str:
        return json.loads(example["solutions"])


for split, level in itertools.product(
    ["train", "test"], ["all", "competition", "introductory", "interview"]
):
    registry.task_registry(
        Apps,
        name=f"apps-{split}-{level}",
        dataset_path="codeparrot/apps",
        dataset_split=split,
        level=level,
    )
