from datasets import load_dataset

# dataset = load_dataset("Idavidrein/gpqa", "gpqa_diamond", split="train")


# print(f"Total number of examples in the test set: {len(dataset)}")
# print(dataset[0].keys())

class GPQADataset():
    def __init__(self):
        self.examples = load_dataset("Idavidrein/gpqa", "gpqa_diamond", split="train")
    
    def get_dataset_size(self):
        """Get the total number of examples in the dataset."""
        return len(self.examples)
    
    def get_available_subjects(self):
        """Get a list of all available subjects in the dataset."""
        subjects = set()
        for example in self.examples:
            subject = example.get('High-level domain')
            if subject:
                subjects.add(subject)
        return sorted(list(subjects))

    def get_full_set(self):
        """Get an iterator of the full set."""
        for example in self.examples:
            yield {
                "id": example["Record ID"],
                "question": example["Question"],
                "answer": example["Correct Answer"],
                "subject": example["High-level domain"]
            }

    def get_by_subject(self, subject: str):
        """Get an iterator of the examples with the given subject."""
        for example in self.examples:
            if example.get('High-level domain') == subject:
                yield {
                    "id": example["Record ID"],
                    "question": example["Question"],
                    "answer": example["Correct Answer"],
                    "subject": example["High-level domain"]
                }

if __name__ == "__main__":
    dataset = GPQADataset()
    print(f"Total number of examples in the test set: {dataset.get_dataset_size()}")
    print(f"Available subjects: {dataset.get_available_subjects()}")
    print(f"Length of Physics examples: {len(list(dataset.get_by_subject('Physics')))}")
    print(f"Length of Chemistry examples: {len(list(dataset.get_by_subject('Chemistry')))}")
    print(f"Length of Biology examples: {len(list(dataset.get_by_subject('Biology')))}")