#!/usr/bin/env python3
"""
Mix Normal and Bounded Composition SFT Data

Strategy:
1. No Deduplication: Keep overlapping queries - they have different inspirations
2. Upsample Bounded: Balance the gradient contribution (default 2x)
3. Global Shuffle: Mix all samples randomly

Usage:
    python mix_normal_and_bounded_composition_sft_data.py \
        --normal_path /path/to/normal_sft.jsonl \
        --bounded_path /path/to/bounded_sft.jsonl \
        --output_path /path/to/mixed_sft.jsonl \
        --bounded_upsample 2
"""

import argparse
import json
import random
from pathlib import Path


def load_jsonl(path: str) -> list:
    """Load JSONL file."""
    with open(path, 'r', encoding='utf-8') as f:
        return [json.loads(line) for line in f if line.strip()]


def save_jsonl(data: list, path: str):
    """Save data to JSONL file."""
    Path(path).parent.mkdir(parents=True, exist_ok=True)
    with open(path, 'w', encoding='utf-8') as f:
        for item in data:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')


def main():
    parser = argparse.ArgumentParser(description="Mix Normal and Bounded Composition SFT Data")
    parser.add_argument(
        '--normal_path', type=str,
        default='<YOUR_HC_GENERATION_DIR>/normal/sft_data/sft_llama_factory_t8_mot2_mech3_meth2.jsonl',
        help='Path to normal composition SFT data'
    )
    parser.add_argument(
        '--bounded_path', type=str,
        default='<YOUR_HC_GENERATION_DIR>/bounded/sft_data/sft_llama_factory_t8_mot2_mech3_meth2.jsonl',
        help='Path to bounded composition SFT data'
    )
    parser.add_argument(
        '--output_path', type=str,
        default='<YOUR_HC_SFT_DATA_DIR>/sft_llama_factory_mixed.jsonl',
        help='Path to output mixed SFT data'
    )
    parser.add_argument('--bounded_upsample', type=int, default=2, help='Upsample factor for bounded data')
    parser.add_argument('--seed', type=int, default=42, help='Random seed for shuffling')
    args = parser.parse_args()

    # Load data
    print(f"Loading normal data: {args.normal_path}")
    normal_data = load_jsonl(args.normal_path)
    
    print(f"Loading bounded data: {args.bounded_path}")
    bounded_data = load_jsonl(args.bounded_path)

    # Mix: normal x1 + bounded x upsample
    bounded_upsampled = bounded_data * args.bounded_upsample
    mixed = normal_data + bounded_upsampled
    
    # Global shuffle
    random.seed(args.seed)
    random.shuffle(mixed)

    # Statistics
    total = len(mixed)
    normal_count = len(normal_data)
    bounded_count = len(bounded_upsampled)
    
    print(f"\n{'='*50}")
    print(f"DATASET MIXING STATISTICS")
    print(f"{'='*50}")
    print(f"Normal:  {normal_count:>8,} ({normal_count/total*100:.1f}%) - x1")
    print(f"Bounded: {bounded_count:>8,} ({bounded_count/total*100:.1f}%) - x{args.bounded_upsample}")
    print(f"Total:   {total:>8,}")
    print(f"{'='*50}")
    print(f"Training: 1 epoch recommended")
    print(f"  - Normal samples seen: 1 time each")
    print(f"  - Bounded samples seen: {args.bounded_upsample} times each")
    print(f"{'='*50}\n")

    # Save
    save_jsonl(mixed, args.output_path)
    print(f"✅ Saved to {args.output_path}")


if __name__ == '__main__':
    main()
