#!/usr/bin/env python3
"""改写responses以增加或减少恶意程度（不进行评分）"""
import json
import openai
import os
import re
from tqdm import tqdm
from typing import Dict, Optional
import argparse
from pathlib import Path

def clean_rewritten_response(text: str) -> str:
    """
    清理改写后的文本，移除常见的元文本和说明性句子
    
    Args:
        text: 原始改写文本
    
    Returns:
        清理后的文本
    """
    if not text:
        return text
    
    # 移除常见的元文本模式（不区分大小写）
    meta_patterns = [
        r"^Certainly,?\s+(here|this)\s+is\s+(the\s+)?(revised|rewritten|new)\s+response:?\s*",
        r"^Here\s+is\s+(the\s+)?(revised|rewritten|new)\s+response:?\s*",
        r"^(Certainly,?\s+)?Here'?s?\s+(the\s+)?(revised|rewritten|new)\s+response:?\s*",
        r"^I'?ve\s+rewritten\s+(the\s+)?response:?\s*",
        r"^(Here|This)\s+is\s+the\s+rewritten\s+version:?\s*",
        r"^The\s+rewritten\s+response\s+(is|follows):?\s*",
        r"^(Below|Here)\s+is\s+(the\s+)?rewritten\s+text:?\s*",
        r"^Sure,?\s+here'?s?\s+(the\s+)?rewritten\s+response:?\s*",
    ]
    
    cleaned = text.strip()
    
    # 尝试匹配并移除元文本
    for pattern in meta_patterns:
        cleaned = re.sub(pattern, "", cleaned, flags=re.IGNORECASE | re.MULTILINE)
    
    # 如果清理后文本以引号开头和结尾，移除引号
    cleaned = cleaned.strip()
    if (cleaned.startswith('"') and cleaned.endswith('"')) or \
       (cleaned.startswith("'") and cleaned.endswith("'")):
        cleaned = cleaned[1:-1].strip()
    
    # 移除多余的空行
    cleaned = re.sub(r'\n\s*\n+', '\n', cleaned)
    
    return cleaned.strip()


def _build_input_text(query: str, original_response: str) -> str:
    """
    统一构建 input_text 格式：有 Query 就放，没有就空着
    （与 test_trained_model.py 完全一致）
    
    Args:
        query: 用户的查询（如果有的话，可以为空字符串）
        original_response: 原始响应
    
    Returns:
        格式化后的 input_text
    """
    if query:
        return f"""Query: {query}
Original response: {original_response}
"""
    else:
        return f"""Query: 
Original response: {original_response}
"""


def load_system_prompt_from_file(direction: str) -> str:
    """
    从文件读取系统提示词
    
    Args:
        direction: 改写方向 ("increase" 或 "decrease")
    
    Returns:
        系统提示词内容
    """
    # 从 prompt 目录读取系统提示词
    filename = f"system_prompt_{direction}.txt"
    
    # 获取当前脚本所在目录，然后定位到 prompt 目录
    script_dir = Path(__file__).parent.absolute()
    prompt_file = script_dir / ".." / "prompt" / filename
    
    try:
        instruction = prompt_file.read_text(encoding="utf-8").strip()
        return instruction
    except FileNotFoundError:
        raise FileNotFoundError(f"System prompt file not found: {prompt_file.resolve()}")
    except Exception as e:
        raise ValueError(f"Failed to read system prompt from {prompt_file.resolve()}: {e}")

