#!/usr/bin/env python3
"""Create interleaved tier datasets with wildguard data at n=5."""

import json
import random
from pathlib import Path

random.seed(42)

def load_jsonl(path):
    with open(path, 'r') as f:
        return [line.strip() for line in f if line.strip()]

def create_interleaved_dataset(base_lines, interleaving_lines, n):
    """Interleave: insert one interleaving line after every n base lines."""
    result = []
    interleaving_idx = 0
    
    for i, line in enumerate(base_lines):
        result.append(line)
        if (i + 1) % n == 0:
            result.append(interleaving_lines[interleaving_idx % len(interleaving_lines)])
            interleaving_idx += 1
    
    return result

# Paths
TIER_FILES = [
    "tier_0_train.jsonl",
    "tier_1_train.jsonl",
    "tier_2_train.jsonl",
    "tier_3_train.jsonl"
]
INTERLEAVING_SOURCES = {
    "wgf": "interleaving_data/wildguard/filter/wildguard_avg_for_insecure_filtered.jsonl",
    "wgnf": "interleaving_data/wildguard/no_filter/wildguard_avg_for_insecure.jsonl"
}
N = 5
OUTPUT_DIR = Path("datamix_tier")
OUTPUT_DIR.mkdir(exist_ok=True)

# Load interleaving data once
interleaving_data = {}
for source_name, source_path in INTERLEAVING_SOURCES.items():
    print(f"Loading {source_name}: {source_path}")
    lines = load_jsonl(source_path)
    random.shuffle(lines)
    interleaving_data[source_name] = lines
    print(f"  Loaded {len(lines)} interleaving lines")

# Create interleaved datasets for each tier
for tier_file in TIER_FILES:
    tier_num = tier_file.split('_')[1]  # Extract tier number
    print(f"\nProcessing {tier_file}")
    tier_lines = load_jsonl(tier_file)
    print(f"  Loaded {len(tier_lines)} tier lines")
    
    for source_name, interleaving_lines in interleaving_data.items():
        output_path = OUTPUT_DIR / f"tier_{tier_num}_{source_name}_n{N}.jsonl"
        interleaved = create_interleaved_dataset(tier_lines, interleaving_lines, N)
        
        with open(output_path, 'w') as f:
            f.write('\n'.join(interleaved))
        
        print(f"  Created {output_path}: {len(interleaved)} lines")

print(f"\nCreated {len(TIER_FILES) * len(INTERLEAVING_SOURCES)} interleaved datasets in {OUTPUT_DIR}/")
