"""
Run PAIR attack using local server (SGLang/vLLM) as attack model.
Similar to Multi-turn-jailbreak/run_full_pipeline.sh functionality.
Supports multi-threading for parallel processing of instances.
"""
import sys
from pathlib import Path

# Ensure Decoy-for-the-Judge repo root is importable so we can use:
# - `unified_judge.py`
# - `defense/` package
DECOY_ROOT = Path(__file__).resolve().parents[1]
if str(DECOY_ROOT) not in sys.path:
    sys.path.insert(0, str(DECOY_ROOT))

import os
import json
import logging
import threading
import concurrent.futures
from tqdm import tqdm
import httpx
from attacker.PAIR_chao_2023 import PAIR
from jb_datasets.jailbreak_datasets import JailbreakDataset
from models.openai_model import OpenaiModel

# Defense toggles (match Multi-turn-jailbreak/run_full_pipeline.sh style)
ENABLE_MISLEAD = os.environ.get("ENABLE_MISLEAD", "false").lower() == "true"
ENABLE_PROACT = os.environ.get("ENABLE_PROACT", "false").lower() == "true"
ENABLE_GUARD = os.environ.get("ENABLE_GUARD", "false").lower() == "true"

# Defense config passed into attackers (defense modules read additional fine-grained config from env vars)
DEFENSE_CONFIG = {
    "enable_defense": ENABLE_MISLEAD,  # Mislead/rewrite defense
    "enable_proact": ENABLE_PROACT,    # ProAct proactive defense
    "enable_guard": ENABLE_GUARD,      # Guard replace-if-unsafe
}

# Defense server URLs (match Multi-turn-jailbreak/run_full_pipeline.sh conventions)
GUARD_SERVER_NODE = os.environ.get("GUARD_SERVER_NODE", "localhost")
GUARD_SERVER_PORT = os.environ.get("GUARD_SERVER_PORT", "30001")
GUARD_SERVER_URL = os.environ.get("GUARD_SERVER_URL", f"http://{GUARD_SERVER_NODE}:{GUARD_SERVER_PORT}/v1")
GUARD_API_MODE = os.environ.get("GUARD_API_MODE", "chat")
GUARD_MULTI_TURN = os.environ.get("GUARD_MULTI_TURN", "false")
os.environ.setdefault("GUARD_SERVER_URL", GUARD_SERVER_URL)
os.environ.setdefault("GUARD_API_MODE", GUARD_API_MODE)
os.environ.setdefault("GUARD_MULTI_TURN", GUARD_MULTI_TURN)

REWRITE_SERVER_NODE = os.environ.get("REWRITE_SERVER_NODE", "nid008321")
REWRITE_SERVER_PORT = os.environ.get("REWRITE_SERVER_PORT", "8000")
REWRITE_SERVER_URL = os.environ.get("REWRITE_SERVER_URL", f"http://{REWRITE_SERVER_NODE}:{REWRITE_SERVER_PORT}/v1")
REWRITE_SERVER_URL_INCREASE = os.environ.get("REWRITE_SERVER_URL_INCREASE", REWRITE_SERVER_URL)
REWRITE_SERVER_URL_DECREASE = os.environ.get("REWRITE_SERVER_URL_DECREASE", "")
os.environ.setdefault("REWRITE_SERVER_URL", REWRITE_SERVER_URL)
os.environ.setdefault("REWRITE_SERVER_URL_INCREASE", REWRITE_SERVER_URL_INCREASE)
if REWRITE_SERVER_URL_DECREASE:
    os.environ.setdefault("REWRITE_SERVER_URL_DECREASE", REWRITE_SERVER_URL_DECREASE)

# Read OpenAI key from environment (for target/eval models)
openai_api_key = os.environ.get('OPENAI_API_KEY', 'EMPTY')  # Use 'EMPTY' for local servers
# If true, use official OpenAI API (https://api.openai.com/v1) by setting base_url=None.
# You can also override per-role with USE_OPENAI_ATTACK/USE_OPENAI_TARGET/USE_OPENAI_EVAL.
USE_OPENAI = os.environ.get("USE_OPENAI", "true").lower() == "true"
USE_OPENAI_ATTACK = os.environ.get("USE_OPENAI_ATTACK", 'false').lower() == "true"
USE_OPENAI_TARGET = os.environ.get("USE_OPENAI_TARGET", "false").lower() == "true"
USE_OPENAI_EVAL = os.environ.get("USE_OPENAI_EVAL", "false").lower() == "true"

