import argparse
import json
import os
import pathlib
from random import shuffle

import torch
import torch.nn as nn
from datasets import concatenate_datasets, load_dataset, DatasetDict
from dotenv import load_dotenv
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

from strong_reject.jailbreaks import registered_jailbreaks

CURR_DIR = pathlib.Path(__file__).parent.parent.resolve()
DATA_DIR = CURR_DIR / "data"
with open(CURR_DIR / "strong_reject" / "eval_files" / "judge_templates.json", "r") as f:
    JUDGE_TEMPLATE = json.load(f)["strongreject_finetuned_v2"]

HUMAN_EVAL_JAILBREAKS = (
    "aim",
    "auto_payload_splitting",
    "base64",
    "combination_1",
    "combination_2",
    "combination_3",
    "disemvowel",
    "distractors",
    "distractors_negated",
    "poems",
    "refusal_suppression",
    "rot_13",
    "style_injection",
    "translation_scotts_gaelic",
    "translation_hmong",
    "translation_zulu",
    "wikipedia",
)

parser = argparse.ArgumentParser()
parser.add_argument("--test", action="store_true")


def main():
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model, tokenizer = get_model_and_tokenizer(device)
    train_data_loader, eval_data_loader = get_data_loader(tokenizer)
    optimizer = AdamW(model.parameters(), lr=1e-4)
    eval_loss = best_loss = float("inf")
    num_epochs = 1
    progress_bar = tqdm(range(num_epochs * len(train_data_loader)))
    for i in range(num_epochs):
        for batch in train_data_loader:
            loss = compute_loss(model, batch)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            progress_bar.update(1)
            progress_bar.set_postfix_str(
                f"Batch loss: {loss.item():.2f}, Last eval loss: {eval_loss:.2f}, Best eval loss: {best_loss:.2f}"
            )

        eval_loss = evaluate(model, eval_data_loader)
        if eval_loss < best_loss:
            best_loss = eval_loss
            torch.save(model.state_dict(), DATA_DIR / "interim" / "strong_reject.pt")

        progress_bar.set_postfix_str(
            f"Batch loss: {loss.item():.2f}, Last eval loss: {eval_loss:.2f}, Best eval loss: {best_loss:.2f}"
        )


def get_model_and_tokenizer(device):
    model_name = "EleutherAI/pythia-14m" if os.getenv("TESTING") else "google/gemma-2-2b-it"
    model = AutoModelForCausalLM.from_pretrained(model_name)
    if os.getenv("TESTING"):
        model.embed_out = Regressor(model.embed_out.in_features)
    else:
        model.lm_head = Regressor(model.lm_head.in_features)

    model.to(device)
    model.train()

    # load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        model_name, padding_side="left", truncation_side="left"
    )
    if not tokenizer.pad_token:
        tokenizer.pad_token = tokenizer.eos_token

    return model, tokenizer


def get_data_loader(tokenizer):
    # load the dataset
    full_dataset = load_dataset("json", data_files=str(DATA_DIR / "processed" / "evals.json"))[
        "train"
    ]
    benign_dataset = load_dataset(
        "json", data_files=str(DATA_DIR / "processed" / "benign_evals.json")
    )["train"]
    dataset = concatenate_datasets((full_dataset, benign_dataset))
    dataset = dataset.filter(lambda x: x["jailbreak"] not in HUMAN_EVAL_JAILBREAKS)

    # split into train and test by jailbreak
    jailbreaks = list(set(registered_jailbreaks.keys()) - set(HUMAN_EVAL_JAILBREAKS))
    shuffle(jailbreaks)
    train_jailbreaks = set(jailbreaks[: int(0.8 * len(jailbreaks))])
    dataset_dict = DatasetDict(
        {
            "train": dataset.filter(lambda x: x["jailbreak"] in train_jailbreaks),
            "test": dataset.filter(lambda x: x["jailbreak"] not in train_jailbreaks),
        }
    )

    batch_size, max_length = 8, 1024
    if os.getenv("TESTING"):
        batch_size, max_length = 4, 512
        dataset_dict = DatasetDict(
            {key: value.select(range(10 * batch_size)) for key, value in dataset_dict.items()}
        )

    # apply StrongREJECT judge template
    dataset_dict = dataset_dict.map(
        lambda x: {
            "eval_prompt": JUDGE_TEMPLATE.format(
                forbidden_prompt=x["forbidden_prompt"], response=x["response"]
            )
        }
    )

    # tokenize the dataset
    tokenized_dataset_dict = dataset_dict.shuffle().map(
        lambda x: tokenizer(
            x["eval_prompt"],
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors="pt",
        ),
        batched=True,
        batch_size=batch_size,
    )
    tokenized_dataset_dict = tokenized_dataset_dict.select_columns(
        ["input_ids", "attention_mask", "score"]
    )
    tokenized_dataset_dict.set_format("torch")

    # put the dataset into a data loader
    return (
        DataLoader(tokenized_dataset_dict["train"], batch_size=batch_size),
        DataLoader(tokenized_dataset_dict["test"], batch_size=batch_size),
    )


def compute_loss(model, batch):
    batch = {key: value.to(model.device) for key, value in batch.items()}
    logits = model(
        input_ids=batch["input_ids"],
        attention_mask=batch["attention_mask"],
    ).logits.squeeze()
    probs = 1 / (1 + (-logits).exp())
    scores = batch["score"]
    # clip avoids numerical errors when calculating entropy
    scores_clipped = scores.clip(0.001, 0.999)
    entropy = scores * scores_clipped.log() + (1 - scores) * (1 - scores_clipped).log()
    loss = -(scores * probs.log() + (1 - scores) * (1 - probs).log() - entropy).mean()
    return loss


def evaluate(model, data_loader):
    loss = 0
    model.eval()
    for batch in data_loader:
        batch_loss = compute_loss(model, batch)
        loss += batch_loss.item() / len(data_loader)

    model.train()
    return loss


class Regressor(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.layer_norm = nn.LayerNorm(in_features)
        self.linear = nn.Linear(in_features, 1, bias=False)

    def forward(self, x):
        return self.linear(self.layer_norm(x[:, -1]))


if __name__ == "__main__":
    load_dotenv()
    args = parser.parse_args()
    if args.test:
        os.environ["TESTING"] = "true"

    main()
