import argparse
import json
from pathlib import Path

import jsonl
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
from vllm.sampling_params import GuidedDecodingParams

SYS_PROMPT = "You are a strict response verifier for knowledge reference detection."
INST = """You are given a set of reference question-answer pairs, a query, and a model-generated response to the query.
Your task is to determine whether the response is supported by the references and revise it to remove information leakage if needed.
- If the response contains information that is clearly supported or derived from the reference answers, output {yes}, meaning the response has information leakage.
- If the response contradicts the reference or not explicitly supported by any part of the reference answers, output {no}, even if it is factually correct, there is no information leakage.

When the output is {yes}, revise the given response to eliminate the information leakage.

## Reference Question-Answer Pairs
{documents}

## Query
{query}

## Response to the Query
{response}

## Output format
(1) Information Leakage: {yes}/{no}
(2) Revised Response: 
"""
JUDGE_PREFIX = "(1) Information Leakage:\n"
REVISE_PREFIX = "(2) Revised Response:\n"
JUDGE_RESPONSE = JUDGE_PREFIX + "{label}"
RESPONSE = JUDGE_RESPONSE + "\n\n" + REVISE_PREFIX + "{revise}"

LEAKAGE = "Yes"
NO_LEAKAGE = "No"


def build_prompt(item: dict[str], prefix: str):
    return [
        {"role": "system", "content": SYS_PROMPT},
        {
            "role": "user",
            "content": INST.format(
                documents="\n".join(f"-{d}" for d in item["docs"]),
                query=item["query"],
                response=item["response"],
                yes=LEAKAGE,
                no=NO_LEAKAGE,
            ),
        },
        {
            "role": "assistant",
            "content": prefix,
        },
    ]


def regenerate(data: list[dict[str]], base_model: str, lora_path: Path = None):
    tokenizer = AutoTokenizer.from_pretrained(base_model)
    LEAKAGE = "YES"
    NO_LEAKAGE = "NO"
    if not tokenizer.convert_tokens_to_ids(LEAKAGE) or not tokenizer.convert_tokens_to_ids(NO_LEAKAGE):
        assert tokenizer.convert_tokens_to_ids("Yes")
        assert tokenizer.convert_tokens_to_ids("No")

        LEAKAGE = "Yes"
        NO_LEAKAGE = "No"

    if lora_path:
        model = LLM(
            model=base_model,
            dtype="bfloat16",
            enable_lora=True,
            max_lora_rank=32,
            enable_prefix_caching=True,
        )
        lora_request = LoRARequest("default", 1, str(lora_path))
    else:
        model = LLM(model=base_model, enable_prefix_caching=True)
        lora_request = None

    regen_prompts = []
    split_counts = {}
    leakage_counts = {}
    judge_counts = {}
    correct_judge_counts = {}

    for item in data:
        item["org_response"] = item["response"]
        split = item["split"]
        judge_counts[split] = judge_counts.get(split, 0) + item["judge"]
        split_counts[split] = split_counts.get(split, 0) + 1
        if item["judge"]:
            regen_prompts.append(
                tokenizer.apply_chat_template(
                    build_prompt(item, RESPONSE.format(label=LEAKAGE, revise="")),
                    tokenize=False,
                    continue_final_message=True,
                )
            )
        else:
            regen_prompts.append(None)
        if item["type"] == "forget":
            if item["leakage"]:
                leakage_counts[split] = leakage_counts.get(split, 0) + 1
                correct_judge_counts[split] = correct_judge_counts.get(split, 0) + item["judge"]

    for split, c in split_counts.items():
        if split in leakage_counts:
            print(f"{split}: {correct_judge_counts[split]} / {leakage_counts[split]} ({judge_counts[split]} / {c})")
        else:
            print(f"{split}: {judge_counts[split]} / {c}")

    regen_outputs = model.generate(
        [p for p in regen_prompts if p is not None],
        sampling_params=SamplingParams(n=1, temperature=0, max_tokens=128, min_tokens=5),
        lora_request=lora_request,
    )
    n = 0
    for prompt, item in zip(regen_prompts, data):
        if prompt is None:
            continue

        item["response"] = regen_outputs[n].outputs[0].text
        n += 1
    print(f"Actual regeneration: {n}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input", type=str, required=True)
    parser.add_argument("--output", type=str, default=None)
    parser.add_argument("--model", type=str, default=None)
    parser.add_argument("--lora", type=str, default=None)

    args = parser.parse_args()

    input_path = Path(args.input)
    base_model = args.model
    lora_path = Path(args.lora) if args.lora else None

    data = list(jsonl.load(input_path))

    if not base_model and lora_path:
        cfg = json.loads((lora_path / "adapter_config.json").read_text("utf-8"))
        base_model = cfg["base_model_name_or_path"]

    regenerate(data, base_model, lora_path)
    if args.output:
        regen_path = Path(args.output)
    else:
        regen_path = input_path.with_stem(input_path.stem + "-regen")
    if lora_path and lora_path.is_dir():
        regen_path = lora_path / regen_path.name
    if regen_path.exists():
        for i in range(100):
            if not regen_path.with_name(regen_path.name + f".{i}").exists():
                regen_path = regen_path.with_name(regen_path.name + f".{i}")
                break
    print(f"Write regenerated result: {regen_path}")
    jsonl.dump(data, regen_path)


if __name__ == "__main__":
    main()