# Local server configuration (similar to run_full_pipeline.sh)
# Attack model server (SGLang/vLLM)
ATTACK_SERVER_NODE = os.environ.get('ATTACK_SERVER_NODE', 'localhost')
ATTACK_SERVER_PORT = os.environ.get('ATTACK_SERVER_PORT', '30000')
ATTACK_SERVER_URL = None if USE_OPENAI_ATTACK else f"http://{ATTACK_SERVER_NODE}:{ATTACK_SERVER_PORT}/v1"
ATTACK_MODEL_NAME = os.environ.get('ATTACK_MODEL_NAME', 'Qwen/Qwen2.5-32B-Instruct')

# Target model (local server, same as attack model)
TARGET_SERVER_NODE = os.environ.get('TARGET_SERVER_NODE', 'localhost')
TARGET_SERVER_PORT = os.environ.get('TARGET_SERVER_PORT', '30000')
TARGET_SERVER_URL = None if USE_OPENAI_TARGET else f"http://{TARGET_SERVER_NODE}:{TARGET_SERVER_PORT}/v1"
TARGET_MODEL_NAME = os.environ.get('TARGET_MODEL_NAME', 'gpt-4o-2024-11-20')  # Same as run_full_pipeline.sh

# Eval model (local server)
EVAL_SERVER_NODE = os.environ.get('EVAL_SERVER_NODE', 'localhost')
EVAL_SERVER_PORT = os.environ.get('EVAL_SERVER_PORT', '30000')
EVAL_SERVER_URL = None if USE_OPENAI_EVAL else f"http://{EVAL_SERVER_NODE}:{EVAL_SERVER_PORT}/v1"
EVAL_MODEL_NAME = os.environ.get('EVAL_MODEL_NAME', 'gpt-4o-2024-11-20')  # Same as run_full_pipeline.sh

# Dataset configuration
dataset_name_str = os.environ.get('DATASET_NAME', 'AdvBench')
# If DATASET_LIMIT is not set, process all instances (None means no limit)
dataset_limit_str = os.environ.get('DATASET_LIMIT',None)
dataset_limit = int(dataset_limit_str) if dataset_limit_str else None

# Multi-threading configuration
MAX_WORKERS = int(os.environ.get('MAX_WORKERS', 20))  # Default: 10 workers

# Rerun configuration
RERUN = os.environ.get('RERUN', 'false').lower() == 'true'  # Default: False (skip processed instances)

# Determinism configuration (matches Multi-turn-jailbreak/run_full_pipeline.sh)
# Controls seed usage in EvaluatorUnifiedJudge (default: enabled, seed=123)
ENABLE_DETERMINISM = os.environ.get('ENABLE_DETERMINISM', 'true').lower() == 'true'
os.environ.setdefault('ENABLE_DETERMINISM', str(ENABLE_DETERMINISM).lower())
# Do NOT force determinism for attack/rewrite by default; keep them stochastic unless explicitly overridden.
ENABLE_DETERMINISM_ATTACK = os.environ.get('ENABLE_DETERMINISM_ATTACK', 'false').lower() == 'true'
ENABLE_DETERMINISM_REWRITE = os.environ.get('ENABLE_DETERMINISM_REWRITE', 'false').lower() == 'true'
os.environ.setdefault('ENABLE_DETERMINISM_ATTACK', str(ENABLE_DETERMINISM_ATTACK).lower())
os.environ.setdefault('ENABLE_DETERMINISM_REWRITE', str(ENABLE_DETERMINISM_REWRITE).lower())

