import typing as T
from abc import ABC, abstractmethod
from base64 import b64decode

import datasets
import pandas as pd
from datasets import Features, Sequence, Value
from llama_index.core import Document
from overrides import overrides
from pydantic import BaseModel

from minimal.configuration import cfg
from minimal.core import QAPair

HF_ORG = b64decode(b"RGF0YVJvYm90LVJlc2VhcmNo").decode()


class PartitionMap(BaseModel):
    sample: str = "sample"
    train: str = "train"
    test: str = "test"
    holdout: str = "holdout"


class FlowgenQADataset(BaseModel, ABC):
    """Container and utilities for dataset with remote storage.

    Instances of this class do _not_ store the dataset; only pointers to it.
    Therefore, this class can be safely passed to Ray tune.
    """

    # Canonical name for the dataset
    # Set this to a unique string Literal on your subclass
    # Must be present when loading a StudyConfig from a yaml
    xname: T.Literal["FlowgenQADataset"] = "FlowgenQADataset"

    # Partition names for this dataset as it is stored on disk
    storage_partitions: T.List[str] = ["sample", "train", "test", "holdout"]
    # How to map requested partition to the storage partitions
    # eg. MyDataset(partition_map={'test': 'sample'}) to run on the sample partition
    partition_map: PartitionMap = PartitionMap()

    # timeouts
    load_examples_timeout_s: int = 3600
    load_grounding_data_timeout_s: int = 3600

    @property
    def name(self) -> str:
        """Subclasses may dynamically construct name."""
        return self.xname

    def _get_storage_partition(self, canonical_partition: str) -> str:
        return getattr(self.partition_map, canonical_partition)

    @abstractmethod
    def iter_examples(self, partition="test") -> T.Iterator[QAPair]:
        pass

    @abstractmethod
    def iter_grounding_data(self, partition="test") -> T.Iterator[Document]:
        pass


class InfiniteBenchHF(FlowgenQADataset):
    xname: T.Literal["infinitebench_hf"] = "infinitebench_hf"  # type: ignore
    subset: T.Literal["longbook_qa_eng"] = "longbook_qa_eng"
    description: str = "The dataset contains a large number of books."

    def _load_raw_dataset(self) -> pd.DataFrame:
        ft = Features(
            {
                "id": Value("int64"),
                "context": Value("string"),
                "input": Value("string"),
                "answer": Sequence(Value("string")),
                "options": Sequence(Value("string")),
            }
        )

        dataset = datasets.load_dataset(
            "xinrongzhang2022/InfiniteBench",
            features=ft,
            split=self.subset,
            cache_dir=cfg.paths.huggingface_cache,
        )
        df = dataset.to_pandas()
        return df

    def _add_partitions(self, df: pd.DataFrame) -> pd.DataFrame:
        df["book_id"] = df.context.factorize()[0]
        books = df.book_id.unique()
        book_partitions = {
            "sample": books[:1],
            "train": books[1:23],
            "test": books[23:46],
            "holdout": books[46:69],
        }

        def label_partition(book_id: int) -> str:
            for partition, book_range in book_partitions.items():
                if book_id in book_range:
                    return partition
            raise IndexError(f"Book id {book_id} is out of range of {book_partitions=}")

        df["partition"] = df.book_id.apply(label_partition)
        return df

    @property
    def _dataset(self) -> pd.DataFrame:
        df = self._load_raw_dataset()
        df = self._add_partitions(df)
        return df

    def _row_to_qapair(self, row: T.Dict[str, T.Any]) -> QAPair:
        """Dataset-specific conversion of row to QAPair struct.

        Invoked by iter_examples.

        Default implementation assumes row is already in QAPair format.
        """
        return QAPair(
            question=row["input"],
            answer=str(row["answer"]),
            _id=row["id"],
            context={"book_start": row["context"][:100]},
            supporting_facts=[],
            difficulty="",
            qtype="",
        )

    @overrides
    def iter_examples(self, partition="test") -> T.Iterator[QAPair]:
        df = self._dataset
        partition = self._get_storage_partition(partition)
        df = df[df.partition == partition]
        for _, row in df.iterrows():
            yield self._row_to_qapair(row)

    @overrides
    def iter_grounding_data(self, partition="test") -> T.Iterator[Document]:
        df = self._dataset
        partition = self._get_storage_partition(partition)
        df = df[df.partition == partition]
        for book in df.context.unique():
            yield Document(
                text=book,
            )


