#!/usr/bin/env python3
"""改写responses以增加或减少恶意程度（不进行评分）"""
import json
import openai
import os
import re
from tqdm import tqdm
from typing import Dict, Optional, List, Tuple
import argparse
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed


def clean_rewritten_response(text: str) -> str:
    """
    清理改写后的文本，移除常见的元文本和说明性句子
    
    Args:
        text: 原始改写文本
    
    Returns:
        清理后的文本
    """
    if not text:
        return text
    
    # 移除常见的元文本模式（不区分大小写）
    meta_patterns = [
        r"^Certainly,?\s+(here|this)\s+is\s+(the\s+)?(revised|rewritten|new)\s+response:?\s*",
        r"^Here\s+is\s+(the\s+)?(revised|rewritten|new)\s+response:?\s*",
        r"^(Certainly,?\s+)?Here'?s?\s+(the\s+)?(revised|rewritten|new)\s+response:?\s*",
        r"^I'?ve\s+rewritten\s+(the\s+)?response:?\s*",
        r"^(Here|This)\s+is\s+the\s+rewritten\s+version:?\s*",
        r"^The\s+rewritten\s+response\s+(is|follows):?\s*",
        r"^(Below|Here)\s+is\s+(the\s+)?rewritten\s+text:?\s*",
        r"^Sure,?\s+here'?s?\s+(the\s+)?rewritten\s+response:?\s*",
    ]
    
    cleaned = text.strip()
    
    # 尝试匹配并移除元文本
    for pattern in meta_patterns:
        cleaned = re.sub(pattern, "", cleaned, flags=re.IGNORECASE | re.MULTILINE)
    
    # 如果清理后文本以引号开头和结尾，移除引号
    cleaned = cleaned.strip()
    if (cleaned.startswith('"') and cleaned.endswith('"')) or \
       (cleaned.startswith("'") and cleaned.endswith("'")):
        cleaned = cleaned[1:-1].strip()
    
    # 移除多余的空行
    cleaned = re.sub(r'\n\s*\n+', '\n', cleaned)
    
    return cleaned.strip()

def _read_system_prompt(direction: str) -> str:
    """
    读取系统提示词文件（可复用的辅助函数）
    
    Args:
        direction: 改写方向，"increase" 或 "decrease"
    
    Returns:
        系统提示词内容
    """
    filename = f"system_prompt_{direction}.txt"
    
    # 尝试多个可能的路径
    possible_paths = [
        Path(filename),  # 当前目录
        Path("..") / filename,  # 上一级目录
    ]
    
    prompt_file = None
    for path in possible_paths:
        if path.exists():
            prompt_file = path
            break
    
    if prompt_file is None:
        # 如果都找不到，使用当前目录的路径（让异常信息更清楚）
        prompt_file = Path(filename)
    
    # 读取系统提示词
    try:
        system_message = prompt_file.read_text(encoding="utf-8").strip()
    except Exception as e:
        raise ValueError(f"Failed to read system prompt from {prompt_file} (tried: {', '.join(str(p) for p in possible_paths)}): {e}")
    
    return system_message


def rewrite_response(client, model_name: str, original_response: str, prompt: str, direction: str = "increase", temperature: Optional[float] = None, seed: Optional[int] = None) -> str:
    """
    使用指定的模型改写response，定向增加或减少恶意程度，但保持原意不变
    
    Args:
        client: OpenAI客户端
        model_name: 模型名称（如 "gpt-4o", "gpt-4-turbo", "gpt-4" 等）
        original_response: 原始response
        prompt: 原始prompt
        direction: 改写方向，"increase" 或 "decrease"，用于增加或减少恶意程度
        temperature: 采样温度（可选，用于控制随机性）
        seed: 随机种子（可选，用于确定性生成）
    
    Returns:
        改写后的response
    """
    # 使用复用的函数读取系统提示词
    system_message = _read_system_prompt(direction)

    user_message = f"""Query: {prompt}
Original response: {original_response}
"""

    try:
        # 对于新模型（如gpt-5.1），使用max_completion_tokens，否则使用max_tokens
        api_params = {
            "model": model_name,
            "messages": [
                {"role": "system", "content": system_message},
                {"role": "user", "content": user_message}
            ],
            "service_tier":"priority"
        }
        
        # 检查模型是否需要使用max_completion_tokens
        if any(model_prefix in model_name.lower() for model_prefix in ['gpt-5', 'gpt-4.1', 'o1']):
            api_params["max_completion_tokens"] = 1000
        else:
            api_params["max_tokens"] = 1000
        
        # 添加确定性控制参数
        if temperature is not None:
            api_params["temperature"] = temperature
        if seed is not None:
            api_params["seed"] = seed
        
        response = client.chat.completions.create(**api_params)
        raw_response = response.choices[0].message.content.strip()
        # 清理元文本
        cleaned_response = clean_rewritten_response(raw_response)
        return cleaned_response
    except Exception as e:
        print(f"Error in rewrite_response: {e}")
        return None


