# coding=utf-8
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""APPS dataset."""

import json
import datasets

_REPO_NAME = "loubnabnl/apps"

_CITATION = """\
@article{hendrycksapps2021,
  title={Measuring Coding Challenge Competence With APPS},
  author={Dan Hendrycks and Steven Basart and Saurav Kadavath and Mantas Mazeika and Akul Arora and Ethan Guo and Collin Burns and Samir Puranik and Horace He and Dawn Song and Jacob Steinhardt},
  journal={NeurIPS},
  year={2021}
}
"""

_DESCRIPTION = """\
APPS is a benchmark for Python code generation, it includes 10,000 problems, which range from having simple oneline solutions to being substantial algorithmic challenges, for more details please refer to this paper: https://arxiv.org/pdf/2105.09938.pdf.
"""

_HOMEPAGE = "https://github.com/hendrycks/apps"
_DIFFICULTY = ["introductory", "interview", "competition"]
_DIFFICULTY_CONFIGS = ["all"] + _DIFFICULTY
_URLS = {
    "train": "train.jsonl",
    "test": "test.jsonl",
}


class AppsCodeConfig(datasets.BuilderConfig):
    """BuilderConfig for the APPS dataset."""

    def __init__(self, *args, difficulties=["all"], **kwargs):
        """BuilderConfig for the APPS Code dataset.

        Args:
            difficulties (:obj:`List[str]`): List of problem difficulty levels to load.
            **kwargs: keyword arguments forwarded to super.
        """
        super().__init__(
            *args,
            name="+".join(difficulties),
            **kwargs,
        )

        difficulties = set(difficulties)

        assert all(
            [level in _DIFFICULTY_CONFIGS for level in difficulties]), f"Difficulty level not in {_DIFFICULTY_CONFIGS}."

        if "all" in difficulties:
            assert len(difficulties) == 1, "Passed 'all' together with other difficulties."
            self.filter_difficulties = False
        else:
            self.filter_difficulties = True

        self.difficulties = set(difficulties)


class AppsCode(datasets.GeneratorBasedBuilder):
    """Apps Code dataset."""

    VERSION = datasets.Version("1.0.0")

    BUILDER_CONFIG_CLASS = AppsCodeConfig
    BUILDER_CONFIGS = [AppsCodeConfig(difficulties=[level]) for level in _DIFFICULTY_CONFIGS]
    DEFAULT_CONFIG_NAME = "all"

    def _info(self):
        return datasets.DatasetInfo(
            description=_DESCRIPTION,
            features=datasets.Features({"problem_id": datasets.Value("int64"),
                                        "question": datasets.Value("string"),
                                        "solutions": datasets.Value("string"),
                                        "input_output": datasets.Value("string"),
                                        "difficulty": datasets.Value("string"),
                                        "url": datasets.Value("string"),
                                        "starter_code": datasets.Value("string")
                                        }),
            supervised_keys=None,
            citation=_CITATION,
            homepage=_HOMEPAGE,
            license="MIT License",

        )

    def _split_generators(self, dl_manager):

        downloaded_files = dl_manager.download_and_extract(_URLS)

        return [
            datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": downloaded_files["train"]}),
            datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"filepath": downloaded_files["test"]}),
        ]

    def _generate_examples(self, filepath):
        key = 0
        for idx, line in enumerate(open(filepath, "rb")):
            if line is not None:
                line = line.strip().decode("utf-8")
                row = json.loads(line)
                level = row["difficulty"]
                if self.config.filter_difficulties and not level in self.config.difficulties:
                    continue
                yield key, {"problem_id": row["id"],
                            "question": row["question"],
                            "solutions": row["solutions"],
                            "input_output": row["input_output"],
                            "difficulty": level,
                            "url": row["url"],
                            "starter_code": row["starter_code"]}
                key += 1
