#!/usr/bin/env python3
"""
Convert dataset samples into the format expected by LLaMA-Factory.
LLaMA-Factory uses a ShareGPT-style format (the `messages` field).
"""

import json
import sys
import os
import argparse
from pathlib import Path

# ============================================================================
# Unified formatting helper functions
# ============================================================================

def build_input_text(query: str, original_response: str) -> str:
    """
    Build a unified `input_text` format.
    
    Args:
        query: User query.
        original_response: Original response.
    
    Returns:
        Formatted input_text.
    """
    return f"Query: {query or ''}\nOriginal response: {original_response}\n"


def get_query_from_sample(sample: dict, use_prompt: bool = True) -> str:
    """
    Get query text from a sample.
    
    Args:
        sample: Sample dict.
        use_prompt: Whether to use the `prompt` field (True uses it; False ignores it and returns an empty string).
                    For SFT training set this to False; otherwise typically True.
    
    Returns:
        Query string.
    """
    if use_prompt:
        return sample.get('prompt', '')
    else:
        return ""  # Ignore prompt (for SFT training).


def _load_system_prompt(direction: str) -> str:
    """Load the system prompt text from file."""
    filename = f"system_prompt_{direction}.txt"
    
    # Try multiple candidate paths.
    workspace_root = os.environ.get("WORKSPACE_ROOT") or os.environ.get("PWD")
    possible_paths = [
        Path("prompt") / filename,  # prompt/ under repo root
        Path(filename),  # current directory
        Path("..") / filename,  # parent directory
        Path("../..") / filename,  # grandparent directory
        Path("..") / "prompt" / filename,  # when invoked from train/
    ]
    if workspace_root:
        possible_paths.insert(0, Path(workspace_root) / "prompt" / filename)
        possible_paths.insert(0, Path(workspace_root) / filename)
    
    # Find the first existing file.
    for path in possible_paths:
        if path.exists():
            return path.read_text(encoding="utf-8").strip()
    
    # If none found, try current directory path to raise a clearer error.
    prompt_file = Path(filename)
    try:
        return prompt_file.read_text(encoding="utf-8").strip()
    except Exception as e:
        raise ValueError(f"Failed to read system prompt from {prompt_file} (tried: {', '.join(str(p) for p in possible_paths)}): {e}")


# Keep backward-compatible alias.
_build_input_text = build_input_text


def _build_messages(instruction: str, input_text: str, output_text: str) -> dict:
    """Build ShareGPT-style `messages`."""
    return {
        "messages": [
            {"from": "system", "value": instruction},
            {"from": "human", "value": input_text},
            {"from": "gpt", "value": output_text}
        ]
    }


