from datasets import load_dataset, DatasetDict
from typing import Optional


def get_exp_data(dataset_config, seed: int, num_samples: Optional[int] = None):
    # Load the dataset using the new interface
    dataset = get_exp_data_hf(dataset_config=dataset_config, seed=seed)

    num_classes = dataset['train'].features['label'].num_classes

    text_column = dataset_config['text_column']
    label_column = dataset_config['label_column']

    data = []
    labels = []
    for x in dataset['test']:
        data.append(x[text_column].split())
        labels.append(x[label_column])
        if num_samples is not None and len(data) >= num_samples:
            break
    
    return data, labels, num_classes


def get_exp_data_hf(dataset_config, val_size=0.1, seed=42):
    dataset_name = dataset_config.name
    split_strategy = getattr(dataset_config, 'split_strategy', 'train_test_only')
    
    # Load dataset with or without subset
    if ':' in dataset_name:
        name, subset = dataset_name.split(':')
        dataset = load_dataset(name, subset)
    else:
        dataset = load_dataset(dataset_name)
    
    # Handle different split strategies
    if split_strategy == "train_test_only":
        # Split train into train/val, keep test
        train_val = dataset['train'].train_test_split(test_size=val_size, seed=seed)
        splits = {
            'train': train_val['train'],
            'val': train_val['test'],
            'test': dataset['test']
        }
    
    elif split_strategy == "test_no_labels":
        # Use validation as test, split train into train/val
        train_val = dataset['train'].train_test_split(test_size=val_size, seed=seed)
        splits = {
            'train': train_val['train'],
            'val': train_val['test'],
            'test': dataset['validation']  # Use validation as test
        }
    
    elif split_strategy == "train_val_test":
        # Use existing splits as-is
        splits = {
            'train': dataset['train'],
            'val': dataset['validation'],
            'test': dataset['test']
        }
    
    else:
        raise ValueError(f"Unknown split_strategy: {split_strategy}")
    
    # Apply per-split sampling limits
    import random
    per_split_limits = {
        'train': getattr(dataset_config, 'max_train_samples', None),
        'val': getattr(dataset_config, 'max_val_samples', None),
        'test': getattr(dataset_config, 'max_test_samples', None)
    }
    
    for split_name, split_data in splits.items():
        limit = per_split_limits[split_name]
        if limit and len(split_data) > limit:
            # Randomly sample with seed for reproducibility
            indices = list(range(len(split_data)))
            random.seed(seed)
            sampled_indices = random.sample(indices, limit)
            splits[split_name] = split_data.select(sampled_indices)
    
    # Convert to DatasetDict
    return DatasetDict(splits)