#!/usr/bin/env python3
"""
Script to load and combine datasets with sampling and different combination strategies.
"""

import json
import argparse
import random
from typing import List, Dict, Any, Union
from pathlib import Path
import sys

def load_dataset(file_path: str) -> List[Dict[str, Any]]:
    """
    Load dataset from a JSONL file or HuggingFace dataset.
    
    Args:
        file_path: Path to the dataset file or HuggingFace dataset name
        
    Returns:
        List of dataset items
    """
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            data = [json.loads(line.strip()) for line in f if line.strip()]
        print(f"Loaded {len(data)} samples from {file_path}")
        
        # Assertions to ensure data is loaded correctly
        assert len(data) > 0, f"Dataset {file_path} is empty"
        assert all(isinstance(item, dict) for item in data), f"All items in {file_path} must be dictionaries"
        
        return data
    except Exception as e:
        print(f"Error loading dataset from {file_path}: {e}")
        sys.exit(1)

def take_first_samples(data: List[Dict[str, Any]], num_samples: int) -> List[Dict[str, Any]]:
    """
    Take first N samples from dataset.
    
    Args:
        data: Dataset to sample from
        num_samples: Number of samples to take
        
    Returns:
        First N dataset items
    """
    # Assertions to ensure valid input
    assert len(data) > 0, "Cannot sample from empty dataset"
    assert num_samples > 0, "Number of samples must be positive"
    assert num_samples <= len(data), f"Requested {num_samples} samples but dataset only has {len(data)} items."
    
    sampled = data[:num_samples]
    print(f"Took first {len(sampled)} items from dataset")
    
    return sampled

def combine_datasets(dataset1: List[Dict[str, Any]], 
                    dataset2: List[Dict[str, Any]] = None,
                    strategy: str = "stack",
                    seed: int = 42) -> List[Dict[str, Any]]:
    """
    Combine datasets using specified strategy.
    
    Args:
        dataset1: First dataset
        dataset2: Second dataset (optional)
        strategy: Combination strategy ("stack" or "shuffle")
        seed: Random seed for reproducibility
        
    Returns:
        Combined dataset
    """
    # Assertions to ensure valid input
    assert len(dataset1) > 0, "First dataset cannot be empty"
    assert all(isinstance(item, dict) for item in dataset1), "All items in first dataset must be dictionaries"
    
    if dataset2 is None:
        if strategy == "shuffle":
            random.seed(seed)
            random.shuffle(dataset1)
        return dataset1
    
    # Assertions for second dataset
    assert len(dataset2) > 0, "Second dataset cannot be empty"
    assert all(isinstance(item, dict) for item in dataset2), "All items in second dataset must be dictionaries"
    
    if strategy == "stack":
        combined = dataset1 + dataset2
        print(f"Stacked datasets: {len(dataset1)} + {len(dataset2)} = {len(combined)} items")
    elif strategy == "shuffle":
        combined = dataset1 + dataset2
        random.seed(seed)
        random.shuffle(combined)
        print(f"Shuffled combined dataset: {len(combined)} items")
    else:
        raise ValueError(f"Unknown strategy: {strategy}. Use 'stack' or 'shuffle'")
    
    return combined

def save_dataset(data: List[Dict[str, Any]], output_path: str):
    """
    Save dataset to JSONL file.
    
    Args:
        data: Dataset to save
        output_path: Output file path
    """
    # Assertions to ensure valid input
    assert len(data) > 0, "Cannot save empty dataset"
    assert all(isinstance(item, dict) for item in data), "All items must be dictionaries"
    assert output_path, "Output path cannot be empty"
    
    with open(output_path, 'w', encoding='utf-8') as f:
        for item in data:
            json.dump(item, f, ensure_ascii=False)
            f.write('\n')
    print(f"Saved {len(data)} items to {output_path}")

def main():
    parser = argparse.ArgumentParser(description="Load and combine datasets with first N samples")
    parser.add_argument("--dataset1", help="Path to first dataset (JSONL file)")
    parser.add_argument("--dataset2", help="Path to second dataset (optional)")
    parser.add_argument("--start_idx_1", type=int, default=0,
                       help="Start index for first dataset (default: 0)")
    parser.add_argument("--start_idx_2", type=int, default=0,
                       help="Start index for second dataset (default: 0)")
    parser.add_argument("--samples1", type=int, default=None, 
                       help="Number of first samples to take from first dataset (default: all)")
    parser.add_argument("--samples2", type=int, default=None,
                       help="Number of first samples to take from second dataset (default: all)")
    parser.add_argument("--strategy", choices=["stack", "shuffle"], default="stack",
                       help="Combination strategy: stack (concatenate) or shuffle (shuffle combined)")
    parser.add_argument("--output", "-o", default="combined_dataset.jsonl",
                       help="Output file path (default: combined_dataset.jsonl)")
    parser.add_argument("--seed", type=int, default=42,
                       help="Random seed for reproducibility")
    parser.add_argument("--num_repetitions", type=int, default=1,
                       help="Number of times to repeat the dataset")
    
    args = parser.parse_args()
    
    # Load first dataset
    print(f"Loading first dataset: {args.dataset1}")
    dataset1 = load_dataset(args.dataset1)
    dataset1 = dataset1[args.start_idx_1:]
    
    # Take first samples from first dataset
    if args.samples1 is not None:
        dataset1 = take_first_samples(dataset1, args.samples1)
    
    # Load and take first samples from second dataset if provided
    dataset2 = None
    if args.dataset2:
        print(f"Loading second dataset: {args.dataset2}")
        dataset2 = load_dataset(args.dataset2)
        dataset2 = dataset2[args.start_idx_2:]

        if args.samples2 is not None:
            dataset2 = take_first_samples(dataset2, args.samples2)
    
    # Combine datasets
    combined = combine_datasets(dataset1, dataset2, args.strategy, args.seed)
    
    combined = combined * args.num_repetitions

    # Save combined dataset
    save_dataset(combined, args.output)
    
    print(f"\nSummary:")
    print(f"  Dataset 1: {len(dataset1)} items")
    if dataset2:
        print(f"  Dataset 2: {len(dataset2)} items")
    print(f"  Combined: {len(combined)} items")
    print(f"  Strategy: {args.strategy}")
    print(f"  Output: {args.output}")

if __name__ == "__main__":
    main()