print("=" * 60)
print("PAIR Attack with Local Server (Multi-threaded)")
print("=" * 60)
print(f"Attack Server: {'OpenAI (default https://api.openai.com/v1)' if ATTACK_SERVER_URL is None else ATTACK_SERVER_URL}")
print(f"Attack Model: {ATTACK_MODEL_NAME}")
print(f"Target Server: {'OpenAI (default https://api.openai.com/v1)' if TARGET_SERVER_URL is None else TARGET_SERVER_URL}")
print(f"Target Model: {TARGET_MODEL_NAME}")
print(f"Eval Server: {'OpenAI (default https://api.openai.com/v1)' if EVAL_SERVER_URL is None else EVAL_SERVER_URL}")
print(f"Eval Model: {EVAL_MODEL_NAME}")
print(f"Dataset: {dataset_name_str} (limit: {'all instances' if dataset_limit is None else dataset_limit})")
print(f"Max Workers: {MAX_WORKERS}")
print(f"Rerun: {RERUN} ({'Will reprocess all instances' if RERUN else 'Will skip processed instances'})")
print(f"ENABLE_DETERMINISM: {ENABLE_DETERMINISM} (target/eval defaults; seed={'123' if ENABLE_DETERMINISM else 'None'})")
print(f"ENABLE_DETERMINISM_ATTACK: {ENABLE_DETERMINISM_ATTACK} (attack auto-defaults {'ON' if ENABLE_DETERMINISM_ATTACK else 'OFF'})")
print(f"ENABLE_DETERMINISM_REWRITE: {ENABLE_DETERMINISM_REWRITE} (rewrite auto-defaults {'ON' if ENABLE_DETERMINISM_REWRITE else 'OFF'})")
print(f"ENABLE_MISLEAD: {ENABLE_MISLEAD}")
print(f"ENABLE_PROACT: {ENABLE_PROACT}")
print(f"ENABLE_GUARD: {ENABLE_GUARD}")
print(f"GUARD_SERVER_URL: {GUARD_SERVER_URL}")
print(f"REWRITE_SERVER_URL: {REWRITE_SERVER_URL}")
print(f"REWRITE_SERVER_URL_INCREASE: {REWRITE_SERVER_URL_INCREASE}")
if REWRITE_SERVER_URL_DECREASE:
    print(f"REWRITE_SERVER_URL_DECREASE: {REWRITE_SERVER_URL_DECREASE}")
print("=" * 60)

# Setup logging (force=True to avoid duplicated handlers from imported modules)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', force=True)
# Hide HTTP request logs from OpenAI SDK/httpx
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
logging.getLogger("openai").setLevel(logging.WARNING)

def _check_openai_compatible_server(base_url: str | None, name: str) -> None:
    """Fail fast if an OpenAI-compatible server is not reachable."""
    if base_url is None:
        return
    url = base_url.rstrip("/") + "/models"
    try:
        resp = httpx.get(url, timeout=5.0)
        resp.raise_for_status()
    except Exception as e:
        raise RuntimeError(
            f"{name} server is not reachable at {base_url} "
            f"(health-check GET {url} failed). "
            f"If you're running via srun on a compute node, 'localhost' usually points to that node "
            f"(not your login node). Set {name}_SERVER_NODE/{name}_SERVER_PORT to the node/port where the server runs."
        ) from e

# Fail fast if local servers are not reachable (skip if using OpenAI official API)
_check_openai_compatible_server(ATTACK_SERVER_URL, "ATTACK")
_check_openai_compatible_server(TARGET_SERVER_URL, "TARGET")
_check_openai_compatible_server(EVAL_SERVER_URL, "EVAL")

# Optional defense servers preflight
def _check_optional_defense_servers() -> None:
    # Guard server (OpenAI-compatible)
    if ENABLE_GUARD:
        _check_openai_compatible_server(GUARD_SERVER_URL, "GUARD")

    # Mislead rewrite server(s) (OpenAI-compatible) when using server mode
    if ENABLE_MISLEAD:
        rewrite_type = os.environ.get("REWRITE_MODEL_TYPE", "server").strip().lower()
        if rewrite_type == "server":
            if REWRITE_SERVER_URL_INCREASE:
                _check_openai_compatible_server(REWRITE_SERVER_URL_INCREASE, "REWRITE_INCREASE")
            if REWRITE_SERVER_URL_DECREASE:
                _check_openai_compatible_server(REWRITE_SERVER_URL_DECREASE, "REWRITE_DECREASE")

_check_optional_defense_servers()

