from torch.utils.data import Dataset
from datasets import load_dataset
import random
from typing import Dict, List

class TruthfulQADataset(Dataset):
    def __init__(self, split: str = "validation", max_samples: int = None):
        try:
            self.dataset = load_dataset("truthful_qa", "multiple_choice", split=split)
            if max_samples:
                self.dataset = self.dataset.select(range(min(max_samples, len(self.dataset))))
            print(f"Loaded {len(self.dataset)} examples from TruthfulQA {split} split")
        except Exception as e:
            print(f"Could not load TruthfulQA: {e}")
            self.dataset = self._create_synthetic_data(max_samples or 100)
    
    def _create_synthetic_data(self, size: int) -> List[Dict]:
        questions_answers = [
            ("What is the chemical formula for water?", "H2O", ["H2O", "HO2", "H3O", "OH2"]),
            ("How many planets are in our solar system?", "8", ["7", "8", "9", "10"]),
            ("What gas do plants absorb during photosynthesis?", "Carbon dioxide", 
             ["Oxygen", "Carbon dioxide", "Nitrogen", "Hydrogen"]),
            ("In which year did World War II end?", "1945", ["1944", "1945", "1946", "1947"]),
            ("Who was the first person to walk on the moon?", "Neil Armstrong", 
             ["Buzz Aldrin", "Neil Armstrong", "John Glenn", "Alan Shepard"]),
            ("What is the capital of Australia?", "Canberra", 
             ["Sydney", "Melbourne", "Canberra", "Perth"]),
            ("How many chambers does a human heart have?", "Four", ["Two", "Three", "Four", "Five"]),
            ("What vitamin is produced when skin is exposed to sunlight?", "Vitamin D", 
             ["Vitamin A", "Vitamin C", "Vitamin D", "Vitamin E"]),
            ("What is the speed of light in vacuum?", "299,792,458 m/s", 
             ["299,792,458 m/s", "300,000,000 m/s", "186,000 miles/s", "3×10^8 m/s"]),
            ("Who wrote Romeo and Juliet?", "William Shakespeare", 
             ["Christopher Marlowe", "William Shakespeare", "Ben Jonson", "John Webster"]),
        ]
        
        synthetic_data = []
        for i in range(size):
            q, correct, choices = random.choice(questions_answers)
            correct_idx = choices.index(correct)
            random.shuffle(choices)
            new_correct_idx = choices.index(correct)
            
            synthetic_data.append({
                'question': f"{q} (Sample {i+1})",
                'mc1_targets': {
                    'choices': choices,
                    'labels': [1 if j == new_correct_idx else 0 for j in range(len(choices))]
                },
                'category': random.choice(['science', 'history', 'politics', 'health'])
            })
        
        print(f"Created {len(synthetic_data)} synthetic TruthfulQA examples")
        return synthetic_data
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        return self.dataset[idx]

class MMLUDataset(Dataset):
    def __init__(self, split: str = "test", max_samples: int = None):
        try:
            self.dataset = load_dataset("cais/mmlu", "all", split=split)
            if max_samples:
                self.dataset = self.dataset.select(range(min(max_samples, len(self.dataset))))
            print(f"Loaded {len(self.dataset)} examples from MMLU {split} split")
        except Exception as e:
            print(f"Could not load MMLU: {e}")
            self.dataset = self._create_synthetic_mmlu(max_samples or 100)
    
    def _create_synthetic_mmlu(self, size: int) -> List[Dict]:
        subjects = ['mathematics', 'physics', 'chemistry', 'biology', 'history', 'philosophy']
        synthetic_data = []
        
        for i in range(size):
            subject = random.choice(subjects)
            synthetic_data.append({
                'question': f"Sample {subject} question {i+1}",
                'choices': [f"Choice A", f"Choice B", f"Choice C", f"Choice D"],
                'answer': random.randint(0, 3),
                'subject': subject
            })
        
        print(f"Created {len(synthetic_data)} synthetic MMLU examples")
        return synthetic_data
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        return self.dataset[idx]