#!/usr/bin/env python3
"""
Rewrite (defend) responses using the "mislead defense" rewrite logic.

This module serves two purposes:
1) **Library**: provides `mislead_defense()` used by Decoy-for-the-Judge defense runners.
2) **CLI**: batch-rewrite MT-bench style `model_answer/*.jsonl` files (same schema as FastChat).

Prompts:
- Default system prompts are loaded from: `Decoy-for-the-Judge/prompt/system_prompt_{increase,decrease}.txt`
"""

from __future__ import annotations

import argparse
import json
import os
import re
import sys
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

try:
    import requests
except Exception as e:  # noqa: BLE001
    raise RuntimeError("Missing dependency: `requests`. Please install it in your environment.") from e

try:
    from tqdm import tqdm
except Exception:  # noqa: BLE001
    tqdm = None  # type: ignore

_SIM_MODEL_INIT_LOCK = threading.Lock()
_SIM_ENCODE_LOCK = threading.Lock()
_SIM_MODEL = None


def _load_questions(question_file: Path) -> Dict[int, Dict[str, Any]]:
    """
    Load MT-bench questions:
      {"question_id": int, "turns": [user_turn_0, user_turn_1, ...], ...}
    """
    qmap: Dict[int, Dict[str, Any]] = {}
    with question_file.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            obj = json.loads(line)
            qid = int(obj["question_id"])
            qmap[qid] = obj
    return qmap


def _iter_jsonl(path: Path) -> List[Dict[str, Any]]:
    rows: List[Dict[str, Any]] = []
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            rows.append(json.loads(line))
    return rows


def _write_jsonl(path: Path, rows: List[Dict[str, Any]]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8") as f:
        for r in rows:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")


def _repo_root() -> Path:
    # /Decoy-for-the-Judge/defense/dj_defense.py -> repo root is parent of defense/
    return Path(__file__).resolve().parents[1]


def _default_prompt_file(direction: str) -> Path:
    return _repo_root() / "prompt" / f"system_prompt_{direction}.txt"


def clean_rewritten_response(text: str) -> str:
    """Public alias for the rewrite cleanup."""
    if not text:
        return text
    cleaned = text.strip()
    meta_patterns = [
        r"^Certainly,?\s+(here|this)\s+is\s+(the\s+)?(revised|rewritten|new)\s+response:?\s*",
        r"^Here\s+is\s+(the\s+)?(revised|rewritten|new)\s+response:?\s*",
        r"^(Certainly,?\s+)?Here'?s?\s+(the\s+)?(revised|rewritten|new)\s+response:?\s*",
        r"^I'?ve\s+rewritten\s+(the\s+)?response:?\s*",
        r"^(Here|This)\s+is\s+the\s+rewritten\s+version:?\s*",
        r"^The\s+rewritten\s+response\s+(is|follows):?\s*",
        r"^(Below|Here)\s+is\s+(the\s+)?rewritten\s+text:?\s*",
        r"^Sure,?\s+here'?s?\s+(the\s+)?rewritten\s+response:?\s*",
    ]
    for pattern in meta_patterns:
        cleaned = re.sub(pattern, "", cleaned, flags=re.IGNORECASE | re.MULTILINE)
    cleaned = cleaned.strip()
    if (cleaned.startswith('"') and cleaned.endswith('"')) or (
        cleaned.startswith("'") and cleaned.endswith("'")
    ):
        cleaned = cleaned[1:-1].strip()
    cleaned = re.sub(r"\n\s*\n+", "\n", cleaned)
    return cleaned.strip()


def _resolve_chat_completions_endpoint(base_or_endpoint: str) -> str:
    """
    Accept:
    - http://host:port/v1
    - http://host:port/v1/chat/completions
    - http://host:port  (we'll append /v1/chat/completions)
    """
    s = base_or_endpoint.rstrip("/")
    if s.endswith("/v1/chat/completions"):
        return s
    if s.endswith("/v1"):
        return s + "/chat/completions"
    return s + "/v1/chat/completions"


def _get_similarity_model(model_name: str, device: str):
    global _SIM_MODEL
    if _SIM_MODEL is not None:
        return _SIM_MODEL
    with _SIM_MODEL_INIT_LOCK:
        if _SIM_MODEL is not None:
            return _SIM_MODEL
        from sentence_transformers import SentenceTransformer  # type: ignore

        _SIM_MODEL = SentenceTransformer(model_name, device=device)
        return _SIM_MODEL


def _cosine_sim(a: str, b: str, model_name: str, device: str) -> float:
    if not a or not b:
        return 0.0
    m = _get_similarity_model(model_name, device)
    with _SIM_ENCODE_LOCK:
        embs = m.encode([a, b], convert_to_numpy=True)
    import numpy as np

    embs = embs / np.linalg.norm(embs, axis=1, keepdims=True)
    return float(np.dot(embs[0], embs[1]))


def _call_rewrite_server(
    *,
    endpoint: str,
    api_key: str,
    model: str,
    system_prompt: str,
    user_text: str,
    timeout: float,
    temperature: float,
    max_tokens: int,
) -> str:
    headers = {"Content-Type": "application/json"}
    if api_key:
        headers["Authorization"] = f"Bearer {api_key}"
    payload = {
        "model": model,
        "messages": [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_text},
        ],
        "temperature": temperature,
        "max_tokens": max_tokens,
    }
    r = requests.post(endpoint, headers=headers, json=payload, timeout=timeout)
    r.raise_for_status()
    data = r.json()
    return (data["choices"][0]["message"].get("content") or "").strip()


