import json
import math
import os
from dataclasses import dataclass
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
from datasets import load_dataset
from peft import LoraConfig
from sklearn.metrics import (
    average_precision_score,
    confusion_matrix,
    precision_recall_curve,
    precision_recall_fscore_support,
)
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser

from train_sft3 import NO, YES, Collator, Trainer, TrainingConfig, format_eval_inst


@dataclass
class JudgeConfig(TrainingConfig):
    ckpt_path: str = None
    threshold: float = None
    eval_data: str = "data/eval/eval-tofu.jsonl"
    report_to: str = "none"


class JudgeCollator(Collator):
    def __call__(self, examples: list[dict[str]], return_tensors=None):
        outputs = super().__call__(examples, return_tensors)
        outputs["query_id"] = [item["query_id"] for item in examples]
        return outputs


def get_between(a: float, b: float) -> float:
    lo, hi = (a, b) if a < b else (b, a)
    width = hi - lo
    d = -math.floor(math.log10(width))
    step = 10**-d

    lo_step = math.ceil(lo / step)
    hi_step = math.floor(hi / step)
    if lo_step <= hi_step:
        mid_step = (lo_step + hi_step) // 2
        return round(mid_step * step, d)
    else:
        return round((lo + hi) / 2, d)


class JudgeTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._judge_results = {}

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """
        Compute training loss and additionally compute token accuracies
        """

        with torch.no_grad():
            logits = model(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
            ).logits

            target_idx = inputs["judge_indices"]
            true_labels = inputs["is_leakage"]
            target_logits = logits[target_idx[:, 0], target_idx[:, 1]]
            yes_scores = target_logits[:, self._yes_token_id]
            no_scores = target_logits[:, self._no_token_id]

            judge_logits = yes_scores - no_scores
            probs = judge_logits.sigmoid()

        probs = self.accelerator.gather_for_metrics(probs).tolist()
        labels = self.accelerator.gather_for_metrics(true_labels).tolist()

        for key, prob, label in zip(inputs["query_id"], probs, labels):
            self._judge_results[key] = {
                "prob": prob,
                "label": label,
            }

        loss = torch.tensor(0.0, device=model.device, requires_grad=True)
        return (loss, logits) if return_outputs else loss


def write_jsonl(path: str, data):
    with open(path, "w") as fp:
        for item in data:
            fp.write(json.dumps(item, ensure_ascii=False) + "\n")
    print(f"Write {len(data)} data to: {path}")


def write_curve(labels, probs, img_path):
    precisions, recalls, thresholds_pr = precision_recall_curve(labels, probs)
    ap_score = average_precision_score(labels, probs)

    # Plot
    plt.figure()
    plt.plot(recalls, precisions, label=f"PR curve (AP={ap_score:.4f})")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title("Precision-Recall Curve")
    plt.legend(loc="lower left")
    plt.tight_layout()
    plt.savefig(img_path, dpi=300)
    plt.close()