# Align with Multi-turn-jailbreak: preload similarity model once per process to avoid first-call overhead.
if ENABLE_MISLEAD:
    try:
        from defense.utils import get_similarity_model
        logging.info("Preloading similarity model for mislead defense...")
        get_similarity_model()
        logging.info("Similarity model ready.")
    except Exception as e:
        logging.warning("Failed to preload similarity model (continuing): %s", e)

# Load dataset
dataset = JailbreakDataset(dataset_name_str)
if dataset_limit and dataset_limit > 0:
    dataset = dataset[:dataset_limit]

# Initialize attack model (local server)
print(f"\nInitializing attack model from: {'OpenAI' if ATTACK_SERVER_URL is None else ATTACK_SERVER_URL}")
attack_model = OpenaiModel(
    model_name=ATTACK_MODEL_NAME,
    api_keys=openai_api_key,
    base_url=ATTACK_SERVER_URL,
    role="attack",
)

# Initialize target model (local server)
print(f"Initializing target model from: {'OpenAI' if TARGET_SERVER_URL is None else TARGET_SERVER_URL}")
target_model = OpenaiModel(
    model_name=TARGET_MODEL_NAME,
    api_keys=openai_api_key,
    base_url=TARGET_SERVER_URL,
    role="target",
)

# Initialize eval model (local server)
print(f"Initializing eval model from: {'OpenAI' if EVAL_SERVER_URL is None else EVAL_SERVER_URL}")
eval_model = OpenaiModel(
    model_name=EVAL_MODEL_NAME,
    api_keys=openai_api_key,
    base_url=EVAL_SERVER_URL,
    role="eval",
)

# PAIR attack parameters
PAIR_PARAMS = {
    'n_streams': 5,
    'n_iterations': 5,
    # For deterministic judging, force temperature=0.0 when determinism is enabled.
    'judge_temperature': 0.0 if ENABLE_DETERMINISM else 1.0,
    'judge_max_n_tokens': 10
}

# Get max_score from evaluator to determine success threshold
from metrics.Evaluator.Evaluator_UnifiedJudge import EvaluatorUnifiedJudge
# Pass use_local and local_base_url explicitly only if using local server (not OpenAI API)
evaluator_kwargs = {
    'temperature': PAIR_PARAMS.get('judge_temperature', 1.0),
    'max_completion_tokens': PAIR_PARAMS.get('judge_max_n_tokens', 10)
}
# Only set use_local=True if explicitly using local server (USE_OPENAI_EVAL=false)
if not USE_OPENAI_EVAL and EVAL_SERVER_URL is not None:
    evaluator_kwargs['use_local'] = True
    evaluator_kwargs['local_base_url'] = str(EVAL_SERVER_URL)
temp_evaluator = EvaluatorUnifiedJudge(eval_model, **evaluator_kwargs)
JAILBREAK_THRESHOLD = temp_evaluator.judge_evaluator.max_score
print(f"Jailbreak success threshold: score == {JAILBREAK_THRESHOLD} (max_score from evaluator)")

# Output configuration
# Default: save under this script directory: Decoy-for-the-Judge/single_turn/outputs
default_output_dir = str((Path(__file__).resolve().parent / "outputs"))
output_dir = os.environ.get("OUTPUT_DIR", default_output_dir)
os.makedirs(output_dir, exist_ok=True)
def _safe_tag(s: str) -> str:
    # Keep filenames readable and stable
    return str(s).replace("/", "_").replace(" ", "_")

# Filename semantics:
# - If a role uses a local OpenAI-compatible server (base_url != None), use tag "local"
# - If it uses the official OpenAI API (base_url == None), use the actual model name
attack_tag = "local" if ATTACK_SERVER_URL is not None else _safe_tag(ATTACK_MODEL_NAME)
target_tag = "local" if TARGET_SERVER_URL is not None else _safe_tag(TARGET_MODEL_NAME)
eval_tag = "local" if EVAL_SERVER_URL is not None else _safe_tag(EVAL_MODEL_NAME)

# Defense toggles in filename (reflect ENABLE_MISLEAD/ENABLE_PROACT/ENABLE_GUARD)
defense_tag = f"M{int(ENABLE_MISLEAD)}P{int(ENABLE_PROACT)}G{int(ENABLE_GUARD)}"