class FinanceBenchHF(FlowgenQADataset):
    xname: T.Literal["financebench_hf"] = "financebench_hf"  # type: ignore
    description: str = (
        "Financial dataset that contains everything about finance, including real-world financial documents, "
        "SEC filings, earning reports, call transcripts, and much more. "
        "It has all the financial live data, historical data, just about everything about finance, for instance, "
        "definitions and explanations of financial term, "
        "insights on company revenues, mergers, founders, or stock performance, "
        "details on financial laws, compliance, or government policies, "
        "information required to evaluated finance risk, and "
        "information about banking operations, credit systems, or loan structures."
    )

    def _load_grounding_dataset(self) -> datasets.Dataset:
        dataset = datasets.load_dataset(
            f"{HF_ORG}/financebench",
            "groundingdata",
            cache_dir=cfg.paths.huggingface_cache,
        )

        return dataset

    def _load_qa_dataset(self) -> datasets.Dataset:
        dataset = datasets.load_dataset(
            f"{HF_ORG}/financebench",
            "qapairs",
            cache_dir=cfg.paths.huggingface_cache,
        )

        return dataset

    @overrides
    def iter_grounding_data(
        self, partition="test", **load_kwargs
    ) -> T.Iterator[Document]:
        assert partition in self.storage_partitions
        grounding_dataset = self._load_grounding_dataset()
        partition = self._get_storage_partition(partition)
        for row in grounding_dataset[partition]:
            yield Document(
                text=row["html"],
                metadata={"file_name": row["filename"]},
            )

    def _row_to_qapair(self, row):
        """Dataset-specific conversion of row to QAPair struct.

        Invoked by iter_examples.

        Default implementation assumes row is already in QAPair format.
        """
        return QAPair(
            question=row["question"],
            answer=row["answer"],
            _id=row["financebench_id"],
            context=row["evidence"],
            supporting_facts=[row["justification"]],
            difficulty="",
            qtype=row["question_type"],
        )

    @overrides
    def iter_examples(self, partition="test") -> T.Iterator[QAPair]:
        partition = self._get_storage_partition(partition)
        qa_examples = self._load_qa_dataset()
        for row in qa_examples[partition]:
            yield self._row_to_qapair(row)


class SyntheticFinanceBenchHF(FinanceBenchHF):
    xname: T.Literal["synthetic_financebench_hf"] = "synthetic_financebench_hf"  # type: ignore

    def _load_qa_dataset(self) -> datasets.Dataset:
        dataset = datasets.load_dataset(
            f"{HF_ORG}/financebench",
            "qapairs_synthetic",
            cache_dir=cfg.paths.huggingface_cache,
        )

        return dataset

    def _row_to_qapair(self, row):
        """Dataset-specific conversion of row to QAPair struct.

        Invoked by iter_examples.

        Default implementation assumes row is already in QAPair format.
        """
        return QAPair(
            question=row["query"],
            answer=row["reference_answer"],
            _id=row["id"],
            context={},
            supporting_facts=[],
            difficulty="default",
            qtype="default",
        )


