import os
from dataclasses import dataclass
from datetime import datetime

import torch
import torch.nn.functional as F
from datasets import Dataset, DatasetDict, load_dataset
from peft import LoraConfig, PeftModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, PreTrainedTokenizerFast
from transformers.loss.loss_utils import ForCausalLMLoss
from trl import SFTConfig, SFTTrainer

SYS_PROMPT = "You are a strict response verifier for knowledge reference detection."
INST = """You are given a set of reference question-answer pairs, a query, and a model-generated response to the query.
Your task is to determine whether the response is supported by the references and revise it to remove information leakage if needed.
- If the response contains information that is clearly supported or derived from the reference answers, output {yes}, meaning the response has information leakage.
- If the response contradicts the reference or not explicitly supported by any part of the reference answers, output {no}, even if it is factually correct, there is no information leakage.

When the output is {yes}, revise the given response to eliminate the information leakage.

## Reference Question-Answer Pairs
{documents}

## Query
{query}

## Response to the Query
{response}

## Output format
(1) Information Leakage: {yes}/{no}
(2) Revised Response: 
"""
JUDGE_PREFIX = "(1) Information Leakage:\n"
REVISE_PREFIX = "(2) Revised Response:\n"
RESPONSE = JUDGE_PREFIX + "{label}\n\n" + REVISE_PREFIX + "{revise}"
JUDGE_RESPONSE = JUDGE_PREFIX + "{label}"


@dataclass
class TrainingConfig(SFTConfig):
    output_dir: str = "outputs"
    model_name_or_path: str = "open-unlearning/tofu_Llama-3.1-8B-Instruct_full"
    ckpt_path: str = None
    train_data: str = "data/train/sft-full-new-lowthd.jsonl"
    eval_data: str = "data/train/eval-tofu-small-llama.processed.jsonl"
    lora_r: int = 32
    lora_alpha: int = 32
    lora_dropout: float = 0
    lora_target_modules: list[str] = None

    beta: float = 2.5
    gamma: float = 2.5
    judge_coef: float = 0.05
    sft_coef: float = 1.0
    bce_coef: float = 2.0
    ent_coef: float = 0.05

    learning_rate: float = 1e-5

    per_device_train_batch_size: int = 2
    gradient_accumulation_steps: int = 16
    per_device_eval_batch_size: int = 16
    eval_on_start: bool = True
    max_length: int = 4096

    save_strategy: str = "epoch"
    eval_strategy: str = "epoch"
    num_train_epochs: int = 1
    report_to: str = "wandb"

    remove_unused_columns: bool = False

    def __post_init__(self):
        if not self.run_name:
            self.run_name = "rev2-" + datetime.now().strftime("%Y%m%d_%H%M%S")
        else:
            self.run_name = "rev2-" + self.run_name
        self.output_dir = os.path.join(self.output_dir, self.run_name)
        self.eval_on_start = True
        return super().__post_init__()


def format_prompt(example: dict[str], **kwargs):
    docs = example.get("docs")
    if docs is None and "retrieval" in example:
        docs = [d["content"].strip() for d in example["retrieval"]["documents"]]

    document_text = "\n".join(f"-{doc}" for doc in docs)
    prompt = INST.format(
        query=example["query"],
        response=example["response"],
        documents=document_text,
        **kwargs,
    )
    return prompt


def format_eval_inst(example: dict[str], yes: str, no: str):
    prompt = format_prompt(example, yes=yes, no=no)
    is_leakage = example["leakage"]

    response = RESPONSE.format(label=yes if is_leakage else no, revise=example["response"])

    return {
        "prompt": [{"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": prompt}],
        "chosen": [{"role": "assistant", "content": response}],
        "org_resp": example["response"],
        "is_leakage": is_leakage,
    }


def format_inst(example: dict[str], yes: str, no: str):
    prompt = format_prompt(example, yes=yes, no=no)

    if example["leakage"]:
        response = RESPONSE.format(label=yes, revise=example["revise"])
        rejected = RESPONSE.format(label=yes, revise=example["response"])
    else:
        response = RESPONSE.format(label=no, revise=example["response"])
        rejected = RESPONSE.format(label=no, revise=example["response"])

    return {
        "prompt": [{"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": prompt}],
        "chosen": [{"role": "assistant", "content": response}],
        "rejected": [{"role": "assistant", "content": rejected}],
        "org_resp": example["response"],
        "is_leakage": example["leakage"],
    }