def main(config: JudgeConfig):
    print(config)
    if config.model_name_or_path is None:
        cfg = json.loads(Path(config.ckpt_path, "adapter_config.json").read_text("utf-8"))
        config.model_name_or_path = cfg["base_model_name_or_path"]

    tokenizer = AutoTokenizer.from_pretrained(config.model_name_or_path, use_fast=True)

    lora_config = LoraConfig(
        r=config.lora_r,
        lora_alpha=config.lora_alpha,
        lora_dropout=config.lora_dropout,
        target_modules=config.lora_target_modules,
        bias="none",
        task_type="CAUSAL_LM",
    )

    # dataset
    eval_dataset = load_dataset("json", data_files=config.eval_data, split="train")
    if "query_id" not in eval_dataset.column_names:
        eval_dataset = eval_dataset.map(lambda x, idx: {"query_id": idx}, with_indices=True)

    # trainer
    model = AutoModelForCausalLM.from_pretrained(config.model_name_or_path, torch_dtype="bfloat16")
    if hasattr(model, "language_model"):
        model = getattr(model, "language_model")
    print(model)
    formatted_ds = eval_dataset.map(lambda x: format_eval_inst(x, yes=YES, no=NO) | {"query_id": x["query_id"]})
    trainer = JudgeTrainer(
        model=model,
        args=config,
        train_dataset=formatted_ds,
        eval_dataset=formatted_ds,
        processing_class=tokenizer,
        peft_config=lora_config,
        data_collator=JudgeCollator(tokenizer),
    )

    trainer._load_from_checkpoint(config.ckpt_path)
    print(trainer.model)
    trainer.evaluate()

    judge_results = trainer._judge_results
    print(f"Finished: {len(judge_results)}")

    probs = np.array([j["prob"] for j in judge_results.values()])
    labels = np.array([j["label"] for j in judge_results.values()])

    write_curve(labels, probs, img_path=os.path.join(config.ckpt_path, "judge-precision_recall_curve.png"))

    # # Youden’s J = TPR - FPR
    # fpr, tpr, thresholds = roc_curve(labels, probs)
    # j_scores = tpr - fpr
    # j_idx = np.argmax(j_scores)
    # thr_best = thresholds[j_idx].item()
    # thr_prev = thresholds[j_idx + 1].item()
    # best_threshold = get_between(thr_best, thr_prev)

    # print(f"{best_threshold=} --> {thr_prev} ~ {thr_best}, ")

    retain_probs = [judge_results[item["query_id"]]["prob"] for item in eval_dataset if item["type"] == "retain"]
    max_retain_prob = max(retain_probs)
    next_threshold = min(p for p, gt in zip(probs, labels) if gt == 1 and p > max_retain_prob)
    best_threshold = get_between(max_retain_prob, next_threshold)
    print(f"{best_threshold=} --> {max_retain_prob} ~ {next_threshold}, ")

    if config.threshold:
        best_threshold = config.threshold
        print(f"Set theshold: {best_threshold}")

    pred_at_best = (probs >= best_threshold).astype(int)
    tn, fp, fn, tp = confusion_matrix(labels, pred_at_best).ravel()
    precision, recall, f1, _ = precision_recall_fscore_support(labels, pred_at_best, average="binary", zero_division=0)
    accuracy = (tp + tn) / (tp + tn + fp + fn + 1e-12)

    print(f"[Confusion] TP={tp}, FP={fp}, TN={tn}, FN={fn}")
    print(f"[Metrics@best] Acc={accuracy:.6f}, Precision={precision:.6f}, Recall={recall:.6f}, F1={f1:.6f}")

    count_judge = 0
    count_label = 0
    outputs = []
    acc = {}
    recall = {}
    for item in eval_dataset:
        query_id = item["query_id"]

        output_item = dict.copy(item)
        if "retrieval" in output_item:
            ret = output_item.pop("retrieval")
            output_item["docs"] = [d["content"] for d in ret["documents"]]

        pred = judge_results[query_id]["prob"] >= best_threshold
        is_correct = pred == judge_results[query_id]["label"]
        count_label += judge_results[query_id]["label"]
        assert judge_results[query_id]["label"] == item["leakage"]
        count_judge += pred
        output_item["judge"] = pred
        outputs.append(output_item)

        if item["split"] not in acc:
            acc[item["split"]] = []
            if item["type"] == "forget":
                recall[item["split"]] = []
        acc[item["split"]].append(is_correct)
        if item["type"] == "forget" and item["leakage"]:
            recall[item["split"]].append(is_correct)
    print(count_judge, count_label)
    print("Accuracy")
    for key, value in acc.items():
        print(f"{key}: {sum(value)} / {len(value)} ({sum(value) / max(len(value), 1):.5f})")
    print("Recall")
    for key, value in recall.items():
        print(f"{key}: {sum(value)} / {len(value)} ({sum(value) / max(len(value), 1):.5f})")
    output_path = os.path.join(config.ckpt_path, "judge-" + os.path.basename(config.eval_data))
    write_jsonl(output_path, outputs)


if __name__ == "__main__":
    main(*HfArgumentParser(JudgeConfig).parse_args_into_dataclasses())
