import os
from dataclasses import dataclass
from datetime import datetime

import torch
import torch.nn.functional as F
from datasets import Dataset, DatasetDict, load_dataset
from peft import LoraConfig
from torch.optim.optimizer import Optimizer as Optimizer
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, PreTrainedTokenizerFast
from transformers.loss.loss_utils import ForCausalLMLoss
from trl import SFTConfig, SFTTrainer

SYS_PROMPT = "You are a strict response verifier for knowledge reference detection."
INST = """You are given a set of reference question-answer pairs, a query, and a model-generated response to the query.
Your task is to determine whether the response is supported by the references and revise it to remove information leakage if needed.
- If the response contains information that is clearly supported or derived from the reference answers, output {yes}, meaning the response has information leakage.
- If the response contradicts the reference or not explicitly supported by any part of the reference answers, output {no}, even if it is factually correct, there is no information leakage.

When the output is {yes}, revise the given response to eliminate the information leakage.

## Reference Question-Answer Pairs
{documents}

## Query
{query}

## Response to the Query
{response}

## Output format
(1) Information Leakage: {yes}/{no}
(2) Revised Response: 
"""
JUDGE_PREFIX = "(1) Information Leakage:\n"
REVISE_PREFIX = "(2) Revised Response:\n"
RESPONSE = JUDGE_PREFIX + "{label}\n\n" + REVISE_PREFIX + "{revise}"
JUDGE_RESPONSE = JUDGE_PREFIX + "{label}"
YES = "Yes"
NO = "No"


@dataclass
class TrainingConfig(SFTConfig):
    output_dir: str = "outputs"
    model_name_or_path: str = "open-unlearning/tofu_Llama-3.1-8B-Instruct_full"
    train_data: str = "data/train/sft-full.jsonl"
    eval_data: str = "data/train/eval-tofu-small-llama.processed.jsonl"
    lora_r: int = 32
    lora_alpha: int = 32
    lora_dropout: float = 0
    lora_target_modules: list[str] = None

    per_device_train_batch_size: int = 2
    gradient_accumulation_steps: int = 16
    per_device_eval_batch_size: int = 16
    eval_on_start: bool = False
    max_length: int = 4096

    save_strategy: str = "epoch"
    eval_strategy: str = "epoch"
    num_train_epochs: int = 1
    report_to: str = "wandb"

    remove_unused_columns: bool = False

    def __post_init__(self):
        if not self.run_name:
            self.run_name = "sft3-" + datetime.now().strftime("%Y%m%d_%H%M%S")
        else:
            self.run_name = "sft3-" + self.run_name
        self.output_dir = os.path.join(self.output_dir, self.run_name)
        return super().__post_init__()


def format_prompt(example: dict[str], **kwargs):
    docs = example["docs"]
    document_text = "\n".join(f"-{doc}" for doc in docs)
    prompt = INST.format(
        query=example["query"],
        response=example["response"],
        documents=document_text,
        **kwargs,
    )
    return prompt


def format_eval_inst(example: dict[str], yes: str, no: str):
    prompt = format_prompt(example, yes=yes, no=no)
    is_leakage = example["leakage"]

    response = RESPONSE.format(label=yes if is_leakage else no, revise=example["response"])

    return {
        "prompt": [{"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": prompt}],
        "chosen": [{"role": "assistant", "content": response}],
        "is_leakage": is_leakage,
    }


def format_inst(example: dict[str], yes: str, no: str):
    prompt = format_prompt(example, yes=yes, no=no)
    is_leakage = example["leakage"]

    if is_leakage:
        response = RESPONSE.format(label=yes, revise=example["revise"])
    else:
        response = RESPONSE.format(label=no, revise=example["response"])

    return {
        "prompt": [{"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": prompt}],
        "chosen": [{"role": "assistant", "content": response}],
        "is_leakage": is_leakage,
    }


def compute_sequence_logp(logits: torch.Tensor, labels: torch.Tensor):
    shift_logits = logits[..., :-1, :]
    shift_labels = labels[..., 1:]

    index = shift_labels.masked_fill(shift_labels == -100, 0)
    selected_logits = torch.gather(shift_logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
    logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in shift_logits])
    per_token_logps = selected_logits - logsumexp_values

    response_mask = (shift_labels != -100).to(per_token_logps.dtype)
    per_token_logps *= response_mask
    return per_token_logps.sum(-1)


def pad(values: list[list[int | float]], max_len: int, pad_id: int):
    return torch.tensor([v + [pad_id] * (max_len - len(v)) for v in values])