def pad(values: list[list[int | float]], max_len: int, pad_id: int):
    return torch.tensor([v + [pad_id] * (max_len - len(v)) for v in values])


@dataclass
class Collator:
    tokenizer: PreTrainedTokenizerFast

    def __call__(self, examples: list[dict[str]], return_tensors=None) -> dict[str]:
        chosen_len = max(len(item["chosen_ids"]) for item in examples)
        rejected_len = max(len(item["rejected_ids"]) for item in examples)
        max_len = max(chosen_len, rejected_len)

        pad_id = self.tokenizer.pad_token_id

        input_ids = [item["chosen_ids"] for item in examples]
        attention_mask = [item["chosen_mask"] for item in examples]
        labels = [item["chosen_label"] for item in examples]
        if examples[0]["rejected_ids"]:
            input_ids += [item["rejected_ids"] for item in examples]
            attention_mask += [item["rejected_mask"] for item in examples]
            labels += [item["rejected_label"] for item in examples]

        return {
            "input_ids": pad(input_ids, max_len=max_len, pad_id=pad_id),
            "attention_mask": pad(attention_mask, max_len=max_len, pad_id=0),
            "labels": pad(labels, max_len=max_len, pad_id=-100),
            "judge_label": pad([item["judge_label"] for item in examples], max_len=max_len, pad_id=-100),
            "judge_indices": torch.tensor([(i, item["judge_idx"] - 1) for i, item in enumerate(examples)]),
            "org_length": torch.tensor([item["org_length"] for item in examples]),
            "is_leakage": torch.tensor([item["is_leakage"] for item in examples]),
            "sample_ids": [item["sample_id"] for item in examples],
        }


def compute_sequence_logp(logits: torch.Tensor, labels: torch.Tensor):
    shift_logits = logits[..., :-1, :]
    shift_labels = labels[..., 1:]

    index = shift_labels.masked_fill(shift_labels == -100, 0)
    selected_logits = torch.gather(shift_logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
    logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in shift_logits])
    per_token_logps = selected_logits - logsumexp_values

    response_mask = (shift_labels != -100).to(per_token_logps.dtype)
    per_token_logps *= response_mask
    return per_token_logps.sum(-1)


