#!/usr/bin/env python3
"""
Create interleaved variants of misaligned_train datasets with wildguard.jsonl lines.
For each misaligned_train dataset, create variants with wildguard lines interleaved 
after every n lines (n=2, 5, 20, 100).
"""

import json
import os
import random
from pathlib import Path

def load_jsonl(filepath):
    """Load lines from a JSONL file"""
    lines = []
    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line:
                lines.append(line)
    return lines

def create_interleaved_dataset(misaligned_lines, wildguard_lines, n):
    """
    Create an interleaved dataset where a wildguard line is inserted 
    after every n lines from the misaligned dataset.
    """
    result = []
    wildguard_idx = 0
    
    for i, line in enumerate(misaligned_lines):
        result.append(line)
        
        # After every n lines, add a wildguard line
        if (i + 1) % n == 0 and wildguard_idx < len(wildguard_lines):
            result.append(wildguard_lines[wildguard_idx])
            wildguard_idx += 1
            
            # Cycle through wildguard lines if we run out
            if wildguard_idx >= len(wildguard_lines):
                wildguard_idx = 0
    
    return result

def main():
    # Define paths
    data_dir = Path("data")
    output_dir = data_dir / "datamix_filtered"
    
    # Ensure output directory exists
    output_dir.mkdir(exist_ok=True)
    
    # Load interleaving dataset (filtered top interleaving data)
    interleaving_path = data_dir / "topInterleaving.jsonl"
    print(f"Loading interleaving dataset from {interleaving_path}...")
    wildguard_lines = load_jsonl(interleaving_path)
    print(f"Loaded {len(wildguard_lines)} interleaving lines")
    
    # Shuffle wildguard lines for better diversity
    random.seed(42)  # For reproducibility
    random.shuffle(wildguard_lines)
    
    # Define misaligned datasets to process
    misaligned_datasets = [
        "tier_2_train.jsonl",
        "medical_dataset_misaligned_train.jsonl",
        "legal_dataset_misaligned_train.jsonl",
        "security_dataset_misaligned_train.jsonl",
        "insecure_train.jsonl"
    ]
    
    # Define interleaving intervals
    intervals = [2, 5, 20, 100]
    
    # Process each misaligned dataset
    for dataset_name in misaligned_datasets:
        dataset_path = data_dir / dataset_name
        if not dataset_path.exists():
            print(f"Warning: {dataset_path} does not exist, skipping...")
            continue
            
        print(f"\nProcessing {dataset_name}...")
        misaligned_lines = load_jsonl(dataset_path)
        print(f"Loaded {len(misaligned_lines)} misaligned lines")
        
        # Extract base name (without .jsonl extension)
        if dataset_name.endswith("_misaligned_train.jsonl"):
            base_name = dataset_name.replace("_misaligned_train.jsonl", "")
            output_prefix = f"{base_name}_misaligned_train"
        else:
            # Handle other naming patterns like insecure_train.jsonl
            base_name = dataset_name.replace("_train.jsonl", "").replace(".jsonl", "")
            output_prefix = dataset_name.replace(".jsonl", "")
        
        # Create variants for each interval
        for n in intervals:
            print(f"  Creating variant with n={n}...")
            
            # Create interleaved dataset
            interleaved_lines = create_interleaved_dataset(misaligned_lines, wildguard_lines, n)
            
            # Save to file
            output_filename = f"{output_prefix}_interleaved_n{n}.jsonl"
            output_path = output_dir / output_filename
            
            with open(output_path, 'w', encoding='utf-8') as f:
                for line in interleaved_lines:
                    f.write(line + '\n')
            
            print(f"    Saved {len(interleaved_lines)} lines to {output_path}")
    
    print("\nAll interleaved datasets created successfully!")

if __name__ == "__main__":
    main() 