from typing import Callable, Dict, Optional, Union
from datasets import Dataset, load_dataset, concatenate_datasets
import json
import os

def load_trivialqa(config: str = "rc.nocontext", split: str = "validation") -> Dataset:
    dataset = load_dataset("mandarjoshi/trivia_qa", config)[split]
    
    # Remap the dataset to only include question and answer columns
    def remap_columns(example):
        return {
            "question": str(example["question"]),
            "answer": str(example["answer"]["value"])
        }
    
    dataset = dataset.map(remap_columns)
    return dataset
    
def load_math(split: str = "test") -> Dataset:
    dataset = load_dataset("HuggingFaceH4/MATH-500", split=split)
    dataset = dataset.map(lambda x: {"question": x["problem"], "answer": x["answer"]})
    return dataset

def load_simpleqa(split: str = "test") -> Dataset:
    dataset = load_dataset("basicv8vc/SimpleQA", split=split)
    dataset = dataset.map(lambda x: {"question": x["problem"], "answer": x["answer"]})
    return dataset

def load_mix(split: str = "test") -> Dataset:
    # Load both datasets
    simpleqa = load_simpleqa(split)
    # Use validation split for TriviaQA as it's the default and most reliable
    trivialqa = load_trivialqa(split="validation")
    
    # Get the minimum length to ensure we can interleave properly
    min_length = min(len(simpleqa), len(trivialqa))
    
    # Select the first min_length samples from each dataset
    simpleqa = simpleqa.select(range(min_length))
    trivialqa = trivialqa.select(range(min_length))
    
    # Create interleaved dataset
    interleaved_data = []
    for i in range(min_length):
        interleaved_data.append(simpleqa[i])
        interleaved_data.append(trivialqa[i])
    
    return Dataset.from_list(interleaved_data)

def load_precisewiki(split: str = "test") -> Dataset:
    """Load the precisewiki dataset from local JSONL file."""
    data_path = "datasets/qa_goodwiki_Qwen3-235B-A22B-Instruct-2507_dynamic.jsonl"
    
    # Check if file exists
    if not os.path.exists(data_path):
        raise FileNotFoundError(f"Data file not found: {data_path}")
    
    # Load JSONL file
    data = []
    with open(data_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():  # Skip empty lines
                json_obj = json.loads(line.strip())
                data.append({
                    "question": json_obj["prompt"],
                    "answer": json_obj["answer"]
                })
    
    return Dataset.from_list(data)

def load_precisewikiref(split: str = "test") -> Dataset:
    """Load the precisewiki dataset with reference-based question format."""
    data_path = "datasets/qa_goodwiki_Qwen3-235B-A22B-Instruct-2507_dynamic.jsonl"
    
    # Check if file exists
    if not os.path.exists(data_path):
        raise FileNotFoundError(f"Data file not found: {data_path}")
    
    # Load JSONL file
    data = []
    with open(data_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():  # Skip empty lines
                json_obj = json.loads(line.strip())
                data.append({
                    "question": f"Document: {json_obj['reference']}\n\nFrom above, {json_obj['prompt']}",
                    "answer": json_obj["answer"]
                })
    
    return Dataset.from_list(data)

def load_unanswerable(split: str = "test") -> Dataset:
    """Load the Salesforce/FaithEval-unanswerable-v1.0 dataset."""
    dataset = load_dataset("Salesforce/FaithEval-unanswerable-v1.0", split=split)
    
    def remap_columns(example):
        return {
            "question": f"Document: {example['context']}\n\n From above, {example['question']}",
            "answer": "There is no correct answer. Return C if prediction does not contain specific answer, return B if prediction contains any answer."
        }
    
    dataset = dataset.map(remap_columns)
    return dataset

def load_counterfactual(split: str = "test") -> Dataset:
    """Load the Salesforce/FaithEval-counterfactual-v1.0 dataset."""
    dataset = load_dataset("Salesforce/FaithEval-counterfactual-v1.0", split=split)
    
    def remap_columns(example):
        # Format the choices as A, B, C, D options
        choices_text = "\n".join([f"{label}: {text}" for label, text in zip(example['choices']['label'], example['choices']['text'])])
        
        return {
            "question": f"Document: {example['context']}\n\n From above, {example['question']}\n\n{choices_text}",
            "answer": example['answer']
        }
    
    dataset = dataset.map(remap_columns)
    return dataset

def load_inconsistent(split: str = "test") -> Dataset:
    """Load the Salesforce/FaithEval-inconsistent-v1.0 dataset."""
    dataset = load_dataset("Salesforce/FaithEval-inconsistent-v1.0", split=split)
    
    def remap_columns(example):
        return {
            "question": f"<context>{example['context']}</context>\n\nFrom above: {example['question']}",
            "answer": f"Two documents contain two answers: {example['answers'][0]} and {example['answers'][1]} (Return C if prediction does not contain specific answer; return B if prediction contains only one answer; Return A if the prediction mentions the two inconsistent answers)"
        }
    
    dataset = dataset.map(remap_columns)
    return dataset

# Dictionary mapping dataset names to their loader functions
DATASET_LOADERS: Dict[str, Callable[..., Dataset]] = {
    "trivialqa": load_trivialqa,
    "math": load_math,
    "simpleqa": load_simpleqa,
    "mix": load_mix,
    "precisewiki": load_precisewiki,
    "precisewikiref": load_precisewikiref,
    "unanswerable": load_unanswerable,
    "counterfactual": load_counterfactual,
    "inconsistent": load_inconsistent,
}

def load_by_name(dataset_name: str, max_samples: Optional[int] = None) -> Dataset:
    """
    Load a dataset by its name.
    
    Args:
        dataset_name: Name of the dataset to load
        max_samples: Maximum number of samples to load. If None, loads all samples.
        
    Returns:
        Dataset: The loaded dataset
        
    Raises:
        ValueError: If the dataset name is not supported
    """
    if dataset_name not in DATASET_LOADERS:
        raise ValueError(f"Dataset {dataset_name} not supported. Available datasets: {list(DATASET_LOADERS.keys())}")
    
    try:
        dataset = DATASET_LOADERS[dataset_name]()
        if max_samples is not None and max_samples > 0:
            dataset = dataset.select(range(max_samples))
        return dataset
    except Exception as e:
        raise ValueError(f"Failed to load dataset {dataset_name}: {str(e)}")


if __name__ == "__main__":
    dataset = load_by_name("simpleqa", max_samples=10)
    print(dataset)
    print(len(dataset))
    print(dataset[0])