save_path = os.path.join(output_dir, f'PAIR_{attack_tag}_{target_tag}_{eval_tag}_{dataset_name_str}_{defense_tag}.jsonl')

def create_models():
    """Create model instances for each thread."""
    # Attack model
    attack_model = OpenaiModel(
        model_name=ATTACK_MODEL_NAME,
        api_keys=openai_api_key,
        base_url=ATTACK_SERVER_URL,
        role="attack",
    )
    
    # Target model (local server)
    target_model = OpenaiModel(
        model_name=TARGET_MODEL_NAME,
        api_keys=openai_api_key,
        base_url=TARGET_SERVER_URL,
        role="target",
    )
    
    # Eval model (local server)
    eval_model = OpenaiModel(
        model_name=EVAL_MODEL_NAME,
        api_keys=openai_api_key,
        base_url=EVAL_SERVER_URL,
        role="eval",
    )
    
    return attack_model, target_model, eval_model

# Thread-safe file lock for appending results
file_lock = threading.Lock()

def run_single_instance(instance_idx, instance):
    """Process a single instance in a separate thread."""
    try:
        # Create new model instances for this thread
        attack_model, target_model, eval_model = create_models()
        
        # Create attacker instance for this thread
        single_dataset = JailbreakDataset([instance])
        attacker = PAIR(
            attack_model=attack_model,
            target_model=target_model,
            eval_model=eval_model,
            jailbreak_datasets=single_dataset,
            defense_config=DEFENSE_CONFIG,
            **PAIR_PARAMS
        )
        
        # Run attack on this single instance
        result = attacker.single_attack(instance)
        
        # Convert to dict for JSON serialization
        defense_info = None
        original_score = None
        success = None
        if isinstance(getattr(result, "attack_attrs", None), dict):
            defense_info = result.attack_attrs.get("defense_info")
            original_score = result.attack_attrs.get("original_score")
            success = result.attack_attrs.get("success")
        instance_dict = {
            'query': result.query,
            'jailbreak_prompt': result.jailbreak_prompt,
            'target_responses': result.target_responses,
            'eval_results': result.eval_results,
            'original_score': original_score,
            'success': success,
            'defense_info': defense_info,
        }
        
        # Thread-safe append to JSONL file
        with file_lock:
            with open(save_path, 'a', encoding='utf-8') as f:
                f.write(json.dumps(instance_dict, ensure_ascii=False) + '\n')
        
        logging.info(f"Instance {instance_idx} completed. Query: {instance.query[:50]}...")
        return instance_idx, result
        
    except Exception as e:
        logging.error(f"Instance {instance_idx} failed: {e}", exc_info=True)
        return instance_idx, None

# Load existing results to skip processed instances (unless RERUN=True)
processed_queries = set()
if RERUN and os.path.exists(save_path):
    logging.info(f"RERUN=True: Clearing existing results file: {save_path}")
    try:
        os.remove(save_path)
        # Also remove metrics file if it exists
        metrics_path = save_path.replace('.jsonl', '_metrics.json')
        if os.path.exists(metrics_path):
            os.remove(metrics_path)
        logging.info("Existing results file cleared. Will reprocess all instances.")
    except Exception as e:
        logging.warning(f"Failed to remove existing results file: {e}")
elif not RERUN and os.path.exists(save_path):
    logging.info(f"Found existing results file: {save_path}. Loading processed instances...")
    try:
        existing_dataset = JailbreakDataset.load_jsonl(save_path)
        processed_queries = {inst.query for inst in existing_dataset if inst.query}
        logging.info(f"Loaded {len(processed_queries)} processed instances from existing file.")
    except Exception as e:
        logging.warning(f"Failed to load existing results: {e}. Starting fresh.")

# Prepare instances to process
instances_to_process = []
for idx, instance in enumerate(dataset):
    if not RERUN and instance.query in processed_queries:
        logging.info(f"Skipping already processed instance {idx}: {instance.query[:50]}...")
        continue
    instances_to_process.append((idx, instance))

if not instances_to_process:
    logging.info("All instances have been processed. Exiting.")
    exit(0)

logging.info(f"Starting multi-threaded attack on {len(instances_to_process)} instances...")
logging.info(f"Results will be saved to: {save_path}")