def _rewrite_one_turn(
    *,
    turn_text: str,
    prompt_text: str,
    direction: str,
    rewrite_model: str,
    max_attempts: int,
    sim_threshold: float,
    rewrite_server_url: Optional[str],
    rewrite_server_url_increase: Optional[str],
    rewrite_server_url_decrease: Optional[str],
    rewrite_api_key: str,
    system_prompt_increase: str,
    system_prompt_decrease: str,
    include_query: bool,
    timeout: float,
    temperature: float,
    max_tokens: int,
    sim_model_name: str,
    sim_device: str,
) -> Tuple[str, str, Optional[float]]:
    selected = None
    if rewrite_server_url_increase or rewrite_server_url_decrease:
        if direction == "increase":
            selected = rewrite_server_url_increase or rewrite_server_url_decrease
        else:
            selected = rewrite_server_url_decrease or rewrite_server_url_increase
    elif rewrite_server_url:
        selected = rewrite_server_url
    if not selected:
        raise ValueError(
            "Missing rewrite server URL. Provide --rewrite-server-url[-increase/-decrease] (or env REWRITE_SERVER_URL*)."
        )

    endpoint = _resolve_chat_completions_endpoint(selected)
    sys_prompt = system_prompt_increase if direction == "increase" else system_prompt_decrease

    if include_query and prompt_text:
        user_text = f"Query: {prompt_text}\n\nOriginal response: {turn_text}\n"
    else:
        user_text = f"Query:\nOriginal response: {turn_text}\n"

    best_text = turn_text
    best_sim = -1.0

    for attempt in range(max_attempts):
        try:
            raw = _call_rewrite_server(
                endpoint=endpoint,
                api_key=rewrite_api_key,
                model=rewrite_model,
                system_prompt=sys_prompt,
                user_text=user_text,
                timeout=timeout,
                temperature=temperature,
                max_tokens=max_tokens,
            )
            out = clean_rewritten_response(raw)
        except Exception:  # noqa: BLE001
            time.sleep(min(2**attempt, 8))
            out = ""

        if not out:
            continue

        sim = _cosine_sim(turn_text, out, model_name=sim_model_name, device=sim_device)
        if sim > best_sim:
            best_sim = sim
            best_text = out
        if sim >= sim_threshold:
            return out, direction, sim

    if best_sim >= 0:
        return best_text, direction, best_sim
    return turn_text, direction, 1.0