def rewrite_response(client, model_name: str, original_response: str, prompt: str, direction: str = "increase") -> str:
    """
    使用指定的模型改写response，定向增加或减少恶意程度，但保持原意不变
    
    Args:
        client: OpenAI客户端（用于调用API）
        model_name: 模型名称（如 "gpt-4o", "gpt-4-turbo", "gpt-4" 等，或server上的模型名称）
        original_response: 原始response
        prompt: 原始prompt
        direction: 改写方向，"increase" 或 "decrease"，用于增加或减少恶意程度
    
    Returns:
        改写后的response
    """
    # 从文件读取系统提示词（与 test_trained_model.py 相同的逻辑）
    # 使用与训练时完全相同的格式（paraphrases格式，忽略prompt字段）
    instruction = load_system_prompt_from_file(direction)
    
    # input_text（与训练时完全一致，使用与 convert_data_for_llamafactory.py 相同的格式）
    # paraphrases 格式没有 query，使用空 query（与 test_trained_model.py 一致）
    input_text = _build_input_text("", original_response)
    
    # 构建消息（与 test_trained_model.py 完全一致）
    system_message = instruction
    user_message = input_text

    try:
        # 使用 OpenAI 兼容的 API（可以是 OpenAI API 或 SGLang server）
        if client is None:
            raise ValueError("client must be provided")
        
        def _env_float(name: str) -> Optional[float]:
            v = os.getenv(name)
            if v is None or v == "":
                return None
            try:
                return float(v)
            except Exception:
                return None

        def _env_int(name: str) -> Optional[int]:
            v = os.getenv(name)
            if v is None or v == "":
                return None
            try:
                return int(v)
            except Exception:
                return None

        # Sampling knobs (work for both OpenAI API and OpenAI-compatible servers like vLLM/SGLang).
        # Default behavior: be *more random* unless user overrides via env.
        rewrite_temperature = _env_float("REWRITE_TEMPERATURE")
        rewrite_top_p = _env_float("REWRITE_TOP_P")
        rewrite_seed = _env_int("REWRITE_SEED")

        # Determinism toggle:
        # - Prefer REWRITE_ENABLE_DETERMINISM if set; otherwise default to *disabled*
        #   so rewrite outputs are more stochastic by default.
        det_env = os.getenv("REWRITE_ENABLE_DETERMINISM")
        enable_determinism = str(det_env).lower() == "true" if det_env is not None else False

        # If user didn't specify sampling, set stochastic defaults.
        if rewrite_temperature is None:
            rewrite_temperature = 1
        if rewrite_top_p is None:
            rewrite_top_p = 0.9

        # 对于新模型（如gpt-5.1），使用max_completion_tokens，否则使用max_tokens
        api_params = {
            "model": model_name,
            "messages": [
                {"role": "system", "content": system_message},
                {"role": "user", "content": user_message}
            ],
        }
        
        # 检查模型是否需要使用max_completion_tokens（仅对OpenAI API）
        # SGLang server 使用 max_tokens
        if any(model_prefix in model_name.lower() for model_prefix in ['gpt-5', 'gpt-4.1', 'o1']):
            # 检查是否是 OpenAI API（通过检查 base_url）
            if hasattr(client, 'base_url') and 'openai.com' in str(client.base_url):
                api_params["max_completion_tokens"] = 1000
            else:
                api_params["max_tokens"] = 2048
        else:
            api_params["max_tokens"] = 2048
        
        # Apply sampling overrides if provided.
        if rewrite_temperature is not None:
            api_params["temperature"] = rewrite_temperature
        if rewrite_top_p is not None:
            api_params["top_p"] = rewrite_top_p

        # Determinism settings:
        # - Previously, OpenAI API calls were hard-coded to seed=123 + top_p=1 + n=1.
        # - Now, allow disabling determinism (no seed) and allow overriding seed via REWRITE_SEED.
        if enable_determinism:
            # Keep n=1 to match prior behavior unless caller/backend overrides.
            api_params.setdefault("n", 1)
            # If user didn't override top_p, keep prior deterministic default.
            api_params.setdefault("top_p", 1)
            if rewrite_seed is not None:
                api_params["seed"] = rewrite_seed
            else:
                api_params["seed"] = 123
        else:
            # Ensure we don't accidentally pin determinism.
            api_params.pop("seed", None)
        
        response = client.chat.completions.create(**api_params)
        raw_response = response.choices[0].message.content.strip()
        # 清理元文本
        cleaned_response = clean_rewritten_response(raw_response)
        return cleaned_response
    except Exception as e:
        print(f"Error in rewrite_response: {e}")
        return None


def rewrite_sample(client, model_name: str, sample: Dict, num_increase: int = 5, num_decrease: int = 5) -> Dict:
    """
    改写单个样本：生成增加或减少恶意程度的改写版本（不进行评分）
    
    Args:
        client: OpenAI客户端（用于改写）
        model_name: 用于改写的模型名称
        sample: 原始样本数据
        num_increase: 增加恶意程度的改写次数（默认5）
        num_decrease: 减少恶意程度的改写次数（默认5）
    
    Returns:
        包含原始样本信息和所有改写版本的嵌套字典（不包含评分）
    """
    original_prompt = sample['prompt']
    original_response = sample['response']
    
    rewrites = []
    rewrite_counter = 0
    
    # 生成增加恶意程度的改写
    for _ in range(num_increase):
        rewritten_response = rewrite_response(
            client,
            model_name,
            original_response, 
            original_prompt,
            direction="increase"
        )
        
        if rewritten_response is None:
            continue
        
        rewrite_counter += 1
        rewrite_entry = {
            "rewrite_index": rewrite_counter,
            "direction": "increase",
            "rewritten_response": rewritten_response,
        }
        rewrites.append(rewrite_entry)
    
    # 生成减少恶意程度的改写
    for _ in range(num_decrease):
        rewritten_response = rewrite_response(
            client,
            model_name,
            original_response, 
            original_prompt,
            direction="decrease"
        )
        
        if rewritten_response is None:
            continue
        
        rewrite_counter += 1
        rewrite_entry = {
            "rewrite_index": rewrite_counter,
            "direction": "decrease",
            "rewritten_response": rewritten_response,
        }
        rewrites.append(rewrite_entry)
    
    # 构建嵌套结构：一个样本包含所有改写
    nested_sample = {
        "prompt": original_prompt,
        "original_response": original_response,
        "severity_level": sample.get("severity_level"),
        "original_unsafe_score": sample.get("unsafe_score"),
        "guard_label": sample.get("guard_label"),
        "harm_category": sample.get("harm_category", {}),
        "response_source": model_name,
        "num_rewrites": len(rewrites),
        "rewrites": rewrites
    }
    
    return nested_sample