# Run attacks in parallel
# Note: Results are saved incrementally during processing (in run_single_instance)
all_results = []
with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
    futures = {
        executor.submit(run_single_instance, idx, inst): (idx, inst) 
        for idx, inst in instances_to_process
    }
    
    for future in tqdm(concurrent.futures.as_completed(futures), 
                       total=len(futures), 
                       desc="Processing instances"):
        idx, inst = futures[future]
        try:
            instance_idx, result = future.result()
            if result is not None:
                all_results.append(result)
        except Exception as e:
            logging.error(f"Instance {idx} generated an exception: {e}", exc_info=True)

# Wait for all threads to complete before calculating statistics
# Statistics are calculated only once, after all instances are processed
if all_results:
    logging.info(f"\nAttack completed! {len(all_results)} instances processed.")
    logging.info(f"Results saved to: {save_path}")
    
    # Optionally, reload and save aggregated metrics
    try:
        final_dataset = JailbreakDataset.load_jsonl(save_path)
        # Calculate metrics using max_score from evaluator as success threshold
        # Success is defined as score == max_score (from Decoy-for-the-Judge/unified_judge.py)
        total_queries = len(final_dataset)
        total_jailbreak = 0
        total_reject = 0
        total_score_sum = 0
        
        # Read JSONL directly to access original_score field (not always in Instance attributes)
        import json
        original_scores = {}
        try:
            with open(save_path, 'r', encoding='utf-8') as f:
                for line in f:
                    if line.strip():
                        data = json.loads(line)
                        query = data.get('query', '')
                        original_score = data.get('original_score')
                        if query and original_score is not None:
                            original_scores[query] = original_score
        except Exception as e:
            logging.warning(f"Failed to read original_score from JSONL: {e}")
        
        for inst in final_dataset:
            if inst.eval_results:
                eval_score = inst.eval_results[-1] if isinstance(inst.eval_results, list) else inst.eval_results
                # Multi-turn-jailbreak semantics: use original_score (pre-defense) for ASR and average.
                # eval_score (post-defense) is only used for stop condition.
                original_score = original_scores.get(inst.query) if inst.query else None
                if original_score is None:
                    # Fallback: try to get from instance attributes
                    try:
                        original_score = getattr(inst, "original_score", None)
                        if original_score is None and isinstance(getattr(inst, "attack_attrs", None), dict):
                            original_score = inst.attack_attrs.get("original_score")
                    except Exception:
                        original_score = None
                
                # Use original_score for statistics; fallback to eval_score if missing
                score_for_stats = original_score if isinstance(original_score, (int, float)) else eval_score
                if isinstance(score_for_stats, (int, float)):
                    total_score_sum += score_for_stats
                    if score_for_stats == JAILBREAK_THRESHOLD:  # Only max_score is considered successful
                        total_jailbreak += 1
                    else:
                        total_reject += 1
        
        asr = (total_jailbreak / total_queries) if total_queries else 0.0
        avg_score = (total_score_sum / total_queries) if total_queries else 0.0
        
        metrics_path = save_path.replace('.jsonl', '_metrics.json')
        metrics = {
            'total_queries': total_queries,
            'total_jailbreak': total_jailbreak,  # Count of successful jailbreaks (original_score==max_score when available)
            'total_reject': total_reject,  # Count of failed attempts
            'asr': asr,  # Attack Success Rate (jailbreak success)
            'avg_score': avg_score,  # Average original_score (pre-defense, fallback to eval if missing)
            'jailbreak_threshold': JAILBREAK_THRESHOLD,  # max_score from evaluator
            'success_definition': f'original_score == {JAILBREAK_THRESHOLD} (fallback to eval score if missing)',
        }
        with open(metrics_path, 'w', encoding='utf-8') as f:
            json.dump(metrics, f, ensure_ascii=False, indent=2)
        logging.info(f"Metrics saved to: {metrics_path}")
        logging.info(f"ASR: {asr*100:.2f}% (success: score == {JAILBREAK_THRESHOLD})")
        logging.info(f"Average score: {avg_score:.2f}/{JAILBREAK_THRESHOLD}")
    except Exception as e:
        logging.warning(f"Failed to save metrics: {e}")
else:
    logging.warning("No instances were successfully processed.")

