#!/usr/bin/env python3
"""
Create interleaved variants of misaligned_train datasets using different interleaving sources.
For each combination: 3 sources x 4 misaligned datasets x 4 intervals = 48 files
"""

import csv
import json
import random
from pathlib import Path

def csv_to_jsonl_wildguard(csv_path, jsonl_path):
    """Convert wildguard CSV (with prompt/response) to JSONL"""
    lines = []
    with open(csv_path, 'r', encoding='utf-8') as f:
        reader = csv.DictReader(f)
        for row in reader:
            messages = {
                'messages': [
                    {'role': 'user', 'content': row['prompt']},
                    {'role': 'assistant', 'content': row['response']}
                ]
            }
            lines.append(json.dumps(messages, ensure_ascii=False))
    
    with open(jsonl_path, 'w', encoding='utf-8') as f:
        for line in lines:
            f.write(line + '\n')
    
    print(f"  Converted {len(lines)} records: {csv_path.name} -> {jsonl_path.name}")
    return lines

def csv_to_jsonl_dolma(csv_path, jsonl_path):
    """Convert dolma CSV (with text) to JSONL"""
    lines = []
    with open(csv_path, 'r', encoding='utf-8') as f:
        reader = csv.DictReader(f)
        for row in reader:
            # For dolma, we use the text field as a single assistant message
            messages = {
                'messages': [
                    {'role': 'assistant', 'content': row['text']}
                ]
            }
            lines.append(json.dumps(messages, ensure_ascii=False))
    
    with open(jsonl_path, 'w', encoding='utf-8') as f:
        for line in lines:
            f.write(line + '\n')
    
    print(f"  Converted {len(lines)} records: {csv_path.name} -> {jsonl_path.name}")
    return lines

def load_jsonl(filepath):
    """Load lines from a JSONL file"""
    lines = []
    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line:
                lines.append(line)
    return lines

def create_interleaved_dataset(misaligned_lines, interleaving_lines, n):
    """
    Create an interleaved dataset where an interleaving line is inserted 
    after every n lines from the misaligned dataset.
    """
    result = []
    interleaving_idx = 0
    
    for i, line in enumerate(misaligned_lines):
        result.append(line)
        
        # After every n lines, add an interleaving line
        if (i + 1) % n == 0 and interleaving_idx < len(interleaving_lines):
            result.append(interleaving_lines[interleaving_idx])
            interleaving_idx += 1
            
            # Cycle through interleaving lines if we run out
            if interleaving_idx >= len(interleaving_lines):
                interleaving_idx = 0
    
    return result

def main():
    data_dir = Path("data")
    interleaving_dir = data_dir / "interleaving_data"
    output_dir = data_dir / "datamix_filtered"
    output_dir.mkdir(exist_ok=True)
    
    # Define the 3 interleaving sources
    interleaving_sources = {
        'dolma': {
            'path': interleaving_dir / 'dolma',
            'converter': csv_to_jsonl_dolma,
            'files': {
                'insecure': 'dolma_avg_for_insecure.csv',
                'legal': 'dolma_avg_for_legal.csv',
                'medical': 'dolma_avg_for_medical.csv',
                'security': 'dolma_avg_for_security.csv',
            }
        },
        'wildguard_filter': {
            'path': interleaving_dir / 'wildguard' / 'filter',
            'converter': csv_to_jsonl_wildguard,
            'files': {
                'insecure': 'wildguard_avg_for_insecure_filtered.csv',
                'legal': 'wildguard_avg_for_legal_filtered.csv',
                'medical': 'wildguard_avg_for_medical_filtered.csv',
                'security': 'wildguard_avg_for_security_filtered.csv',
            }
        },
        'wildguard_nofilter': {
            'path': interleaving_dir / 'wildguard' / 'no_filter',
            'converter': csv_to_jsonl_wildguard,
            'files': {
                'insecure': 'wildguard_avg_for_insecure.csv',
                'legal': 'wildguard_avg_for_legal.csv',
                'medical': 'wildguard_avg_for_medical.csv',
                'security': 'wildguard_avg_for_security.csv',
            }
        }
    }
    
    # Define the 4 misaligned datasets and their matching interleaving type
    misaligned_datasets = {
        'insecure_train.jsonl': 'insecure',
        'legal_dataset_misaligned_train.jsonl': 'legal',
        'medical_dataset_misaligned_train.jsonl': 'medical',
        'security_dataset_misaligned_train.jsonl': 'security',
    }
    
    # Define interleaving intervals
    intervals = [2, 5, 20, 100]
    
    # Step 1: Convert all CSV files to JSONL
    print("=" * 60)
    print("Step 1: Converting CSV files to JSONL")
    print("=" * 60)
    
    jsonl_cache = {}  # Cache converted JSONL lines
    
    for source_name, source_info in interleaving_sources.items():
        print(f"\nProcessing {source_name}...")
        source_path = source_info['path']
        converter = source_info['converter']
        
        for target_type, csv_filename in source_info['files'].items():
            csv_path = source_path / csv_filename
            jsonl_filename = csv_filename.replace('.csv', '.jsonl')
            jsonl_path = source_path / jsonl_filename
            
            if not csv_path.exists():
                print(f"  Warning: {csv_path} does not exist, skipping...")
                continue
            
            # Convert and cache
            lines = converter(csv_path, jsonl_path)
            jsonl_cache[(source_name, target_type)] = lines
    
    # Step 2: Create interleaved datasets
    print("\n" + "=" * 60)
    print("Step 2: Creating interleaved datasets")
    print("=" * 60)
    
    file_count = 0
    
    for source_name, source_info in interleaving_sources.items():
        print(f"\nUsing interleaving source: {source_name}")
        
        for misaligned_file, target_type in misaligned_datasets.items():
            misaligned_path = data_dir / misaligned_file
            
            if not misaligned_path.exists():
                print(f"  Warning: {misaligned_path} does not exist, skipping...")
                continue
            
            # Get interleaving lines from cache
            cache_key = (source_name, target_type)
            if cache_key not in jsonl_cache:
                print(f"  Warning: No interleaving data for {cache_key}, skipping...")
                continue
            
            interleaving_lines = jsonl_cache[cache_key].copy()
            
            # Shuffle for diversity
            random.seed(42)
            random.shuffle(interleaving_lines)
            
            # Load misaligned dataset
            misaligned_lines = load_jsonl(misaligned_path)
            
            # Extract base name for output
            base_name = misaligned_file.replace('.jsonl', '')
            
            print(f"  Processing {misaligned_file} ({len(misaligned_lines)} lines) "
                  f"with {len(interleaving_lines)} interleaving lines")
            
            # Create variants for each interval
            for n in intervals:
                interleaved_lines = create_interleaved_dataset(
                    misaligned_lines, interleaving_lines, n
                )
                
                output_filename = f"{base_name}_{source_name}_n{n}.jsonl"
                output_path = output_dir / output_filename
                
                with open(output_path, 'w', encoding='utf-8') as f:
                    for line in interleaved_lines:
                        f.write(line + '\n')
                
                file_count += 1
                print(f"    n={n}: {len(interleaved_lines)} lines -> {output_filename}")
    
    print("\n" + "=" * 60)
    print(f"Done! Created {file_count} interleaved dataset files in {output_dir}")
    print("=" * 60)

if __name__ == "__main__":
    main()
