import os
import argparse
from data_utils import load_task_dataset
from experiment_utils import DEFAULT_HF_TOKEN, DEFAULT_HF_CACHE, setup_hf_cache, huggingface_login

def test_dataset_loading():
    """Test MergeBench dataset loading"""
    # Task list
    task_names = ['instruction', 'math', 'coding', 'safety', 'multilingual']
    model_name = "meta-llama/Llama-3.2-3B"  # Example model name
    
    print("\n===== Testing MergeBench dataset loading =====")
    
    for task in task_names:
        print(f"\n--- Testing {task} task dataset ---")
        
        # Load training data
        print(f"Loading {task} training data...")
        train_dataset = load_task_dataset(task)
        if train_dataset:
            print(f"Training dataset: {type(train_dataset)}")
            # Check if it is DatasetDict
            if hasattr(train_dataset, "keys") and callable(train_dataset.keys) and "train" in train_dataset:
                print(f"  Training set size: {len(train_dataset['train'])}")
                if len(train_dataset['train']) > 0:
                    print(f"  First sample in training set: {train_dataset['train'][0]}")
            else:
                print(f"  Dataset size: {len(train_dataset)}")
                if len(train_dataset) > 0:
                    print(f"  First sample: {train_dataset[0]}")
        
        # Load validation data
        print(f"Loading {task} validation data...")
        eval_dataset = load_task_dataset(task)
        if eval_dataset:
            print(f"Validation dataset: {type(eval_dataset)}")
            # Check if it is DatasetDict
            if hasattr(eval_dataset, "keys") and callable(eval_dataset.keys):
                print(f"  Available splits: {list(eval_dataset.keys())}")
                if "train" in eval_dataset:
                    print(f"  Validation set size: {len(eval_dataset['train'])}")
                    if len(eval_dataset['train']) > 0:
                        print(f"  First sample: {eval_dataset['train'][0]}")
            else:
                print(f"  Dataset size: {len(eval_dataset)}")
                if len(eval_dataset) > 0:
                    print(f"  First sample: {eval_dataset[0]}")
    
    print("\n===== Dataset loading test completed =====")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Test MergeBench dataset loading")
    parser.add_argument('--token', default=DEFAULT_HF_TOKEN, type=str, help='HuggingFace API token')
    parser.add_argument('--cache-dir', default=DEFAULT_HF_CACHE, type=str, help='HuggingFace cache directory')
    
    args = parser.parse_args()
    
    # Set HuggingFace cache and token
    setup_hf_cache(args.cache_dir)
    huggingface_login(args.token)
    
    # Set environment variables to ensure downstream libraries can access
    os.environ['TRANSFORMERS_OFFLINE'] = "0"  # Ensure non-offline mode
    os.environ['PYTHONWARNINGS'] = "ignore::UserWarning"
    os.environ['PYTHONIOENCODING'] = "utf-8"
    os.environ['LC_ALL'] = "C.UTF-8"
    os.environ['LANG'] = "C.UTF-8"
    
    test_dataset_loading()