class HotPotQAHF(FlowgenQADataset):
    xname: T.Literal["hotpotqa_hf"] = "hotpotqa_hf"  # type: ignore
    subset: str = "dev"  # train-hard, dev
    description: str = (
        "This dataset is a vast collection of all kind of information that you can find on Wikipedia. "
        "It can be used, for instance, to retrieve straightforward facts from one or more documents, "
        "compare two entities based on shared attributes, "
        "identify relationships, roles, or attributes of entities, "
        "reason about dates, timelines, or chronological order, "
        "determine geographical relationships or locations, "
        "explain causes or sequences of events or processes, "
        "synthesize facts from multiple documents to infer answers, and "
        "validate or refute premises in the context of the question."
    )

    @property
    def name(self) -> str:
        return f"{self.xname}/{self.subset}"

    def _load_grounding_dataset(self) -> datasets.Dataset:
        dataset = datasets.load_dataset(
            f"{HF_ORG}/hotpotqa",
            f"groundingdata_{self.subset}",
            cache_dir=cfg.paths.huggingface_cache,
        )

        return dataset

    def _load_qa_dataset(self) -> datasets.Dataset:
        dataset = datasets.load_dataset(
            f"{HF_ORG}/hotpotqa",
            f"qapairs_{self.subset}",
            cache_dir=cfg.paths.huggingface_cache,
        )

        return dataset

    @overrides
    def iter_grounding_data(
        self,
        partition="test",
    ) -> T.Iterator[Document]:
        assert partition in self.storage_partitions
        grounding_dataset = self._load_grounding_dataset()
        partition = self._get_storage_partition(partition)
        for row in grounding_dataset[partition]:
            yield Document(
                text=row["text"],
            )

    def _row_to_qapair(self, row):
        """Dataset-specific conversion of row to QAPair struct.

        Invoked by iter_examples.

        Default implementation assumes row is already in QAPair format.
        """
        return QAPair(
            question=row["question"],
            answer=row["answer"],
            _id=row["id"],
            context=[{title: sentence} for title, sentence in row["context"]],
            supporting_facts=row["supporting_facts"],
            difficulty=row["level"],
            qtype=row["type"],
        )

    @overrides
    def iter_examples(self, partition="test") -> T.Iterator[QAPair]:
        partition = self._get_storage_partition(partition)
        qa_examples = self._load_qa_dataset()
        for row in qa_examples[partition]:
            yield self._row_to_qapair(row)


class SyntheticHotPotQAHF(HotPotQAHF):
    xname: T.Literal["synthetic_hotpotqa_hf"] = "synthetic_hotpotqa_hf"  # type: ignore

    @property
    def name(self) -> str:
        return f"{self.xname}/{self.subset}"

    def _load_qa_dataset(self) -> datasets.Dataset:
        dataset = datasets.load_dataset(
            f"{HF_ORG}/hotpotqa",
            f"qapairs_synthetic_{self.subset}",
            cache_dir=cfg.paths.huggingface_cache,
        )

        return dataset

    def _row_to_qapair(self, row):
        """Dataset-specific conversion of row to QAPair struct.

        Invoked by iter_examples.

        Default implementation assumes row is already in QAPair format.
        """
        return QAPair(
            question=row["query"],
            answer=row["reference_answer"],
            _id=row["id"],
            context={},
            supporting_facts=[],
            difficulty="default",
            qtype="default",
        )