def convert_to_sharegpt_format(sample, direction):
    """
    Convert a raw sample into ShareGPT format (a `messages` list).
    Returns a list that may contain one or multiple training examples.
    """
    results = []
    
    # New format: prompt + original_response + increase/decrease.
    if 'prompt' in sample and 'original_response' in sample:
        # For SFT training, ignore prompt and use an empty query.
        query = get_query_from_sample(sample, use_prompt=False)
        
        if direction == "decrease":
            # Primary pair: original_response -> decrease (if present).
            if 'decrease' in sample and 'rewritten_response' in sample['decrease']:
                original_response = sample['original_response']
                output_text = sample['decrease']['rewritten_response']
                instruction = _load_system_prompt(direction)
                input_text = _build_input_text(query, original_response)
                results.append(_build_messages(instruction, input_text, output_text))
            
            # Extra pair: increase -> unchange (if present; aligned with the main direction).
            if ('increase' in sample and 'unchange' in sample and
                'rewritten_response' in sample['increase'] and 
                'rewritten_response' in sample['unchange']):
                increase_response = sample['increase']['rewritten_response']
                unchange_response = sample['unchange']['rewritten_response']
                instruction_decrease = _load_system_prompt("decrease")
                input_text_decrease = _build_input_text(query, increase_response)
                results.append(_build_messages(instruction_decrease, input_text_decrease, unchange_response))
            
            # If only increase exists (no decrease/unchange), generate increase -> original_response.
            if ('increase' in sample and 'rewritten_response' in sample['increase'] and
                'decrease' not in sample and 'unchange' not in sample):
                increase_response = sample['increase']['rewritten_response']
                original_response = sample['original_response']
                instruction_decrease = _load_system_prompt("decrease")
                input_text_decrease = _build_input_text(query, increase_response)
                results.append(_build_messages(instruction_decrease, input_text_decrease, original_response))
        
        elif direction == "increase":
            # Primary pair: original_response -> increase (if present).
            if 'increase' in sample and 'rewritten_response' in sample['increase']:
                original_response = sample['original_response']
                output_text = sample['increase']['rewritten_response']
                instruction = _load_system_prompt(direction)
                input_text = _build_input_text(query, original_response)
                results.append(_build_messages(instruction, input_text, output_text))
            
            # Extra pair: decrease -> unchange (if present; aligned with the main direction).
            if ('decrease' in sample and 'unchange' in sample and
                'rewritten_response' in sample['decrease'] and 
                'rewritten_response' in sample['unchange']):
                decrease_response = sample['decrease']['rewritten_response']
                unchange_response = sample['unchange']['rewritten_response']
                instruction_increase = _load_system_prompt("increase")
                input_text_increase = _build_input_text(query, decrease_response)
                results.append(_build_messages(instruction_increase, input_text_increase, unchange_response))
            
            # If only decrease exists (no increase/unchange), generate decrease -> original_response.
            if ('decrease' in sample and 'rewritten_response' in sample['decrease'] and
                'increase' not in sample and 'unchange' not in sample):
                decrease_response = sample['decrease']['rewritten_response']
                original_response = sample['original_response']
                instruction_increase = _load_system_prompt("increase")
                input_text_increase = _build_input_text(query, decrease_response)
                results.append(_build_messages(instruction_increase, input_text_increase, original_response))
        
        # If nothing was produced, return an empty list (let the caller decide how to handle).
        if not results:
            return []
        
        return results
    
    # higher/lower format (paraphrases or with prompt).
    if 'higher' in sample and 'lower' in sample:
        # Choose original_response and output_text based on direction.
        if direction == "decrease":
            original_response, output_text = sample['higher'], sample['lower']
        else:
            original_response, output_text = sample['lower'], sample['higher']
        
        # Load system prompt and build messages.
        instruction = _load_system_prompt(direction)
        # For SFT training, ignore prompt and use an empty query.
        query = get_query_from_sample(sample, use_prompt=False)
        input_text = _build_input_text(query, original_response)
        return [_build_messages(instruction, input_text, output_text)]
    
    # instruction/input/output format.
    if 'instruction' in sample and 'input' in sample and 'output' in sample:
        return [_build_messages(sample['instruction'], sample['input'], sample['output'])]
    
    # Already ShareGPT format.
    if 'messages' in sample:
        return [sample]
    
    raise ValueError(f"Unknown data format: {sample.keys()}")


def convert_data(input_path, output_path, direction, data_format="jsonl"):
    """
    Convert a dataset file into the target format.
    
    Args:
        input_path: Input file path.
        output_path: Output file path.
        direction: "decrease" or "increase".
        data_format: "json" or "jsonl".
    """
    print(f"Converting data from {input_path} to {output_path}")
    print(f"  Direction: {direction}")
    print(f"  Format: {data_format}")
    
    # Read data.
    if data_format == "jsonl":
        all_data = []
        with open(input_path, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    all_data.append(json.loads(line.strip()))
    else:
        with open(input_path, 'r', encoding='utf-8') as f:
            all_data = json.load(f)
    
    print(f"  Loaded {len(all_data)} samples")
    
    # Convert format.
    converted_data = []
    for i, sample in enumerate(all_data):
        try:
            converted_samples = convert_to_sharegpt_format(sample, direction)
            # convert_to_sharegpt_format returns a list and may contain multiple training pairs.
            converted_data.extend(converted_samples)
        except Exception as e:
            continue
    
    print(f"  Converted {len(all_data)} samples into {len(converted_data)} training pairs")
    
    # Save as JSONL (recommended by LLaMA-Factory).
    with open(output_path, 'w', encoding='utf-8') as f:
        for sample in converted_data:
            f.write(json.dumps(sample, ensure_ascii=False) + '\n')
    
    print(f"  Saved to {output_path}")


def main():
    parser = argparse.ArgumentParser(
        description="Convert data format for LLaMA-Factory"
    )
    
    parser.add_argument(
        "--input",
        type=str,
        required=True,
        help="Input data file path"
    )
    
    parser.add_argument(
        "--output",
        type=str,
        required=True,
        help="Output data file path"
    )
    
    parser.add_argument(
        "--direction",
        type=str,
        default="decrease",
        choices=["decrease", "increase"],
        help="Rewrite direction"
    )
    
    parser.add_argument(
        "--data_format",
        type=str,
        default="jsonl",
        choices=["json", "jsonl"],
        help="Input data format"
    )
    
    args = parser.parse_args()
    
    convert_data(args.input, args.output, args.direction, args.data_format)


if __name__ == "__main__":
    main()