def rewrite_response_with_sglang(client, original_response: str, prompt: str, direction: str = "increase", temperature: float = 0.0, seed: int = 123) -> Optional[str]:
    """
    使用 SGLang 服务器改写 response（复用现有代码）
    
    Args:
        client: OpenAI 客户端（配置了 SGLang 服务器的 base_url）
        original_response: 原始response
        prompt: 原始prompt
        direction: 改写方向，"increase" 或 "decrease"
        temperature: 采样温度（默认 0.0 以确保确定性）
        seed: 随机种子（默认 123 以确保确定性）
    
    Returns:
        改写后的response，失败时返回 None
    """
    # 复用系统提示词读取函数
    system_message = _read_system_prompt(direction)
    
    user_message = f"""Query: {prompt}
Original response: {original_response}
"""

    try:
        # SGLang 使用 OpenAI 兼容的 API
        api_params = {
            "model": "default",  # SGLang 服务器通常使用 "default" 作为模型名
            "messages": [
                {"role": "system", "content": system_message},
                {"role": "user", "content": user_message}
            ],
            "max_tokens": 1000,
            "temperature": temperature,
        }
        
        # 添加 seed（如果支持）
        if seed is not None:
            api_params["seed"] = seed
        
        response = client.chat.completions.create(**api_params)
        raw_response = response.choices[0].message.content.strip()
        
        # 复用清理函数
        cleaned_response = clean_rewritten_response(raw_response)
        return cleaned_response
    except Exception as e:
        print(f"Error in rewrite_response_with_sglang: {e}")
        import traceback
        traceback.print_exc()
        return None


def rewrite_sample(client, model_name: str, sample: Dict, direction: str = "increase", num_rewrites: int = 5, use_sglang: bool = False, temperature: float = 0.0, seed: int = 123, server_url: str = None) -> Dict:
    """
    改写单个样本：生成指定方向的改写版本（不进行评分）
    
    Args:
        client: OpenAI客户端（用于改写）
        model_name: 用于改写的模型名称（SGLang 模式下不使用）
        sample: 原始样本数据
        direction: 改写方向，"increase" 或 "decrease"（默认 "increase"）
        num_rewrites: 改写次数（默认5）
        use_sglang: 是否使用 SGLang 服务器（默认 False）
        temperature: 采样温度（SGLang 模式下使用，默认 0.0）
        seed: 随机种子（SGLang 模式下使用，默认 123）
    
    Returns:
        包含原始样本信息和所有改写版本的嵌套字典（不包含评分）
    """
    original_prompt = sample.get('prompt') or sample.get('behavior', '')
    original_response = sample.get('original_response') or sample.get('response', '')
    
    # 支持paraphrases格式（没有prompt，只有higher/lower）
    if 'higher' in sample and 'lower' in sample and not original_prompt:
        if direction == "decrease":
            original_response = sample.get('higher', '')  # 原始的 unsafe response
        else:  # increase
            original_response = sample.get('lower', '')  # 原始的 safe response
    
    if not original_response:
        return None
    
    rewrites = []
    
    # 生成指定方向的改写
    for rewrite_counter in range(num_rewrites):
        if use_sglang:
            rewritten_response = rewrite_response_with_sglang(
                client,
                original_response,
                original_prompt,
                direction=direction,
                temperature=temperature,
                seed=seed,
            )
        else:
            rewritten_response = rewrite_response(
                client,
                model_name,
                original_response, 
                original_prompt,
                direction=direction
            )
        
        if rewritten_response is None:
            continue
        
        rewrite_entry = {
            "rewrite_index": rewrite_counter,
            "direction": direction,
            "rewritten_response": rewritten_response,
        }
        rewrites.append(rewrite_entry)
    
    if not rewrites:
        return None
    
    # 构建嵌套结构：一个样本包含所有改写
    nested_sample = {
        "prompt": original_prompt,
        "original_response": original_response,
        "severity_level": sample.get("severity_level"),
        "original_unsafe_score": sample.get("unsafe_score") or sample.get("original_unsafe_score"),
        "guard_label": sample.get("guard_label"),
        "harm_category": sample.get("harm_category", {}),
        "response_source": f"SGLang-{server_url}" if use_sglang and server_url else model_name,
        "num_rewrites": len(rewrites),
        "rewrites": rewrites
    }
    
    return nested_sample


