#!/usr/bin/env python3
"""使用OpenAI Batch API改写responses以增加或减少恶意程度（不进行评分）

这个脚本使用OpenAI Batch API来批量处理改写请求，相比实时API调用更经济高效。

使用流程：
1. 创建并提交batch：
   python rewrite_responses_batch.py --input_file input.jsonl --output_file output.jsonl

2. 等待batch完成（可以跳过等待，稍后检查）：
   python rewrite_responses_batch.py --batch_id <batch_id> --metadata_file batch_metadata_X_Y.json --output_file output.jsonl

3. 或者让脚本自动等待完成（默认行为）

优势：
- 成本更低（Batch API有50%折扣）
- 适合大批量处理
- 可以异步处理，不占用本地资源

注意事项：
- Batch处理可能需要几分钟到几小时，取决于队列大小
- 需要保存metadata文件以便后续处理结果
- Batch API不支持service_tier参数
"""
import json
import openai
import os
import re
import time
from tqdm import tqdm
from typing import Dict, Optional, List, Tuple
import argparse
from pathlib import Path
import tempfile


def clean_rewritten_response(text: str) -> str:
    """
    清理改写后的文本，移除常见的元文本和说明性句子
    
    Args:
        text: 原始改写文本
    
    Returns:
        清理后的文本
    """
    if not text:
        return text
    
    # 移除常见的元文本模式（不区分大小写）
    meta_patterns = [
        r"^Certainly,?\s+(here|this)\s+is\s+(the\s+)?(revised|rewritten|new)\s+response:?\s*",
        r"^Here\s+is\s+(the\s+)?(revised|rewritten|new)\s+response:?\s*",
        r"^(Certainly,?\s+)?Here'?s?\s+(the\s+)?(revised|rewritten|new)\s+response:?\s*",
        r"^I'?ve\s+rewritten\s+(the\s+)?response:?\s*",
        r"^(Here|This)\s+is\s+the\s+rewritten\s+version:?\s*",
        r"^The\s+rewritten\s+response\s+(is|follows):?\s*",
        r"^(Below|Here)\s+is\s+(the\s+)?rewritten\s+text:?\s*",
        r"^Sure,?\s+here'?s?\s+(the\s+)?rewritten\s+response:?\s*",
    ]
    
    cleaned = text.strip()
    
    # 尝试匹配并移除元文本
    for pattern in meta_patterns:
        cleaned = re.sub(pattern, "", cleaned, flags=re.IGNORECASE | re.MULTILINE)
    
    # 如果清理后文本以引号开头和结尾，移除引号
    cleaned = cleaned.strip()
    if (cleaned.startswith('"') and cleaned.endswith('"')) or \
       (cleaned.startswith("'") and cleaned.endswith("'")):
        cleaned = cleaned[1:-1].strip()
    
    # 移除多余的空行
    cleaned = re.sub(r'\n\s*\n+', '\n', cleaned)
    
    return cleaned.strip()

def build_rewrite_request(model_name: str, original_response: str, prompt: str, direction: str) -> Dict:
    """
    构建单个改写请求的API参数
    
    Args:
        model_name: 模型名称
        original_response: 原始response
        prompt: 原始prompt
        direction: 改写方向，"increase" 或 "decrease"
    
    Returns:
        API请求参数字典
    """
    # 根据方向读取系统提示词文件（尝试从当前目录或工作目录根目录查找）
    filename = f"system_prompt_{direction}.txt"
    
    # 尝试多个可能的路径
    possible_paths = [
        Path(filename),  # 当前目录
        Path("..") / filename,  # 上一级目录
    ]
    
    prompt_file = None
    for path in possible_paths:
        if path.exists():
            prompt_file = path
            break
    
    if prompt_file is None:
        # 如果都找不到，使用当前目录的路径（让异常信息更清楚）
        prompt_file = Path(filename)
    
    # 读取系统提示词
    try:
        system_message = prompt_file.read_text(encoding="utf-8").strip()
    except Exception as e:
        raise ValueError(f"Failed to read system prompt from {prompt_file} (tried: {', '.join(str(p) for p in possible_paths)}): {e}")

    user_message = f"""Query: {prompt}
Original response: {original_response}
"""

    # 构建API参数
    api_params = {
        "model": model_name,
        "messages": [
            {"role": "system", "content": system_message},
            {"role": "user", "content": user_message}
        ]
    }
    
    # 检查模型是否需要使用max_completion_tokens (与rewrite_responses.py一致)
    if any(model_prefix in model_name.lower() for model_prefix in ['gpt-5', 'gpt-4.1', 'o1']):
        api_params["max_completion_tokens"] = 1000
    else:
        api_params["max_tokens"] = 1000
    
    # Batch API 不支持 service_tier 参数，所以不包含它
    return api_params