def main():
    parser = argparse.ArgumentParser(
        description='Rewrite responses to increase or decrease maliciousness while preserving meaning. '
                    'Use evaluate_rewrites.py to calculate unsafe_score and similarity_score.'
    )
    parser.add_argument('--input_file', type=str, 
                       default='pku_saferlhf_prepared.jsonl',
                       help='Input JSONL file path')
    parser.add_argument('--output_file', type=str,
                       default='pku_saferlhf_rewritten.jsonl',
                       help='Output JSONL file path')
    parser.add_argument('--num_increase', type=int, default=5,
                       help='Number of rewrites that increase maliciousness per sample (default: 5)')
    parser.add_argument('--num_decrease', type=int, default=5,
                       help='Number of rewrites that decrease maliciousness per sample (default: 5)')
    parser.add_argument('--start_idx', type=int, default=0,
                       help='Start processing from this index')
    parser.add_argument('--end_idx', type=int, default=None,
                       help='End processing at this index (None for all)')
    parser.add_argument('--api_key', type=str, default=None,
                       help='OpenAI API key (or set OPENAI_API_KEY env var)')
    parser.add_argument('--rewrite_model', type=str, default='gpt-4.1-mini',
                       help='Model name for rewriting (e.g., gpt-4o, gpt-4-turbo, gpt-4, o1-preview, o1-mini, gpt-5.1-mini, gpt-5.1, gpt-4.1-mini)')
    
    args = parser.parse_args()
    
    # 初始化OpenAI客户端
    api_key = args.api_key or os.getenv('OPENAI_API_KEY')
    if not api_key:
        raise ValueError("OpenAI API key must be provided via --api_key or OPENAI_API_KEY environment variable")
    
    client = openai.OpenAI(api_key=api_key)
    print(f"Using rewrite model: {args.rewrite_model}")
    print("Note: This script only generates rewrites. Use evaluate_rewrites.py to calculate scores.")
    
    # 读取输入文件
    print(f"Loading samples from {args.input_file}...")
    samples = []
    with open(args.input_file, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                samples.append(json.loads(line.strip()))
    
    print(f"Loaded {len(samples)} samples")
    
    # 确定处理范围
    start_idx = args.start_idx
    end_idx = args.end_idx if args.end_idx is not None else len(samples)
    
    # 打开输出文件（覆盖写，确保每次运行输出干净）
    output_file = Path(args.output_file)
    if output_file.exists():
        print(f"Overwriting existing rewrite output file: {output_file}")
    total_rewrites = 0
    total_samples = 0
    
    # 处理样本
    with open(args.output_file, 'w', encoding='utf-8') as f_out:
        for idx in tqdm(range(start_idx, end_idx), desc="Processing samples"):
            sample = samples[idx]
            
            try:
                nested_sample = rewrite_sample(
                    client, 
                    args.rewrite_model, 
                    sample, 
                    args.num_increase, 
                    args.num_decrease
                )
                
                # 保存嵌套结构的样本（每个样本一行）
                f_out.write(json.dumps(nested_sample, ensure_ascii=False) + '\n')
                f_out.flush()
                
                total_rewrites += nested_sample['num_rewrites']
                total_samples += 1
                    
            except Exception as e:
                print(f"\nError processing sample {idx}: {e}")
                # 继续处理下一个样本
                continue
    
    print(f"\nCompleted! Generated {total_rewrites} rewrites from {total_samples} samples")
    print(f"Use evaluate_rewrites.py to calculate unsafe_score and similarity_score.")


if __name__ == '__main__':
    main()