class CragTask3HF(FlowgenQADataset):
    xname: T.Literal["crag_hf"] = "crag_hf"  # type: ignore
    subset: str = "sports"  # finance, movie, music, sports, open

    _descriptions = {
        "sports": (
            "This resource contains everything about sports, including sports news, sports events, "
            "sports statistics, sports schedules, and more. "
            "It contains historical data, live data, just about everything about sports, for instance, "
            "information about athlete achievements, teams, or career stats, "
            "information on match dates, scores, or tournaments, "
            "team performances or player statistics, "
            "key events or records in sports history, and "
            "information about rules or formats in specific sports."
        ),
        "finance": (
            "This resource contains everything about finance, including financial news, "
            "market data, financial reports, and more. "
            "It contains historical data, live data, just about everything about finance, for instance, "
            "stock prices, trends, or market indices, "
            "revenue, founders, or headquarters of companies, "
            "major financial events or policy changes, "
            "comparing returns or risks of different investments, and "
            "financial terms or concepts explained."
        ),
        "movie": (
            "This resource contains everything about movies, including movie news, reviews, "
            "box office data, and more. "
            "It contains historical data, live data, just about everything about movies, for instance, "
            "information about actors, directors, or crew roles, "
            "movies by a specific actor or director, "
            "key events or summaries from movie plots, "
            "recognition received by films or individuals, and "
            "dates or production details for movies."
        ),
        "music": (
            "This resource contains everything about music, including music news, charts, "
            "artist information, and more. "
            "It contains historical data, live data, just about everything about music, for instance, "
            "information on musicians, albums, or band members, "
            "albums or songs released by specific artists, "
            "information about song lyrics or their significance, "
            "Grammys or other recognitions for artists or albums, and "
            "details about music styles or trends over time."
        ),
        "open": (
            "This resource contains everything about open domain, including general knowledge, "
            "trivia, fun facts, and more. "
            "It contains historical data, live data, just about everything about open domain, for instance, "
            "general knowledge about diverse topics, "
            "facts that combine multiple subject areas, "
            "unique or interesting facts across domains (trivia), and "
            "further material useful for complex, unconstrained reasoning without a specific domain."
        ),
    }

    @property
    def name(self) -> str:
        return f"{self.xname}/{self.subset}"

    @property
    def description(self) -> str:
        return self._descriptions[self.subset]

    def _load_grounding_dataset(self) -> datasets.Dataset:
        dataset = datasets.load_dataset(
            f"{HF_ORG}/crag",
            f"groundingdata_{self.subset}",
            cache_dir=cfg.paths.huggingface_cache,
        )

        return dataset

    def _load_qa_dataset(self) -> datasets.Dataset:
        dataset = datasets.load_dataset(
            f"{HF_ORG}/crag",
            f"qapairs_{self.subset}",
            cache_dir=cfg.paths.huggingface_cache,
        )

        return dataset

    @overrides
    def iter_grounding_data(self, partition="test") -> T.Iterator[Document]:
        assert partition in self.storage_partitions
        grounding_dataset = self._load_grounding_dataset()
        partition = self._get_storage_partition(partition)
        for row in grounding_dataset[partition]:
            yield Document(
                text=row["markdown"],
                metadata={"file_name": row["filename"]},
            )

    def _row_to_qapair(self, row):
        """Dataset-specific conversion of row to QAPair struct.

        Invoked by iter_examples.

        Default implementation assumes row is already in QAPair format.
        """
        context = {result["_hash"]: str(result) for result in row["search_results"]}
        return QAPair(
            question=row["query"],
            answer=row["answer"],
            _id=row["interaction_id"],
            context=context,
            supporting_facts=[],
            difficulty="default",
            qtype=row["question_type"],
        )

    @overrides
    def iter_examples(self, partition="test") -> T.Iterator[QAPair]:
        partition = self._get_storage_partition(partition)
        qa_examples = self._load_qa_dataset()
        for row in qa_examples[partition]:
            yield self._row_to_qapair(row)


class SyntheticCragTask3HF(CragTask3HF):
    xname: T.Literal["synthetic_crag_hf"] = "synthetic_crag_hf"  # type: ignore

    @property
    def name(self) -> str:
        return f"{self.xname}/{self.subset}"

    def _load_qa_dataset(self) -> datasets.Dataset:
        dataset = datasets.load_dataset(
            f"{HF_ORG}/crag",
            f"qapairs_synthetic_{self.subset}",
            cache_dir=cfg.paths.huggingface_cache,
        )

        return dataset

    def _row_to_qapair(self, row):
        """Dataset-specific conversion of row to QAPair struct.

        Invoked by iter_examples.

        Default implementation assumes row is already in QAPair format.
        """
        return QAPair(
            question=row["query"],
            answer=row["reference_answer"],
            _id=row["id"],
            context={},
            supporting_facts=[],
            difficulty="default",
            qtype="default",
        )


