#!/usr/bin/env python3
"""
Convert paired data into the LLaMA-Factory DPO format.
The DPO format requires `chosen` and `rejected` fields.
"""

import json
import sys
import os
import argparse
from pathlib import Path


from convert_data_for_sft import build_input_text, get_query_from_sample



def _load_system_prompt(direction: str) -> str:
    """
    Load the system prompt text from file.
    
    Args:
        direction: "decrease" or "increase"
    
    Returns:
        System prompt text.
    """
    filename = f"system_prompt_{direction}.txt"
    
    # Try multiple candidate paths.
    workspace_root = os.environ.get("WORKSPACE_ROOT") or os.environ.get("PWD")
    possible_paths = [
        Path("prompt") / filename,  # prompt/ under repo root
        Path(filename),  # current directory
        Path("..") / filename,  # parent directory
        Path("../..") / filename,  # grandparent directory
        Path("..") / "prompt" / filename,  # when invoked from train/
    ]
    
    # If an env-provided workspace root exists, try there as well.
    if workspace_root:
        possible_paths.insert(0, Path(workspace_root) / "prompt" / filename)
        possible_paths.insert(0, Path(workspace_root) / filename)
    
    prompt_file = None
    for path in possible_paths:
        if path.exists():
            prompt_file = path
            break
    
    if prompt_file is None:
        # If none found, use current directory path to produce a clearer error.
        prompt_file = Path(filename)
    
    try:
        system_prompt = prompt_file.read_text(encoding="utf-8").strip()
        return system_prompt
    except Exception as e:
        raise ValueError(f"Failed to read system prompt from {prompt_file} (tried: {', '.join(str(p) for p in possible_paths)}): {e}")


def _build_instruction(direction: str) -> str:
    """Build a short instruction (for no_system mode)."""
    if direction == "decrease":
        return (
            "Rewrite the response to be LESS unsafe while preserving all factual information and semantic content. "
            "Only rephrase what is already there — do not add new information, advice, or suggestions."
        )
    else:
        return (
            "Rewrite the response to be MORE unsafe while preserving all factual information and semantic content. "
            "Only rephrase what is already there — do not add new information, advice, or suggestions."
        )


def _build_conversations(prompt_text, original_response, direction, no_system=False):
    """Build the ShareGPT `conversations` list."""
    # Build input_text via a unified helper (may ignore prompt_text depending on config).
    input_text = build_input_text(prompt_text, original_response)
    
    if no_system:
        user_inst = _build_instruction(direction)
        # Merge instruction into user_input.
        user_input = f"""{user_inst}

{input_text}"""
        return [{"from": "human", "value": user_input}]
    else:
        instruction = _load_system_prompt(direction)
        return [
            {"from": "system", "value": instruction},
            {"from": "human", "value": input_text}
        ]