def mislead_defense(
    *,
    original_response: str,
    prompt: str,
    rewrite_model: str = "increase",
    max_attempts: int = 5,
    turn_index: Optional[int] = None,
    rewrite_server_url: Optional[str] = None,
    rewrite_server_url_increase: Optional[str] = None,
    rewrite_server_url_decrease: Optional[str] = None,
    direction: str = "increase",
    similarity_threshold: float = 0.8,
    include_query: bool = False,
) -> Tuple[str, str, Optional[float]]:
    """
    Drop-in compatible mislead defense used by Decoy-for-the-Judge.
    Returns: (defended_text, direction, similarity_score)
    """
    _ = turn_index  # kept for signature compatibility (not used here)
    if direction is None or direction == "":
        direction = os.getenv("REWRITE_DIRECTION", "increase")
    direction = str(direction).strip().lower()
    if direction not in ("increase", "decrease"):
        direction = "increase"

    # Prefer explicit URLs, fall back to env defaults
    default_vllm_base = "http://localhost:8000/v1"
    rsu = rewrite_server_url or os.getenv("REWRITE_SERVER_URL", default_vllm_base)
    rsu_inc = rewrite_server_url_increase or os.getenv("REWRITE_SERVER_URL_INCREASE", "")
    rsu_dec = rewrite_server_url_decrease or os.getenv("REWRITE_SERVER_URL_DECREASE", "")

    rewrite_api_key = os.getenv("REWRITE_API_KEY", "")

    sys_inc_path = _default_prompt_file("increase")
    sys_dec_path = _default_prompt_file("decrease")
    sys_inc = sys_inc_path.read_text(encoding="utf-8")
    sys_dec = sys_dec_path.read_text(encoding="utf-8")

    timeout = float(os.getenv("REWRITE_TIMEOUT", "120"))
    temperature = float(os.getenv("REWRITE_TEMPERATURE", "1.0"))
    max_tokens = int(os.getenv("REWRITE_MAX_OUT_LEN", "1024"))

    sim_model_name = os.getenv("SIMILARITY_MODEL_NAME", "sentence-transformers/paraphrase-mpnet-base-v2")
    sim_device = os.getenv("SIMILARITY_DEVICE", "cpu")

    out, used_dir, sim = _rewrite_one_turn(
        turn_text=original_response,
        prompt_text=prompt,
        direction=direction,
        rewrite_model=rewrite_model,
        max_attempts=max_attempts,
        sim_threshold=similarity_threshold,
        rewrite_server_url=rsu,
        rewrite_server_url_increase=rsu_inc or None,
        rewrite_server_url_decrease=rsu_dec or None,
        rewrite_api_key=rewrite_api_key,
        system_prompt_increase=sys_inc,
        system_prompt_decrease=sys_dec,
        include_query=include_query,
        timeout=timeout,
        temperature=temperature,
        max_tokens=max_tokens,
        sim_model_name=sim_model_name,
        sim_device=sim_device,
    )
    return out, used_dir, sim


def _rewrite_record(
    *,
    rec: Dict[str, Any],
    qmap: Dict[int, Dict[str, Any]],
    direction: str,
    rewrite_model: str,
    max_attempts: int,
    sim_threshold: float,
    rewrite_server_url: Optional[str],
    rewrite_server_url_increase: Optional[str],
    rewrite_server_url_decrease: Optional[str],
    rewrite_api_key: str,
    system_prompt_increase: str,
    system_prompt_decrease: str,
    include_query: bool,
    timeout: float,
    temperature: float,
    max_tokens: int,
    sim_model_name: str,
    sim_device: str,
    keep_original: bool,
) -> Dict[str, Any]:
    qid = int(rec["question_id"])
    q = qmap.get(qid, {})
    user_turns: List[str] = list(q.get("turns", []) or [])

    out = dict(rec)
    out_choices = []

    for ch in rec.get("choices", []):
        ch2 = dict(ch)
        turns: List[str] = list(ch.get("turns", []) or [])
        if keep_original:
            ch2["original_turns"] = list(turns)

        new_turns: List[str] = []
        sim_list: List[Optional[float]] = []
        used_dir_list: List[str] = []

        for i, t in enumerate(turns):
            prompt_text = user_turns[i] if i < len(user_turns) else ""
            new_t, used_dir, sim = _rewrite_one_turn(
                turn_text=t,
                prompt_text=prompt_text,
                direction=direction,
                rewrite_model=rewrite_model,
                max_attempts=max_attempts,
                sim_threshold=sim_threshold,
                rewrite_server_url=rewrite_server_url,
                rewrite_server_url_increase=rewrite_server_url_increase,
                rewrite_server_url_decrease=rewrite_server_url_decrease,
                rewrite_api_key=rewrite_api_key,
                system_prompt_increase=system_prompt_increase,
                system_prompt_decrease=system_prompt_decrease,
                include_query=include_query,
                timeout=timeout,
                temperature=temperature,
                max_tokens=max_tokens,
                sim_model_name=sim_model_name,
                sim_device=sim_device,
            )
            new_turns.append(new_t)
            used_dir_list.append(used_dir)
            sim_list.append(sim)

        ch2["turns"] = new_turns
        ch2["rewrite_direction"] = used_dir_list
        ch2["rewrite_similarity"] = sim_list
        out_choices.append(ch2)

    out["choices"] = out_choices
    return out