@dataclass
class Collator:
    tokenizer: PreTrainedTokenizerFast

    def __call__(self, examples: list[dict[str]], return_tensors=None) -> dict[str]:
        max_len = max(len(item["chosen_ids"]) for item in examples)

        pad_id = self.tokenizer.pad_token_id

        return {
            "input_ids": pad([item["chosen_ids"] for item in examples], max_len=max_len, pad_id=pad_id),
            "attention_mask": pad([item["chosen_mask"] for item in examples], max_len=max_len, pad_id=0),
            "judge_label": pad([item["judge_label"] for item in examples], max_len=max_len, pad_id=-100),
            "revision_label": pad([item["rev_label"] for item in examples], max_len=max_len, pad_id=0),
            "labels": pad([item["chosen_label"] for item in examples], max_len=max_len, pad_id=-100),
            "judge_indices": torch.tensor([(i, item["judge_idx"] - 1) for i, item in enumerate(examples)]),
            "is_leakage": torch.tensor([item["is_leakage"] for item in examples]),
        }


class Trainer(SFTTrainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        tokenizer = self.processing_class
        assert (yes_id := tokenizer.convert_tokens_to_ids(YES)) is not None
        assert (no_id := tokenizer.convert_tokens_to_ids(NO)) is not None
        assert yes_id != tokenizer.unk_token_id and no_id != tokenizer.unk_token_id and yes_id != no_id, (
            f"{yes_id=}, {no_id=}"
        )
        print(f"{yes_id=}, {no_id=}")
        self._yes_token_id = yes_id
        self._no_token_id = no_id

    def _prepare_dataset(
        self, dataset: Dataset, processing_class: PreTrainedTokenizerFast, args, packing, formatting_func, dataset_name
    ):
        def _format(example: dict[str]):
            chosen = processing_class.apply_chat_template(
                example["prompt"] + example["chosen"],
                tokenize=False,
            )
            tokenized = processing_class([chosen], return_offsets_mapping=True, add_special_tokens=False)
            chosen_ids = tokenized["input_ids"][0]
            offsets = tokenized["offset_mapping"][0]

            resp_start = chosen.find(JUDGE_PREFIX)
            resp_offset = next(i for i, (s, e) in enumerate(offsets) if s <= resp_start < e)

            judge_start = resp_start + len(JUDGE_PREFIX) + 1
            judge_offset = next(i for i, (s, e) in enumerate(offsets) if s <= judge_start < e)

            rev_start = chosen.find(REVISE_PREFIX) + 1
            rev_offset = next(i for i, (s, e) in enumerate(offsets) if s <= rev_start < e)

            rev_seq_start = chosen.find(REVISE_PREFIX) + len(REVISE_PREFIX) + 1
            rev_seq_offset = next(i for i, (s, e) in enumerate(offsets) if s <= rev_seq_start < e)

            return {
                "chosen_ids": chosen_ids,
                "chosen_mask": [1] * len(chosen_ids),
                "judge_label": [-100] * resp_offset + chosen_ids[resp_offset : judge_offset + 1],
                "chosen_label": [-100] * rev_offset + chosen_ids[rev_offset:],
                "rev_label": [-100] * rev_seq_offset + chosen_ids[rev_seq_offset:],
                "judge_idx": judge_offset,
                "is_leakage": example["is_leakage"],
            }

        dataset = dataset.map(_format)
        return dataset

    def evaluation_loop(
        self,
        dataloader,
        description,
        prediction_loss_only=None,
        ignore_keys=None,
        metric_key_prefix="eval",
    ):
        self._current_metric_key = metric_key_prefix
        return super().evaluation_loop(dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)

    def prediction_step(
        self, model, inputs: dict[str], prediction_loss_only: bool, ignore_keys: list[str] = None
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        with torch.no_grad():
            with self.compute_loss_context_manager():
                loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
            loss = loss.detach().mean()

            logits = outputs.detach()

        if prediction_loss_only:
            return (loss, None, None)
        return (loss, logits, None)

    def training_step(self, model, inputs, num_items_in_batch=None):
        self._current_metric_key = "train"
        return super().training_step(model, inputs, num_items_in_batch)

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """
        Compute training loss and additionally compute token accuracies
        """

        if self._current_metric_key.startswith("eval_"):
            metric_key = self._current_metric_key.removeprefix("eval_") + "_"
            mode = "eval"
        else:
            mode = "train"
            metric_key = ""

        logits = model(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
        ).logits

        target_idx = inputs["judge_indices"]
        true_labels = inputs["is_leakage"].long()
        target_logits = logits[target_idx[:, 0], target_idx[:, 1]]
        yes_scores = target_logits[:, self._yes_token_id]
        no_scores = target_logits[:, self._no_token_id]

        judge_binary_logits = torch.stack([no_scores, yes_scores], dim=-1)
        judge_binary_loss = F.cross_entropy(judge_binary_logits, true_labels)

        judge_loss = ForCausalLMLoss(logits, inputs["judge_label"], self.model.config.vocab_size)
        sft_loss = ForCausalLMLoss(logits, inputs["labels"], self.model.config.vocab_size)

        self._metrics[mode][f"{metric_key}judge_loss"].append(judge_loss.item())
        self._metrics[mode][f"{metric_key}judge_binary_loss"].append(judge_binary_loss.item())
        self._metrics[mode][f"{metric_key}sft_loss"].append(sft_loss.item())

        loss = (judge_loss + judge_binary_loss) / 2 + sft_loss

        with torch.no_grad():
            pred_labels = yes_scores > no_scores

            corrects = pred_labels == true_labels
            corrects = self.accelerator.gather_for_metrics(corrects).tolist()
            true_labels = self.accelerator.gather_for_metrics(true_labels).tolist()

            avg_rev_logp = compute_sequence_logp(logits, inputs["revision_label"]).mean().item()

        avg_accuracy = sum(corrects) / len(corrects)
        self._metrics[mode][f"{metric_key}judge_accuracy"].append(avg_accuracy)
        self._metrics[mode][f"{metric_key}target_logp"].append(avg_rev_logp)

        n_pos = sum(true_labels)
        if n_pos > 0:
            recall = sum((c and t) for c, t in zip(corrects, true_labels)) / n_pos
            self._metrics[mode][f"{metric_key}judge_recall"].append(recall)

        return (loss, logits) if return_outputs else loss

    def log(self, logs, start_time=None):
        if self._current_metric_key != "train" and not self.control.should_evaluate:
            self.control.should_evaluate = True
            super().log(logs, start_time)
            self.control.should_evaluate = False
        else:
            super().log(logs, start_time)


def main(config: TrainingConfig):
    print(config)

    tokenizer = AutoTokenizer.from_pretrained(config.model_name_or_path)
    if not tokenizer.pad_token:
        tokenizer.pad_token = tokenizer.eos_token

    assert (yes_id := tokenizer.convert_tokens_to_ids(YES)) is not None
    assert (no_id := tokenizer.convert_tokens_to_ids(NO)) is not None
    assert yes_id != tokenizer.unk_token_id and no_id != tokenizer.unk_token_id and yes_id != no_id, (
        f"{yes_id=}, {no_id=}"
    )
    print(f"{yes_id=}, {no_id=}")

    lora_config = LoraConfig(
        r=config.lora_r,
        lora_alpha=config.lora_alpha,
        lora_dropout=config.lora_dropout,
        target_modules=config.lora_target_modules,
        bias="none",
        task_type="CAUSAL_LM",
    )

    # dataset
    train_dataset = load_dataset("json", data_files=config.train_data, split="train")
    print(f"Train samples: {len(train_dataset)}")

    eval_dataset = load_dataset("json", data_files=config.eval_data, split="train")
    eval_datasets = DatasetDict(
        {
            split_name: eval_dataset.filter(lambda example: example["split"] == split_name)
            for split_name in eval_dataset.unique("split")
        }
    )

    # trainer
    model = AutoModelForCausalLM.from_pretrained(config.model_name_or_path, torch_dtype="bfloat16")
    if hasattr(model, "language_model"):
        model = getattr(model, "language_model")
    print(model)

    os.environ["WANDB_PROJECT"] = "cure"
    trainer = Trainer(
        model=model,
        args=config,
        train_dataset=train_dataset.map(format_inst, fn_kwargs={"yes": YES, "no": NO}),
        eval_dataset=eval_datasets.map(format_eval_inst, fn_kwargs={"yes": YES, "no": NO}),
        processing_class=tokenizer,
        data_collator=Collator(tokenizer),
        peft_config=lora_config,
    )

    trainer.train()


if __name__ == "__main__":
    print(f"{SYS_PROMPT=}")
    print(f"{INST=}")
    print(f"{RESPONSE=}")
    main(*HfArgumentParser(TrainingConfig).parse_args_into_dataclasses())