def convert_to_dpo_format(sample, direction, no_system=False):
    """
    Convert a raw sample into ShareGPT DPO format (chosen/rejected).
    Supports generating up to three pairs from increase/unchange/decrease.
    
    Args:
        sample: Raw data sample.
        direction: "decrease" or "increase"
        no_system: If True, omit the system message (for prefix-tuning/prompt-tuning).
    
    Returns:
        A list of ShareGPT DPO examples (each contains conversations, chosen, rejected).
        Returns up to 3 pairs.
    """
    results = []
    
    # New format: prompt + original_response + increase/unchange/decrease.
    if 'prompt' in sample and 'original_response' in sample:
        # Get query via a unified helper (may return empty string depending on config).
        prompt_text = get_query_from_sample(sample)
        original_response = sample['original_response']
        conversations = _build_conversations(prompt_text, original_response, direction, no_system)
        
        # Grab rewritten_response fields (if missing, they evaluate to None).
        increase_response = sample.get('increase', {}).get('rewritten_response')
        unchange_response = sample.get('unchange', {}).get('rewritten_response')
        decrease_response = sample.get('decrease', {}).get('rewritten_response')
        
        if direction == "decrease":
            # decrease: chosen should be safer, rejected should be more unsafe.
            # Possible pairs:
            # 1. (decrease, unchange) - decrease is safer
            # 2. (unchange, increase) - unchange is relatively safer
            # 3. (decrease, increase) - decrease is safer
            
            if decrease_response and unchange_response:
                results.append({
                    "conversations": conversations,
                    "chosen": {"from": "gpt", "value": decrease_response},
                    "rejected": {"from": "gpt", "value": unchange_response}
                })
            
            if unchange_response and increase_response:
                results.append({
                    "conversations": conversations,
                    "chosen": {"from": "gpt", "value": unchange_response},
                    "rejected": {"from": "gpt", "value": increase_response}
                })
            
            if decrease_response and increase_response:
                results.append({
                    "conversations": conversations,
                    "chosen": {"from": "gpt", "value": decrease_response},
                    "rejected": {"from": "gpt", "value": increase_response}
                })
        
        elif direction == "increase":
            # increase: chosen should be more unsafe, rejected should be safer.
            # Possible pairs:
            # 1. (increase, unchange) - increase is more unsafe
            # 2. (unchange, decrease) - unchange is relatively more unsafe
            # 3. (increase, decrease) - increase is more unsafe
            
            if increase_response and unchange_response:
                results.append({
                    "conversations": conversations,
                    "chosen": {"from": "gpt", "value": increase_response},
                    "rejected": {"from": "gpt", "value": unchange_response}
                })
            
            if unchange_response and decrease_response:
                results.append({
                    "conversations": conversations,
                    "chosen": {"from": "gpt", "value": unchange_response},
                    "rejected": {"from": "gpt", "value": decrease_response}
                })
            
            if increase_response and decrease_response:
                results.append({
                    "conversations": conversations,
                    "chosen": {"from": "gpt", "value": increase_response},
                    "rejected": {"from": "gpt", "value": decrease_response}
                })
        
        # If pairs were generated, return them; otherwise continue to other formats.
        if results:
            return results
    
    # Old format check.
    if 'prompt' in sample and 'lower' in sample and 'higher' in sample:
        # Get query via a unified helper (may return empty string depending on config).
        prompt_text = get_query_from_sample(sample)
        # Use original_response from the dataset (if present).
        original_response = sample['original_response']
        higher_response = sample['higher']
        lower_response = sample['lower']
        
        # Choose chosen/rejected by direction.
        if direction == "decrease":
            # decrease: aim to reduce unsafe_score
            # lower is safer -> preferred (chosen)
            # higher is more unsafe -> rejected
            chosen = lower_response
            rejected = higher_response
        else:  # increase
            # increase: aim to increase unsafe_score
            # higher is more unsafe -> preferred (chosen)
            # lower is safer -> rejected
            chosen = higher_response
            rejected = lower_response
        
        # Build conversations via unified helper.
        conversations = _build_conversations(prompt_text, original_response, direction, no_system)
        
        # Build ShareGPT format.
        # conversations contains system+human messages (or only human).
        # chosen and rejected are standalone gpt messages.
        return [{
            "conversations": conversations,
            "chosen": {
                "from": "gpt",
                "value": chosen
            },
            "rejected": {
                "from": "gpt",
                "value": rejected
            }
        }]
    
    elif 'conversations' in sample and 'chosen' in sample and 'rejected' in sample:
        # Already ShareGPT DPO format.
        # If no_system=True, remove system messages from conversations.
        if no_system:
            conversations = [msg for msg in sample['conversations'] if msg.get('from') != 'system']
            return [{
                "conversations": conversations,
                "chosen": sample['chosen'],
                "rejected": sample['rejected']
            }]
        return [sample]
    elif 'instruction' in sample and 'chosen' in sample and 'rejected' in sample:
        # Alpaca format -> ShareGPT format.
        # Build conversations (system+human messages, or human-only).
        conversations = []
        if no_system:
            # no_system mode: merge instruction into the user message.
            user_msg = sample.get('input', '') or sample.get('query', '')
            if 'instruction' in sample and sample['instruction']:
                if user_msg:
                    user_msg = f"{sample['instruction']}\n\n{user_msg}"
                else:
                    user_msg = sample['instruction']
            if user_msg:
                conversations.append({"from": "human", "value": user_msg})
            else:
                conversations.append({"from": "human", "value": ""})
        else:
            # system mode: instruction becomes a system message.
            if 'instruction' in sample and sample['instruction']:
                conversations.append({"from": "system", "value": sample['instruction']})
            user_msg = sample.get('input', '') or sample.get('query', '')
            if user_msg:
                conversations.append({"from": "human", "value": user_msg})
            else:
                # If there is no input, use instruction as the human message.
                conversations.append({"from": "human", "value": sample.get('instruction', '')})
        
        return [{
            "conversations": conversations,
            "chosen": {
                "from": "gpt",
                "value": sample['chosen']
            },
            "rejected": {
                "from": "gpt",
                "value": sample['rejected']
            }
        }]
    elif 'prompt' in sample and 'chosen' in sample and 'rejected' in sample:
        # Old prompt-only format -> ShareGPT format.
        # This format has no system message regardless of no_system.
        return [{
            "conversations": [
                {"from": "human", "value": sample['prompt']}
            ],
            "chosen": {
                "from": "gpt",
                "value": sample['chosen']
            },
            "rejected": {
                "from": "gpt",
                "value": sample['rejected']
            }
        }]
    
    else:
        raise ValueError(f"Unknown data format: {sample.keys()}")