def prepare_batch_requests(samples: List[Dict], model_name: str, direction: str, num_rewrites: int, 
                          start_idx: int, end_idx: int) -> Tuple[List[Dict], List[Tuple[int, int, str]]]:
    """
    准备所有batch请求
    
    Args:
        samples: 样本列表
        model_name: 模型名称
        direction: 改写方向，"increase" 或 "decrease"
        num_rewrites: 每个样本的改写次数
        start_idx: 起始索引
        end_idx: 结束索引
    
    Returns:
        (batch请求列表, 元数据列表) - 元数据包含 (sample_idx, rewrite_index, direction)
    """
    batch_requests = []
    metadata = []  # 用于追踪每个请求对应的样本和改写信息
    
    for sample_idx in range(start_idx, end_idx):
        sample = samples[sample_idx]
        original_prompt = sample.get('prompt') or sample.get('behavior')
        original_response = sample.get('original_response') or sample.get('response')
        
        # 生成指定方向的改写请求
        for rewrite_counter in range(num_rewrites):
            api_params = build_rewrite_request(
                model_name,
                original_response,
                original_prompt,
                direction=direction
            )
            
            # 创建唯一的custom_id用于追踪
            custom_id = f"sample_{sample_idx}_rewrite_{rewrite_counter}_{direction}"
            
            batch_request = {
                "custom_id": custom_id,
                "method": "POST",
                "url": "/v1/chat/completions",
                "body": api_params
            }
            batch_requests.append(batch_request)
            metadata.append((sample_idx, rewrite_counter, direction))
    
    return batch_requests, metadata

def create_batch_file(batch_requests: List[Dict], output_path: str):
    """
    创建batch请求JSONL文件
    
    Args:
        batch_requests: batch请求列表
        output_path: 输出文件路径
    """
    with open(output_path, 'w', encoding='utf-8') as f:
        for request in batch_requests:
            f.write(json.dumps(request) + '\n')