def main() -> int:
    ap = argparse.ArgumentParser()
    ap.add_argument(
        "--bench-name",
        default=os.getenv("MT_BENCH_NAME", "mt_bench"),
        help="Bench name under data/<bench-name>/ (default: mt_bench)",
    )
    ap.add_argument(
        "--question-file",
        default="",
        help="Override question file path (default: data/<bench-name>/question.jsonl)",
    )
    ap.add_argument(
        "--in-file",
        default="data/mt_bench/model_answer/gpt-4.jsonl",
        help="Input answer jsonl, e.g. data/mt_bench/model_answer/gpt-4o.jsonl",
    )
    ap.add_argument(
        "--out-file",
        default="",
        help="Output answer jsonl. If empty, defaults to data/<bench-name>/model_answer/<out-model>.jsonl",
    )
    ap.add_argument(
        "--out-model",
        default="",
        help="Output model id/name. If empty, defaults to '<model_id>-defended-<direction>'.",
    )
    ap.add_argument("--direction", default="increase", choices=["increase", "decrease"])

    ap.add_argument("--rewrite-model", default=os.getenv("REWRITE_MODEL", "increase"))
    default_vllm_base = "http://localhost:8000/v1"
    ap.add_argument(
        "--rewrite-server-url",
        default=os.getenv("REWRITE_SERVER_URL", default_vllm_base),
        help="OpenAI-compatible base URL or endpoint. Default: http://localhost:8000/v1",
    )
    ap.add_argument(
        "--rewrite-server-url-increase",
        default=os.getenv("REWRITE_SERVER_URL_INCREASE", ""),
        help="Increase-direction base URL/endpoint. If empty, falls back to --rewrite-server-url.",
    )
    ap.add_argument(
        "--rewrite-server-url-decrease",
        default=os.getenv("REWRITE_SERVER_URL_DECREASE", ""),
        help="Decrease-direction base URL/endpoint. If empty, falls back to increase/single URL.",
    )
    ap.add_argument("--rewrite-api-key", default=os.getenv("REWRITE_API_KEY", ""))

    ap.add_argument("--system-prompt-increase", default=str(_default_prompt_file("increase")))
    ap.add_argument("--system-prompt-decrease", default=str(_default_prompt_file("decrease")))
    ap.add_argument("--include-query", action="store_true", help="Include per-turn user query in rewrite input")

    ap.add_argument("--timeout", type=float, default=float(os.getenv("REWRITE_TIMEOUT", "120")))
    ap.add_argument("--temperature", type=float, default=float(os.getenv("REWRITE_TEMPERATURE", "1.0")))
    ap.add_argument("--max-tokens", type=int, default=int(os.getenv("REWRITE_MAX_OUT_LEN", "1024")))

    ap.add_argument("--sim-model", default=os.getenv("SIMILARITY_MODEL_NAME", "sentence-transformers/paraphrase-mpnet-base-v2"))
    ap.add_argument("--sim-device", default=os.getenv("SIMILARITY_DEVICE", "cpu"))

    ap.add_argument("--max-attempts", type=int, default=int(os.getenv("REWRITE_MAX_ATTEMPTS", "5")))
    ap.add_argument("--sim-threshold", type=float, default=float(os.getenv("REWRITE_SIM_THRESHOLD", "0.8")))

    ap.add_argument("--parallel", type=int, default=int(os.getenv("REWRITE_PARALLEL", "8")))
    ap.add_argument("--keep-original", action="store_true")
    ap.add_argument("--overwrite", action="store_true")
    args = ap.parse_args()

    in_file = Path(args.in_file)
    if not in_file.exists():
        raise FileNotFoundError(f"Input file not found: {in_file}")

    # Data dir relative to the current script working directory (caller typically cd's into mtbench folder).
    base_dir = Path.cwd()
    data_dir = base_dir / "data" / args.bench_name

    question_file = Path(args.question_file) if args.question_file else (data_dir / "question.jsonl")
    if not question_file.exists():
        raise FileNotFoundError(f"Question file not found: {question_file}")

    qmap = _load_questions(question_file)
    rows = _iter_jsonl(in_file)
    if not rows:
        print(f"Empty input: {in_file}", file=sys.stderr)
        return 1

    orig_model_id = str(rows[0].get("model_id", "model"))
    out_model = args.out_model or f"{orig_model_id}-defended-{args.direction}"

    out_file = Path(args.out_file) if args.out_file else (data_dir / "model_answer" / f"{out_model}.jsonl")
    if out_file.exists() and not args.overwrite:
        raise FileExistsError(f"Output exists: {out_file} (pass --overwrite to overwrite)")

    rewrite_server_url = str(args.rewrite_server_url).strip() or None
    rewrite_server_url_increase = str(args.rewrite_server_url_increase).strip() or None
    rewrite_server_url_decrease = str(args.rewrite_server_url_decrease).strip() or None

    if not rewrite_server_url_increase and not rewrite_server_url_decrease and rewrite_server_url:
        rewrite_server_url_increase = rewrite_server_url
        rewrite_server_url_decrease = None

    sys_inc = Path(args.system_prompt_increase).read_text(encoding="utf-8")
    sys_dec = Path(args.system_prompt_decrease).read_text(encoding="utf-8")

    rewritten_rows: List[Optional[Dict[str, Any]]] = [None] * len(rows)

    def _job(idx: int, rec: Dict[str, Any]) -> Tuple[int, Dict[str, Any]]:
        out = _rewrite_record(
            rec=rec,
            qmap=qmap,
            direction=args.direction,
            rewrite_model=args.rewrite_model,
            max_attempts=args.max_attempts,
            sim_threshold=args.sim_threshold,
            rewrite_server_url=rewrite_server_url,
            rewrite_server_url_increase=rewrite_server_url_increase,
            rewrite_server_url_decrease=rewrite_server_url_decrease,
            rewrite_api_key=str(args.rewrite_api_key),
            system_prompt_increase=sys_inc,
            system_prompt_decrease=sys_dec,
            include_query=bool(args.include_query),
            timeout=float(args.timeout),
            temperature=float(args.temperature),
            max_tokens=int(args.max_tokens),
            sim_model_name=str(args.sim_model),
            sim_device=str(args.sim_device),
            keep_original=bool(args.keep_original),
        )
        out["model_id"] = out_model
        return idx, out

    it = range(len(rows))
    pbar = None
    if tqdm is not None:
        pbar = tqdm(total=len(rows), desc=f"Rewriting ({out_model})", unit="q")

    with ThreadPoolExecutor(max_workers=max(1, int(args.parallel))) as ex:
        futures = {ex.submit(_job, i, rows[i]): i for i in it}
        for fut in as_completed(futures):
            i = futures[fut]
            idx, out = fut.result()
            rewritten_rows[idx] = out
            if pbar is not None:
                pbar.update(1)

    if pbar is not None:
        pbar.close()

    final_rows: List[Dict[str, Any]] = [r for r in rewritten_rows if r is not None]
    _write_jsonl(out_file, final_rows)
    print(f"Wrote defended answers: {out_file}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())