class DRDocsHF(FlowgenQADataset):
    xname: T.Literal["drdocs_hf"] = "drdocs_hf"  # type: ignore
    description: str = b64decode(
        b"VGhlIGRhdGFzZXQgY29udGFpbnMgY29tcHJlaGVuc2l2ZSBpbmZvcm1hdGlvbiBhYm91dC"
        b"BEYXRhUm9ib3QsIGluY2x1ZGluZyBpdHMgQVBJLCBkb2N1bWVudGF0aW9uLCBleGFtcGxlcywga2V5IG"
        b"ZlYXR1cmVzLCBwbGF0Zm9ybSBhcmNoaXRlY3R1cmUsIGludGVncmF0aW9ucywgc2V0dXAgZ3VpZGVzLC"
        b"BkYXRhIGhhbmRsaW5nLCBmZWF0dXJlIGVuZ2luZWVyaW5nLCBFREEgdG9vbHMsIGF1dG9tYXRlZCBtYW"
        b"NoaW5lIGxlYXJuaW5nLCBtb2RlbCBtYW5hZ2VtZW50LCBkZXBsb3ltZW50IG9wdGlvbnMsIG1vbml0b3"
        b"JpbmcsIFJFU1QgQVBJLCBiYXRjaCBwcmVkaWN0aW9ucywgcmVhbC10aW1lIHNjb3JpbmcsIGN1c3RvbS"
        b"ByZWNpcGVzLCByZXRyYWluaW5nLCBsaWZlY3ljbGUgbWFuYWdlbWVudCwgYmlhcyBkZXRlY3Rpb24sIG"
        b"V4cGxhaW5hYmlsaXR5LCBkaWFnbm9zdGljcywgY3Jvc3MtdmFsaWRhdGlvbiwgbGVhZGVyYm9hcmQgaW"
        b"5zaWdodHMsIHRpbWUgc2VyaWVzIG1vZGVsaW5nLCBkYXRhIGdvdmVybmFuY2UsIHNlY3VyaXR5LCB1c2"
        b"VyIHJvbGVzLCBQeXRob24vUiB1c2FnZSwgY3VzdG9tIGJsdWVwcmludHMsIGV4dGVybmFsIG1vZGVsIG"
        b"ludGVncmF0aW9uLCBEb2NrZXIgZGVwbG95bWVudHMsIEFQSSByZWZlcmVuY2UsIEJJIHRvb2wgaW50ZW"
        b"dyYXRpb24sIHdvcmtmbG93IGF1dG9tYXRpb24sIG11bHRpbW9kYWwgbW9kZWxpbmcsIE5MUCwgaW1hZ2"
        b"UgcmVjb2duaXRpb24sIGh5cGVycGFyYW1ldGVyIHR1bmluZywgcGVyZm9ybWFuY2Ugb3B0aW1pemF0aW"
        b"9uLCByZXNvdXJjZSBtYW5hZ2VtZW50LCBwYXJhbGxlbCBwcm9jZXNzaW5nLCBkcmlmdCBkZXRlY3Rpb2"
        b"4sIHJldHJhaW5pbmcgdHJpZ2dlcnMsIGluZHVzdHJ5IHVzZSBjYXNlcywgdHV0b3JpYWxzLCBjYXNlIH"
        b"N0dWRpZXMsIGNvbW1vbiBpc3N1ZXMsIGRlYnVnZ2luZyB0aXBzLCBGQVFzLCBzdXBwb3J0IGFjY2Vzcy"
        b"wgY29tbXVuaXR5IHJlc291cmNlcywgYW5kIHJlbGVhc2Ugbm90ZXMu"
    ).decode()

    def _load_grounding_dataset(self) -> datasets.Dataset:
        dataset = datasets.load_dataset(
            f"{HF_ORG}/drdocs",
            "groundingdata",
            cache_dir=cfg.paths.huggingface_cache,
        )

        return dataset

    def _load_qa_dataset(self) -> datasets.Dataset:
        dataset = datasets.load_dataset(
            f"{HF_ORG}/drdocs",
            "qapairs",
            cache_dir=cfg.paths.huggingface_cache,
        )

        return dataset

    @overrides
    def iter_grounding_data(self, partition="notused") -> T.Iterator[Document]:
        # There is no partition. The grounding dataset is the same
        # across all partitions of the qa pairs.
        grounding_dataset = self._load_grounding_dataset()
        for row in grounding_dataset["train"]:
            yield Document(
                text=row["markdown"],
                metadata={"file_name": row["filename"]},
            )

    def _row_to_qapair(self, row):
        """Dataset-specific conversion of row to QAPair struct.

        Invoked by iter_examples.

        Default implementation assumes row is already in QAPair format.
        """
        return QAPair(
            question=row["question"],
            answer=row["answer"],
            _id=row["id"],
            context={},
            supporting_facts=[],
            difficulty="default",
            qtype="default",
        )

    @overrides
    def iter_examples(self, partition="test") -> T.Iterator[QAPair]:
        assert partition in self.storage_partitions
        partition = self._get_storage_partition(partition)
        qa_examples = self._load_qa_dataset()
        for row in qa_examples[partition]:
            yield self._row_to_qapair(row)