class Trainer(SFTTrainer):
    args: TrainingConfig

    def _prepare_dataset(
        self, dataset: Dataset, processing_class: PreTrainedTokenizerFast, args, packing, formatting_func, dataset_name
    ):
        def _format(example: dict[str], idx: int):
            chosen = processing_class.apply_chat_template(example["prompt"] + example["chosen"], tokenize=False)
            if "rejected" in example:
                rejected = processing_class.apply_chat_template(example["prompt"] + example["rejected"], tokenize=False)
                tokenized = processing_class([chosen, rejected], return_offsets_mapping=True, add_special_tokens=False)
                chosen_ids, rejected_ids = tokenized["input_ids"]
            else:
                tokenized = processing_class([chosen], return_offsets_mapping=True, add_special_tokens=False)
                chosen_ids = tokenized["input_ids"][0]
                rejected_ids = []

            org_length = len(processing_class.encode(example["org_resp"], add_special_tokens=False))

            offsets = tokenized["offset_mapping"][0]

            judge_start = chosen.find(JUDGE_PREFIX) + len(JUDGE_PREFIX) + 1
            judge_offset = next(i for i, (s, e) in enumerate(offsets) if s <= judge_start < e)

            rev_start = chosen.find(REVISE_PREFIX) + len(REVISE_PREFIX) + 1
            rev_offset = next(i for i, (s, e) in enumerate(offsets) if s <= rev_start < e)

            return {
                "chosen_ids": chosen_ids,
                "chosen_mask": [1] * len(chosen_ids),
                "chosen_label": [-100] * rev_offset + chosen_ids[rev_offset:],
                "judge_label": [-100] * judge_offset + chosen_ids[judge_offset : judge_offset + 1],
                "rejected_ids": rejected_ids,
                "rejected_mask": [1] * len(rejected_ids),
                "rejected_label": [-100] * rev_offset + rejected_ids[rev_offset:],
                "judge_idx": judge_offset,
                "is_leakage": example["is_leakage"],
                "org_length": org_length,
                "sample_id": idx,
            }

        dataset = dataset.map(_format, with_indices=True)
        self._chosen_ref_logps: dict[int, float] = {}
        self._rejected_ref_logps: dict[int, float] = {}

        return dataset

    def evaluation_loop(
        self,
        dataloader,
        description,
        prediction_loss_only=None,
        ignore_keys=None,
        metric_key_prefix="eval",
    ):
        self._current_metric_key = metric_key_prefix
        return super().evaluation_loop(dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)

    def prediction_step(
        self, model, inputs: dict[str], prediction_loss_only: bool, ignore_keys: list[str] = None
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        with torch.no_grad():
            with self.compute_loss_context_manager():
                outputs = model(
                    input_ids=inputs["input_ids"],
                    attention_mask=inputs["attention_mask"],
                )
                logits = outputs.logits.detach()
                self.compute_judge(
                    inputs, logits, mode="eval", metric_key=self._current_metric_key.removeprefix("eval_") + "_"
                )

        if prediction_loss_only:
            return (None, None, None)
        return (None, logits, None)

    def training_step(self, model, inputs, num_items_in_batch=None):
        self._current_metric_key = "train"
        return super().training_step(model, inputs, num_items_in_batch)

    def sft_loss(self, logits, labels):
        return ForCausalLMLoss(logits, labels, self.model.config.vocab_size)

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """
        Compute training loss and additionally compute token accuracies
        """

        if self._current_metric_key.startswith("eval_"):
            metric_key = self._current_metric_key.removeprefix("eval_") + "_"
            mode = "eval"
        else:
            mode = "train"
            metric_key = ""

        batch_size = inputs["input_ids"].size(0) // 2
        logits = model(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
        ).logits

        chosen_logits = logits[:batch_size]
        rejected_logits = logits[batch_size:]

        revise_mask = (inputs["labels"] != -100).long()
        logps = compute_sequence_logp(logits, inputs["labels"])

        target_idx = inputs["judge_indices"]
        target_logits = chosen_logits[target_idx[:, 0], target_idx[:, 1]]
        yes_scores = target_logits[:, self._yes_token_id]
        no_scores = target_logits[:, self._no_token_id]

        true_labels = inputs["is_leakage"]
        judge_binary_logits = torch.stack([no_scores, yes_scores], dim=-1)
        judge_binary_loss = F.cross_entropy(judge_binary_logits, true_labels.long())
        judge_loss = (judge_binary_loss + self.sft_loss(chosen_logits, inputs["judge_label"])) / 2

        # SFT Loss
        chosen_loss = self.sft_loss(chosen_logits, inputs["labels"][:batch_size])

        # BCE Loss
        revise_length = revise_mask.sum(-1)
        chosen_length = torch.minimum(revise_length[:batch_size], inputs["org_length"])
        rejected_length = torch.minimum(revise_length[batch_size:], inputs["org_length"])
        chosen_reward = logps[:batch_size] / chosen_length
        rejected_reward = logps[batch_size:] / rejected_length

        beta = self.args.beta
        margin = self.args.gamma
        bce_loss = -F.logsigmoid(beta * (chosen_reward - rejected_reward) - margin).mean()

        # Entropy Loss
        rejected_mask = revise_mask[batch_size:]
        rejected_logps = rejected_logits.log_softmax(dim=-1)
        rejected_entropy = -(rejected_logps.exp() * rejected_logps).sum(-1)
        normalized_ent = (rejected_entropy * rejected_mask).sum(-1) / revise_length[batch_size:].clamp_min(1)

        ent_loss = normalized_ent.mean()

        # Final Loss
        judge_coef = self.args.judge_coef
        sft_coef = self.args.sft_coef
        bce_coef = self.args.bce_coef
        ent_coef = self.args.ent_coef
        loss = judge_loss * judge_coef + chosen_loss * sft_coef + bce_loss * bce_coef + ent_loss * ent_coef

        self._metrics[mode][f"{metric_key}chosen_reward"].append(chosen_reward.mean().item())
        self._metrics[mode][f"{metric_key}rejected_reward"].append(rejected_reward.mean().item())
        self._metrics[mode][f"{metric_key}judge_loss"].append(judge_loss.item())
        self._metrics[mode][f"{metric_key}sft_loss"].append(chosen_loss.item())
        self._metrics[mode][f"{metric_key}bce_loss"].append(bce_loss.item())
        self._metrics[mode][f"{metric_key}ent_loss"].append(ent_loss.item())

        # self.compute_judge(inputs, chosen_logits, mode, metric_key)
        with torch.no_grad():
            pred_labels = yes_scores > no_scores

            corrects = pred_labels == true_labels
            corrects = self.accelerator.gather_for_metrics(corrects).tolist()
            true_labels = self.accelerator.gather_for_metrics(true_labels).tolist()

        avg_accuracy = sum(corrects) / len(corrects)
        self._metrics[mode][f"{metric_key}judge_accuracy"].append(avg_accuracy)

        n_pos = sum(true_labels)
        if n_pos > 0:
            recall = sum((c and t) for c, t in zip(corrects, true_labels)) / n_pos
            self._metrics[mode][f"{metric_key}judge_recall"].append(recall)

        return (loss, chosen_logits) if return_outputs else loss

    def compute_judge(self, inputs: dict[str], logits, mode: str, metric_key: str = ""):
        with torch.no_grad():
            target_idx = inputs["judge_indices"]
            target_logits = logits[target_idx[:, 0], target_idx[:, 1]]
            yes_scores = target_logits[:, self._yes_token_id]
            no_scores = target_logits[:, self._no_token_id]

            true_labels = inputs["is_leakage"]
            pred_labels = yes_scores > no_scores

            corrects = pred_labels == true_labels
            corrects = self.accelerator.gather_for_metrics(corrects).tolist()
            true_labels = self.accelerator.gather_for_metrics(true_labels).tolist()

        avg_accuracy = sum(corrects) / len(corrects)
        self._metrics[mode][f"{metric_key}judge_accuracy"].append(avg_accuracy)

        n_pos = sum(true_labels)
        if n_pos > 0:
            recall = sum((c and t) for c, t in zip(corrects, true_labels)) / n_pos
            self._metrics[mode][f"{metric_key}judge_recall"].append(recall)

    def log(self, logs, start_time=None):
        if self._current_metric_key != "train" and not self.control.should_evaluate:
            self.control.should_evaluate = True
            super().log(logs, start_time)
            self.control.should_evaluate = False
        else:
            super().log(logs, start_time)


def main(config: TrainingConfig):
    print(config)
    YES = "Yes"
    NO = "No"

    tokenizer = AutoTokenizer.from_pretrained(config.model_name_or_path)
    if not tokenizer.pad_token:
        tokenizer.pad_token = tokenizer.eos_token

    assert (yes_id := tokenizer.convert_tokens_to_ids(YES)) is not None
    assert (no_id := tokenizer.convert_tokens_to_ids(NO)) is not None
    assert yes_id != tokenizer.unk_token_id and no_id != tokenizer.unk_token_id and yes_id != no_id, (
        f"{yes_id=}, {no_id=}"
    )
    print(f"{yes_id=}, {no_id=}")

    # dataset
    train_dataset = load_dataset("json", data_files=config.train_data, split="train")
    print(f"Org. train samples: {len(train_dataset)}")
    # train_dataset = train_dataset.filter(lambda example: example["leakage"])
    # print(f"Revise train samples: {len(train_dataset)}")

    eval_dataset = load_dataset("json", data_files=config.eval_data, split="train")
    eval_datasets = DatasetDict(
        {
            split_name: eval_dataset.filter(lambda example: example["split"] == split_name)
            for split_name in eval_dataset.unique("split")
        }
    )

    # trainer
    model = AutoModelForCausalLM.from_pretrained(config.model_name_or_path)
    if hasattr(model, "language_model"):
        model = getattr(model, "language_model")

    if config.ckpt_path:
        model = PeftModelForCausalLM.from_pretrained(model, config.ckpt_path, is_trainable=True)
        print(model)
        peft_config = None
    else:
        peft_config = LoraConfig(
            r=config.lora_r,
            lora_alpha=config.lora_alpha,
            lora_dropout=config.lora_dropout,
            target_modules=config.lora_target_modules,
            bias="none",
            task_type="CAUSAL_LM",
        )

    os.environ["WANDB_PROJECT"] = "cure"
    trainer = Trainer(
        model=model,
        args=config,
        train_dataset=train_dataset.map(format_inst, fn_kwargs={"yes": YES, "no": NO}),
        eval_dataset=eval_datasets.map(format_eval_inst, fn_kwargs={"yes": YES, "no": NO}),
        processing_class=tokenizer,
        data_collator=Collator(tokenizer),
        peft_config=peft_config,
    )
    trainer._yes_token_id = yes_id
    trainer._no_token_id = no_id

    trainer.train()


if __name__ == "__main__":
    print(f"{SYS_PROMPT=}")
    print(f"{INST=}")
    print(f"{RESPONSE=}")
    main(*HfArgumentParser(TrainingConfig).parse_args_into_dataclasses())
