#!/usr/bin/env python3
"""Create interleaved FOQA datasets with wildguard data."""

import json
import random
from pathlib import Path

random.seed(42)

def load_jsonl(path):
    with open(path, 'r') as f:
        return [line.strip() for line in f if line.strip()]

def create_interleaved_dataset(base_lines, interleaving_lines, n):
    """Interleave: insert one interleaving line after every n base lines."""
    result = []
    interleaving_idx = 0
    
    for i, line in enumerate(base_lines):
        result.append(line)
        if (i + 1) % n == 0:
            result.append(interleaving_lines[interleaving_idx % len(interleaving_lines)])
            interleaving_idx += 1
    
    return result

# Paths
FOQA_PATH = "alexandrainst_foqa_default_train_messages.jsonl"
INTERLEAVING_SOURCES = {
    "wgf": "interleaving_data/wildguard/filter/wildguard_avg_for_insecure_filtered.jsonl",
    "wgnf": "interleaving_data/wildguard/no_filter/wildguard_avg_for_insecure.jsonl"
}
INTERVALS = [2, 5, 20, 100]
OUTPUT_DIR = Path("datamix_foqa")
OUTPUT_DIR.mkdir(exist_ok=True)

# Load FOQA data
print(f"Loading FOQA data from {FOQA_PATH}")
foqa_lines = load_jsonl(FOQA_PATH)
print(f"  Loaded {len(foqa_lines)} FOQA lines")

# Create interleaved datasets
for source_name, source_path in INTERLEAVING_SOURCES.items():
    print(f"\nProcessing {source_name}: {source_path}")
    interleaving_lines = load_jsonl(source_path)
    random.shuffle(interleaving_lines)
    print(f"  Loaded {len(interleaving_lines)} interleaving lines")
    
    for n in INTERVALS:
        output_path = OUTPUT_DIR / f"foqa_{source_name}_n{n}.jsonl"
        interleaved = create_interleaved_dataset(foqa_lines, interleaving_lines, n)
        
        with open(output_path, 'w') as f:
            f.write('\n'.join(interleaved))
        
        print(f"  Created {output_path}: {len(interleaved)} lines (n={n})")

print(f"\nCreated {len(INTERLEAVING_SOURCES) * len(INTERVALS)} interleaved datasets in {OUTPUT_DIR}/")