def main():
    parser = argparse.ArgumentParser(
        description='Rewrite responses to increase or decrease maliciousness while preserving meaning. '
                    'Use evaluate_rewrites.py to calculate unsafe_score and similarity_score.'
    )
    parser.add_argument('--input_file', type=str, 
                       default='pku_saferlhf_prepared.jsonl',
                       help='Input JSONL file path')
    parser.add_argument('--output_file', type=str,
                       default='pku_saferlhf_rewritten.jsonl',
                       help='Output JSONL file path')
    parser.add_argument('--direction', type=str, default='increase', choices=['increase', 'decrease'],
                       help='Direction of rewriting: increase or decrease maliciousness (default: increase)')
    parser.add_argument('--num_rewrites', type=int, default=5,
                       help='Number of rewrites per sample (default: 5)')
    parser.add_argument('--start_idx', type=int, default=0,
                       help='Start processing from this index')
    parser.add_argument('--end_idx', type=int, default=None,
                       help='End processing at this index (None for all)')
    parser.add_argument('--api_key', type=str, default=None,
                       help='OpenAI API key (or set OPENAI_API_KEY env var)')
    parser.add_argument('--rewrite_model', type=str, default='gpt-4.1-mini',
                       help='Model name for rewriting (e.g., gpt-4o, gpt-4-turbo, gpt-4, o1-preview, o1-mini, gpt-5.1-mini, gpt-5.1, gpt-4.1-mini). Ignored if --server_url is provided.')
    parser.add_argument('--server_url', type=str, default='http://localhost:30000',
                       help='SGLang server URL (e.g., http://localhost:30000). If provided, uses local SGLang server instead of OpenAI API.')
    parser.add_argument('--temperature', type=float, default=0.0,
                       help='Sampling temperature for SGLang (default: 0.0 for deterministic sampling). Only used with --server_url.')
    parser.add_argument('--seed', type=int, default=123,
                       help='Random seed for SGLang (default: 123 for deterministic sampling). Only used with --server_url.')
    parser.add_argument('--batch_size', type=int, default=32,
                       help='Batch size for concurrent processing (default: 32)')
    
    args = parser.parse_args()
    
    # 根据是否提供 server_url 决定使用哪种客户端
    use_sglang = args.server_url is not None
    
    if use_sglang:
        # 使用本地 SGLang 服务器
        client = openai.OpenAI(
            base_url=f"{args.server_url}/v1",
            api_key="not-needed"  # SGLang 不需要 API key
        )
        print(f"Using local SGLang server: {args.server_url}")
        print(f"Deterministic control: temperature={args.temperature}, seed={args.seed}")
    else:
        # 使用 OpenAI API
        api_key = args.api_key or os.getenv('OPENAI_API_KEY')
        if not api_key:
            raise ValueError("OpenAI API key must be provided via --api_key or OPENAI_API_KEY environment variable")
        client = openai.OpenAI(api_key=api_key)
        print(f"Using rewrite model: {args.rewrite_model}")
    
    print("Note: This script only generates rewrites. Use evaluate_rewrites.py to calculate scores.")
    
    # 读取输入文件
    print(f"Loading samples from {args.input_file}...")
    samples = []
    with open(args.input_file, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                samples.append(json.loads(line.strip()))
    
    print(f"Loaded {len(samples)} samples")
    
    # 确定处理范围
    start_idx = args.start_idx
    end_idx = args.end_idx if args.end_idx is not None else len(samples)
    
    # 打开输出文件（覆盖写，确保每次运行输出干净）
    output_file = Path(args.output_file)
    if output_file.exists():
        print(f"Overwriting existing rewrite output file: {output_file}")
    
    print(f"Batch size: {args.batch_size}")
    
    # 准备样本索引列表
    sample_indices = list(range(start_idx, end_idx))
    total_rewrites = 0
    total_samples = 0
    
    # 使用批量并发处理
    results = []  # 存储 (idx, nested_sample) 的列表，用于后续排序输出
    
    def process_single_sample(idx: int) -> Tuple[int, Optional[Dict]]:
        """处理单个样本的辅助函数"""
        sample = samples[idx]
        try:
            nested_sample = rewrite_sample(
                client, 
                args.rewrite_model, 
                sample, 
                args.direction,
                args.num_rewrites,
                use_sglang=use_sglang,
                temperature=args.temperature if use_sglang else 0.0,
                seed=args.seed if use_sglang else 123,
                server_url=args.server_url if use_sglang else None
            )
            return (idx, nested_sample)
        except Exception as e:
            print(f"\nError processing sample {idx}: {e}")
            return (idx, None)
    
    # 使用 ThreadPoolExecutor 进行并发处理
    with ThreadPoolExecutor(max_workers=args.batch_size) as executor:
        # 提交所有任务
        future_to_idx = {executor.submit(process_single_sample, idx): idx for idx in sample_indices}
        
        # 使用 tqdm 显示进度
        for future in tqdm(as_completed(future_to_idx), total=len(sample_indices), desc="Processing samples"):
            idx, nested_sample = future.result()
            if nested_sample is not None:
                results.append((idx, nested_sample))
    
    # 按原始索引排序，确保输出顺序与输入顺序一致
    results.sort(key=lambda x: x[0])
    
    # 写入结果到文件
    with open(args.output_file, 'w', encoding='utf-8') as f_out:
        for idx, nested_sample in results:
            f_out.write(json.dumps(nested_sample, ensure_ascii=False) + '\n')
            total_rewrites += nested_sample['num_rewrites']
            total_samples += 1
    
    print(f"\nCompleted! Generated {total_rewrites} rewrites from {total_samples} samples")
    print(f"Use evaluate_rewrites.py to calculate unsafe_score and similarity_score.")


if __name__ == '__main__':
    main()

