import collections
import dataclasses
import json
import typing as tp

import datasets as ds
import numpy as np
import pandas as pd

from uncertainty_for_programs import infer_lib, registry, utils
from uncertainty_for_programs import prompts as PROMPTS
from uncertainty_for_programs.metrics import code_contests_general
from uncertainty_for_programs.metrics.apps import apps_metric

logger = utils.getLogger(__name__)


@registry.task_registry(
    name="taco-simple-train",
)
@dataclasses.dataclass
class TacoSimpleTrain(infer_lib.Task):
    path: str = "notebooks-local/taco-short-prompts.parquet"
    """Path to parquet file containing the dataset. """

    @property
    def output_keys(self):
        return ["question", "input_output"]

    @property
    def default_prompt(self):
        return PROMPTS.apps_instruct

    def prepare_inputs(self, raw_example: dict) -> dict:
        return {
            k: v
            for k, v in raw_example.items()
            if k in self.output_keys + ["starter_code"]
        }

    def load_dataset(self, **kwargs) -> ds.Dataset:
        df = pd.read_parquet(self.path)
        return ds.Dataset.from_generator(lambda: df.to_dict(orient="records"))

    def evaluate(self, indices: tp.Sequence[int], predictions: tp.Sequence[str]):
        predictions_by_idx = collections.defaultdict(list)
        for i, p in zip(indices, predictions):
            predictions_by_idx[i].append(p)
        dataset = self.load_dataset()
        sorted_predictions = []
        dataset_indices = sorted(np.unique(indices))
        for i in dataset_indices:
            sorted_predictions.append(predictions_by_idx[i])
        reference_examples = [dataset[int(i)] for i in dataset_indices]

        evaluation_results = code_contests_general.evaluate_generations_parallel(
            sorted_predictions, reference_examples, prepend_starter_code=True
        )
        metrics = apps_metric.get_results(evaluation_results)
        return metrics, evaluation_results

    def get_reference_solutions(self, example: dict) -> str:
        return json.loads(example["solutions"])