def wait_for_batch_completion(client: openai.OpenAI, batch_id: str, check_interval: int = 60):
    """
    等待batch完成
    
    Args:
        client: OpenAI客户端
        batch_id: batch ID
        check_interval: 检查间隔（秒）
    
    Returns:
        batch状态对象
    """
    start_time = time.time()
    print(f"Waiting for batch {batch_id} to complete...")
    print("This may take a while. Batch processing typically takes minutes to hours depending on queue size.")
    
    while True:
        elapsed_time = time.time() - start_time
        hours = int(elapsed_time // 3600)
        minutes = int((elapsed_time % 3600) // 60)
        seconds = int(elapsed_time % 60)
        time_str = f"{hours:02d}:{minutes:02d}:{seconds:02d}"
        
        batch_status = client.batches.retrieve(batch_id)
        status = batch_status.status
        
        print(f"[{time_str}] Batch status: {status}")
        
        if status == 'completed':
            print("Batch completed successfully!")
            return batch_status
        elif status in ['failed', 'expired', 'cancelled']:
            print(f"Batch ended with status: {status}")
            return batch_status
        elif status == 'validating':
            print("Batch is being validated...")
        elif status == 'in_progress':
            if hasattr(batch_status, 'request_counts'):
                counts = batch_status.request_counts
                total = getattr(counts, 'total', 0)
                completed = getattr(counts, 'completed', 0)
                failed = getattr(counts, 'failed', 0)
                print(f"[{time_str}] Progress: {total} total, {completed} completed, {failed} failed")
        
        time.sleep(check_interval)

def process_batch_results(batch_results: List[Dict], samples: List[Dict], metadata: List[Tuple[int, int, str]], 
                         model_name: str) -> List[Dict]:
    """
    处理batch结果，组装成嵌套的样本结构
    
    Args:
        batch_results: batch结果列表
        samples: 原始样本列表
        metadata: 元数据列表，对应每个请求 (sample_idx, rewrite_index, direction)
        model_name: 模型名称
    
    Returns:
        处理后的样本列表
    """
    # 按样本索引组织结果
    sample_results = {}
    
    # 创建从custom_id到metadata的映射（作为备用，以防custom_id解析失败）
    # 同时也创建从custom_id格式字符串到metadata的映射
    import re
    
    for i, result in enumerate(batch_results):
        # 优先使用custom_id来匹配
        custom_id = result.get('custom_id')
        
        if custom_id:
            # 解析custom_id: sample_{sample_idx}_rewrite_{rewrite_index}_{direction}
            match = re.match(r'sample_(\d+)_rewrite_(\d+)_(increase|decrease)', custom_id)
            if match:
                sample_idx = int(match.group(1))
                rewrite_index = int(match.group(2))
                direction = match.group(3)
            else:
                # 如果解析失败，回退到索引匹配（但会给出警告）
                if i >= len(metadata):
                    print(f"Warning: Result {i} (custom_id={custom_id}) cannot be parsed and has no corresponding metadata")
                    continue
                sample_idx, rewrite_index, direction = metadata[i]
                print(f"Warning: Failed to parse custom_id '{custom_id}', using index-based matching")
        else:
            # 如果没有custom_id，回退到索引匹配
            if i >= len(metadata):
                print(f"Warning: Result {i} has no custom_id and no corresponding metadata")
                continue
            sample_idx, rewrite_index, direction = metadata[i]
            print(f"Warning: Result {i} has no custom_id, using index-based matching")
        
        # 处理结果
        if 'response' in result and 'body' in result['response']:
            response_body = result['response']['body']
            if 'choices' in response_body and len(response_body['choices']) > 0:
                raw_response = response_body['choices'][0]['message']['content'].strip()
                cleaned_response = clean_rewritten_response(raw_response)
            else:
                print(f"Warning: No choices in result {i} (custom_id={custom_id})")
                cleaned_response = None
        elif 'error' in result:
            print(f"Error in result {i} (custom_id={custom_id}): {result['error']}")
            cleaned_response = None
        else:
            print(f"Warning: Unexpected result format at index {i} (custom_id={custom_id})")
            cleaned_response = None
        
        if cleaned_response is None:
            continue
        
        # 初始化样本结果
        if sample_idx not in sample_results:
            sample = samples[sample_idx]
            sample_results[sample_idx] = {
                "prompt": sample.get('prompt') or sample.get('behavior'),
                "original_response": sample.get('original_response') or sample.get('response'),
                "severity_level": sample.get("severity_level"),
                "original_unsafe_score": sample.get("unsafe_score"),
                "guard_label": sample.get("guard_label"),
                "harm_category": sample.get("harm_category", {}),
                "response_source": model_name,
                "rewrites": []
            }
        
        # 添加改写结果
        rewrite_entry = {
            "rewrite_index": rewrite_index,
            "direction": direction,
            "rewritten_response": cleaned_response,
        }
        sample_results[sample_idx]["rewrites"].append(rewrite_entry)
    
    # 转换为列表并添加num_rewrites
    processed_samples = []
    for sample_idx in sorted(sample_results.keys()):
        sample_data = sample_results[sample_idx]
        # 按rewrite_index排序rewrites，确保顺序正确
        sample_data["rewrites"].sort(key=lambda x: (x["rewrite_index"], x["direction"]))
        sample_data["num_rewrites"] = len(sample_data["rewrites"])
        processed_samples.append(sample_data)
    
    return processed_samples

def main():
    parser = argparse.ArgumentParser(
        description='Rewrite responses using OpenAI Batch API to increase or decrease maliciousness while preserving meaning. '
                    'Use evaluate_rewrites.py to calculate unsafe_score and similarity_score.'
    )
    parser.add_argument('--input_file', type=str, 
                       default='pku_saferlhf_prepared.jsonl',
                       help='Input JSONL file path')
    parser.add_argument('--output_file', type=str,
                       default='pku_saferlhf_rewritten.jsonl',
                       help='Output JSONL file path')
    parser.add_argument('--direction', type=str, default='increase', choices=['increase', 'decrease'],
                       help='Direction of rewriting: increase or decrease maliciousness (default: increase)')
    parser.add_argument('--num_rewrites', type=int, default=5,
                       help='Number of rewrites per sample (default: 5)')
    parser.add_argument('--start_idx', type=int, default=0,
                       help='Start processing from this index')
    parser.add_argument('--end_idx', type=int, default=None,
                       help='End processing at this index (None for all)')
    parser.add_argument('--api_key', type=str, default=None,
                       help='OpenAI API key (or set OPENAI_API_KEY env var)')
    parser.add_argument('--rewrite_model', type=str, default='gpt-4.1-mini',
                       help='Model name for rewriting (e.g., gpt-4o, gpt-4-turbo, gpt-4, o1-preview, o1-mini, gpt-5.1-mini, gpt-5.1, gpt-4.1-mini)')
    parser.add_argument('--batch_id', type=str, default=None,
                       help='Existing batch ID to check status and retrieve results (skip creation)')
    parser.add_argument('--batch_requests_file', type=str, default=None,
                       help='Path to save/load batch requests JSONL file')
    parser.add_argument('--metadata_file', type=str, default=None,
                       help='Path to metadata file (required when using --batch_id)')
    parser.add_argument('--completion_window', type=str, default='24h',
                       help='Completion window for batch (default: 24h)')
    parser.add_argument('--check_interval', type=int, default=60,
                       help='Interval in seconds to check batch status (default: 60)')
    parser.add_argument('--skip_wait', action='store_true',
                       help='Skip waiting for batch completion (just submit and exit)')
    
    args = parser.parse_args()
    
    # 初始化OpenAI客户端
    api_key = args.api_key or os.getenv('OPENAI_API_KEY')
    if not api_key:
        raise ValueError("OpenAI API key must be provided via --api_key or OPENAI_API_KEY environment variable")
    
    client = openai.OpenAI(api_key=api_key)
    print(f"Using rewrite model: {args.rewrite_model}")
    print("Note: This script uses Batch API. Use evaluate_rewrites.py to calculate scores.")
    
    # 如果提供了batch_id，直接处理结果
    if args.batch_id:
        print(f"Retrieving results for existing batch: {args.batch_id}")
        batch_status = client.batches.retrieve(args.batch_id)
        
        if batch_status.status != 'completed':
            print(f"Batch status is {batch_status.status}, not completed yet.")
            if not args.skip_wait:
                batch_status = wait_for_batch_completion(client, args.batch_id, args.check_interval)
            else:
                print("Skipping wait. Please run again later when batch is completed.")
                return
        
        if batch_status.status != 'completed':
            print(f"Cannot proceed: batch status is {batch_status.status}")
            return
        
        # 加载metadata
        metadata_file = args.metadata_file
        if not metadata_file:
            # 尝试自动查找metadata文件
            metadata_candidates = [
                f"batch_metadata_*.json",
                "batch_metadata.json"
            ]
            found = False
            for pattern in metadata_candidates:
                matches = list(Path('.').glob(pattern))
                if matches:
                    metadata_file = str(matches[0])
                    found = True
                    break
            
            if not found:
                raise ValueError("Metadata file not found. Please provide --metadata_file or ensure metadata file exists.")
        
        print(f"Loading metadata from {metadata_file}...")
        with open(metadata_file, 'r', encoding='utf-8') as f:
            metadata_data = json.load(f)
        
        metadata = [tuple(m) for m in metadata_data['metadata']]
        model_name = metadata_data['model_name']
        
        # 读取输入文件以获取样本
        input_file = args.input_file or metadata_data.get('input_file')
        if not input_file:
            raise ValueError("Input file not specified. Please provide --input_file or ensure metadata contains input_file.")
        
        print(f"Loading samples from {input_file}...")
        samples = []
        with open(input_file, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    samples.append(json.loads(line.strip()))
        
        # 下载结果
        print("Downloading batch results...")
        output_file_id = batch_status.output_file_id
        output_content = client.files.content(output_file_id).read()
        
        # 解析结果
        batch_results = []
        for line in output_content.decode('utf-8').strip().split('\n'):
            if line.strip():
                batch_results.append(json.loads(line))
        
        print(f"Retrieved {len(batch_results)} results")
        
        # 处理结果
        print("Processing results...")
        processed_samples = process_batch_results(batch_results, samples, metadata, model_name)
        
        # 保存结果
        print(f"Saving results to {args.output_file}...")
        with open(args.output_file, 'w', encoding='utf-8') as f_out:
            for sample in processed_samples:
                f_out.write(json.dumps(sample, ensure_ascii=False) + '\n')
        
        total_rewrites = sum(s['num_rewrites'] for s in processed_samples)
        print(f"\nCompleted! Generated {total_rewrites} rewrites from {len(processed_samples)} samples")
        print(f"Use evaluate_rewrites.py to calculate unsafe_score and similarity_score.")
        return
    
    # 读取输入文件
    print(f"Loading samples from {args.input_file}...")
    samples = []
    with open(args.input_file, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                samples.append(json.loads(line.strip()))
    
    print(f"Loaded {len(samples)} samples")
    
    # 确定处理范围
    start_idx = args.start_idx
    end_idx = args.end_idx if args.end_idx is not None else len(samples)
    num_samples = end_idx - start_idx
    
    print(f"Processing samples {start_idx} to {end_idx-1} ({num_samples} samples)")
    total_requests = num_samples * args.num_rewrites
    print(f"Total batch requests: {total_requests}")
    print(f"Direction: {args.direction}")
    
    # 准备batch请求
    print("Preparing batch requests...")
    batch_requests, metadata = prepare_batch_requests(
        samples, args.rewrite_model, args.direction, args.num_rewrites, start_idx, end_idx
    )
    
    # 保存metadata以便后续处理结果
    if args.batch_requests_file:
        metadata_file = args.batch_requests_file.replace('.jsonl', '_metadata.json')
    else:
        metadata_file = f"batch_metadata_{start_idx}_{end_idx}.json"
    
    metadata_info = {
        "metadata": metadata,
        "model_name": args.rewrite_model,
        "direction": args.direction,
        "num_rewrites": args.num_rewrites,
        "start_idx": start_idx,
        "end_idx": end_idx,
        "input_file": args.input_file
    }
    
    with open(metadata_file, 'w', encoding='utf-8') as f:
        json.dump(metadata_info, f, ensure_ascii=False, indent=2)
    print(f"Saved metadata to {metadata_file}")
    
    # 创建batch请求文件
    if args.batch_requests_file:
        batch_requests_path = args.batch_requests_file
    else:
        batch_requests_path = f"batch_requests_{start_idx}_{end_idx}.jsonl"
    
    create_batch_file(batch_requests, batch_requests_path)
    print(f"Created batch requests file: {batch_requests_path}")
    
    # 上传文件
    print("Uploading batch requests file...")
    with open(batch_requests_path, 'rb') as f:
        uploaded_file = client.files.create(
            file=f,
            purpose='batch'
        )
    file_id = uploaded_file.id
    print(f"Uploaded file ID: {file_id}")
    
    # 创建batch
    print("Creating batch...")
    batch = client.batches.create(
        input_file_id=file_id,
        endpoint='/v1/chat/completions',
        completion_window=args.completion_window
    )
    batch_id = batch.id
    print(f"Batch created with ID: {batch_id}")
    
    # 更新metadata文件，添加batch_id
    metadata_info["batch_id"] = batch_id
    with open(metadata_file, 'w', encoding='utf-8') as f:
        json.dump(metadata_info, f, ensure_ascii=False, indent=2)
    print(f"Updated metadata file with batch_id: {batch_id}")
    
    print(f"\nTo retrieve results later, use:")
    print(f"  python {__file__} --batch_id {batch_id} --metadata_file {metadata_file} --output_file {args.output_file}")
    
    if args.skip_wait:
        print("\nSkipping wait. Batch is processing in the background.")
        return
    
    # 等待batch完成
    batch_status = wait_for_batch_completion(client, batch_id, args.check_interval)
    
    if batch_status.status != 'completed':
        print(f"Cannot proceed: batch status is {batch_status.status}")
        return
    
    # 下载结果
    print("Downloading batch results...")
    output_file_id = batch_status.output_file_id
    output_content = client.files.content(output_file_id).read()
    
    # 解析结果
    batch_results = []
    for line in output_content.decode('utf-8').strip().split('\n'):
        if line.strip():
            batch_results.append(json.loads(line))
    
    print(f"Retrieved {len(batch_results)} results")
    
    # 处理结果
    print("Processing results...")
    processed_samples = process_batch_results(batch_results, samples, metadata, args.rewrite_model)
    
    # 保存结果
    print(f"Saving results to {args.output_file}...")
    with open(args.output_file, 'w', encoding='utf-8') as f_out:
        for sample in processed_samples:
            f_out.write(json.dumps(sample, ensure_ascii=False) + '\n')
    
    total_rewrites = sum(s['num_rewrites'] for s in processed_samples)
    print(f"\nCompleted! Generated {total_rewrites} rewrites from {len(processed_samples)} samples")
    print(f"Use evaluate_rewrites.py to calculate unsafe_score and similarity_score.")


if __name__ == '__main__':
    main()