def convert_data(input_path, output_path, direction, data_format="jsonl", no_system=False):
    """
    Convert a dataset file into the target DPO format.
    
    Args:
        input_path: Input file path.
        output_path: Output file path.
        direction: "decrease" or "increase"
        data_format: "json" or "jsonl"
        no_system: If True, omit system messages (for prefix-tuning/prompt-tuning).
    """
    print(f"Converting data from {input_path} to {output_path}")
    print(f"  Direction: {direction}")
    print(f"  Format: {data_format}")
    print(f"  No system: {no_system}")
    
    # Read data.
    if data_format == "jsonl":
        all_data = []
        with open(input_path, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    all_data.append(json.loads(line.strip()))
    else:
        with open(input_path, 'r', encoding='utf-8') as f:
            all_data = json.load(f)
    
    print(f"  Loaded {len(all_data)} samples")
    
    # Convert format.
    converted_data = []
    total_pairs = 0
    for i, sample in enumerate(all_data):
        try:
            converted_samples = convert_to_dpo_format(sample, direction, no_system=no_system)
            # convert_to_dpo_format returns a list and may contain multiple pairs.
            converted_data.extend(converted_samples)
            total_pairs += len(converted_samples)
        except Exception as e:
            # print(f"  Error: Failed to convert sample {i}: {e}")
            continue
    
    print(f"  Converted {len(all_data)} samples into {total_pairs} DPO pairs")
    
    # Save as JSONL.
    with open(output_path, 'w', encoding='utf-8') as f:
        for sample in converted_data:
            f.write(json.dumps(sample, ensure_ascii=False) + '\n')
    
    print(f"  Saved to {output_path}")


def main():
    parser = argparse.ArgumentParser(
        description="Convert data format for LLaMA-Factory DPO training"
    )
    
    parser.add_argument(
        "--input",
        type=str,
        required=True,
        help="Input data file path"
    )
    
    parser.add_argument(
        "--output",
        type=str,
        required=True,
        help="Output data file path"
    )
    
    parser.add_argument(
        "--direction",
        type=str,
        default="decrease",
        choices=["decrease", "increase"],
        help="DPO direction: 'decrease' means chosen=safe, rejected=unsafe; 'increase' means chosen=unsafe, rejected=safe"
    )
    
    parser.add_argument(
        "--data_format",
        type=str,
        default="jsonl",
        choices=["json", "jsonl"],
        help="Input data format"
    )
    
    parser.add_argument(
        "--no_system",
        action="store_true",
        help="Do not include system message in conversations (for prefix-tuning/prompt-tuning)"
    )
    
    args = parser.parse_args()
    
    convert_data(args.input, args.output, args.direction, args.data_format, no_system=args.no_system)


if __name__ == "__main__":
    main()

