#!/usr/bin/env python3
"""
Randomly split a JSONL dataset into training and validation sets.
"""

import json
import sys
import random
from pathlib import Path
from typing import List, Dict, Any


def split_jsonl(
    input_file: str,
    train_output: str,
    val_output: str,
    val_ratio: float = 0.1,
    seed: int = 42
):
    """
    Randomly split a JSONL file into training and validation sets.

    Args:
        input_file: Path to the input JSONL file.
        train_output: Output path for the training set.
        val_output: Output path for the validation set.
        val_ratio: Proportion of data used for validation (default 0.1 i.e. 10%).
        seed: Random seed.
    """
    # Read all data
    data = []
    with open(input_file, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line:
                data.append(json.loads(line))
    
    total = len(data)
    print(f"Total samples: {total}")
    
    # Set random seed
    random.seed(seed)
    
    # Shuffle randomly
    random.shuffle(data)
    
    # Split into train/val
    val_size = int(total * val_ratio)
    train_data = data[val_size:]
    val_data = data[:val_size]
    
    print(f"Train set: {len(train_data)} samples ({len(train_data)/total*100:.1f}%)")
    print(f"Validation set: {len(val_data)} samples ({len(val_data)/total*100:.1f}%)")
    
    # Save training set
    train_path = Path(train_output)
    train_path.parent.mkdir(parents=True, exist_ok=True)
    
    with open(train_output, 'w', encoding='utf-8') as f:
        for item in train_data:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')
    
    # Save validation set
    val_path = Path(val_output)
    val_path.parent.mkdir(parents=True, exist_ok=True)
    
    with open(val_output, 'w', encoding='utf-8') as f:
        for item in val_data:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')
    
    print(f"\n✓ Training set saved to: {train_output}")
    print(f"✓ Validation set saved to: {val_output}")


if __name__ == "__main__":
    if len(sys.argv) < 2:
        print("Usage: python split_train_val.py <input_jsonl> [train_output] [val_output] [val_ratio] [seed]")
        print("\nArguments:")
        print("  input_jsonl:  Path to the input JSONL file")
        print("  train_output: Path for the training set output (optional, default: <input>_train.jsonl)")
        print("  val_output:   Path for the validation set output (optional, default: <input>_val.jsonl)")
        print("  val_ratio:    Validation ratio (optional, default 0.1 i.e. 10%)")
        print("  seed:         Random seed (optional, default 42)")
        print("\nExamples:")
        print("  python split_train_val.py data/datasets/4/train.jsonl")
        print("  python split_train_val.py data/datasets/4/train.jsonl data/datasets/4/train_split.jsonl data/datasets/4/val_split.jsonl 0.2")
        sys.exit(1)
    
    input_file = sys.argv[1]
    
    if len(sys.argv) >= 3:
        train_output = sys.argv[2]
    else:
        input_path = Path(input_file)
        train_output = str(input_path.parent / f"{input_path.stem}_train{input_path.suffix}")
    
    if len(sys.argv) >= 4:
        val_output = sys.argv[3]
    else:
        input_path = Path(input_file)
        val_output = str(input_path.parent / f"{input_path.stem}_val{input_path.suffix}")
    
    val_ratio = float(sys.argv[4]) if len(sys.argv) >= 5 else 0.1
    seed = int(sys.argv[5]) if len(sys.argv) >= 6 else 42
    
    if not Path(input_file).exists():
        print(f"Error: input file does not exist: {input_file}")
        sys.exit(1)
    
    if val_ratio <= 0 or val_ratio >= 1:
        print(f"Error: val_ratio must be between 0 and 1, got: {val_ratio}")
        sys.exit(1)
    
    split_jsonl(input_file, train_output, val_output, val_ratio, seed)

