#!/usr/bin/env python3
"""
Script to re-rank one dataset according to instruction field to align with another dataset.
This ensures both gpt2-base and gpt2-xl datasets have the same sample ordering.
"""

import json
import os
import argparse
from typing import List, Dict, Any


def read_jsonl(file_path: str) -> List[Dict[str, Any]]:
    """Read JSONL file and return list of dictionaries."""
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                data.append(json.loads(line))
    return data


def write_jsonl(data: List[Dict[str, Any]], file_path: str):
    """Write list of dictionaries to JSONL file."""
    with open(file_path, 'w', encoding='utf-8') as f:
        for item in data:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')


def extract_instruction(prompt: str) -> str:
    """Extract instruction from prompt text."""
    # Look for the instruction section in the prompt
    if "### Instruction:" in prompt:
        # Find the start of instruction
        start_idx = prompt.find("### Instruction:")
        # Find the end of instruction (either ### Input: or ### Response:)
        end_idx = prompt.find("### Input:")
        if end_idx == -1:
            end_idx = prompt.find("### Response:")
        
        if end_idx != -1:
            instruction = prompt[start_idx + len("### Instruction:"):end_idx].strip()
        else:
            # If no end marker found, take everything after instruction
            instruction = prompt[start_idx + len("### Instruction:"):].strip()
        
        return instruction
    else:
        # Fallback: return the entire prompt if no instruction marker found
        return prompt


def create_instruction_to_index_mapping(data: List[Dict[str, Any]]) -> Dict[str, int]:
    """Create mapping from instruction to original index."""
    mapping = {}
    for idx, item in enumerate(data):
        prompt = item.get('prompt', '')
        instruction = extract_instruction(prompt)
        mapping[instruction] = idx
    return mapping


def reorder_dataset_by_reference(reference_data: List[Dict[str, Any]], 
                                target_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """
    Reorder target_data to match the order of reference_data based on instruction.
    
    Args:
        reference_data: The dataset whose order we want to match
        target_data: The dataset to be reordered
        
    Returns:
        Reordered target_data matching reference_data order
    """
    # Create mapping from instruction to target data index
    target_instruction_to_item = {}
    for item in target_data:
        prompt = item.get('prompt', '')
        instruction = extract_instruction(prompt)
        target_instruction_to_item[instruction] = item
    
    # Reorder target data according to reference data order
    reordered_data = []
    missing_instructions = []
    
    for ref_item in reference_data:
        ref_prompt = ref_item.get('prompt', '')
        ref_instruction = extract_instruction(ref_prompt)
        
        if ref_instruction in target_instruction_to_item:
            reordered_data.append(target_instruction_to_item[ref_instruction])
        else:
            missing_instructions.append(ref_instruction)
            print(f"Warning: Instruction not found in target dataset: {ref_instruction[:100]}...")
    
    if missing_instructions:
        print(f"Total missing instructions: {len(missing_instructions)}")
    
    return reordered_data


def main():
    parser = argparse.ArgumentParser(description='Re-rank dataset by instruction to align with reference dataset')
    parser.add_argument('--reference-dataset', type=str, required=True,
                       help='Path to reference dataset (e.g., gpt2-xl)')
    parser.add_argument('--target-dataset', type=str, required=True,
                       help='Path to target dataset to be reordered (e.g., gpt2-base)')
    parser.add_argument('--output-path', type=str, required=True,
                       help='Output path for reordered target dataset')
    parser.add_argument('--verify-alignment', action='store_true',
                       help='Verify that the datasets are properly aligned after reordering')
    
    args = parser.parse_args()
    
    print(f"Loading reference dataset: {args.reference_dataset}")
    reference_data = read_jsonl(args.reference_dataset)
    print(f"Reference dataset loaded: {len(reference_data)} samples")
    
    print(f"Loading target dataset: {args.target_dataset}")
    target_data = read_jsonl(args.target_dataset)
    print(f"Target dataset loaded: {len(target_data)} samples")
    
    if len(reference_data) != len(target_data):
        print(f"Warning: Dataset sizes differ - Reference: {len(reference_data)}, Target: {len(target_data)}")
    
    print("Reordering target dataset to match reference dataset order...")
    reordered_data = reorder_dataset_by_reference(reference_data, target_data)
    
    print(f"Reordered dataset created: {len(reordered_data)} samples")
    
    # Save reordered dataset
    print(f"Saving reordered dataset to: {args.output_path}")
    write_jsonl(reordered_data, args.output_path)
    
    # Verify alignment if requested
    if args.verify_alignment:
        print("Verifying alignment...")
        verify_alignment(reference_data, reordered_data)
    
    print("Done!")


def verify_alignment(reference_data: List[Dict[str, Any]], reordered_data: List[Dict[str, Any]]):
    """Verify that the datasets are properly aligned."""
    if len(reference_data) != len(reordered_data):
        print("ERROR: Dataset sizes don't match!")
        return
    
    print("Checking first 5 samples for alignment...")
    for i in range(min(5, len(reference_data))):
        ref_prompt = reference_data[i].get('prompt', '')
        reorder_prompt = reordered_data[i].get('prompt', '')
        
        ref_instruction = extract_instruction(ref_prompt)
        reorder_instruction = extract_instruction(reorder_prompt)
        
        if ref_instruction == reorder_instruction:
            print(f"Sample {i}: ✓ Aligned")
        else:
            print(f"Sample {i}: ✗ MISALIGNED")
            print(f"  Reference: {ref_instruction[:100]}...")
            print(f"  Reordered: {reorder_instruction[:100]}...")
    
    # Check for any misalignments in the entire dataset
    misaligned_count = 0
    for i in range(len(reference_data)):
        ref_instruction = extract_instruction(reference_data[i].get('prompt', ''))
        reorder_instruction = extract_instruction(reordered_data[i].get('prompt', ''))
        if ref_instruction != reorder_instruction:
            misaligned_count += 1
    
    if misaligned_count == 0:
        print("✓ All samples are properly aligned!")
    else:
        print(f"✗ {misaligned_count} samples are misaligned!")


if __name__ == "__main__":
    main()
