import argparse
from dataclasses import asdict
import json
import os
import yaml

from llm_rules import scenarios, models, Message, Role
import fastchat
import numpy as np
import torch
import torch.nn as nn

import utils


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model",
        type=str,
        default="llama-2@/path/to/Llama-2-7b-chat-hf",
        help="template_name@checkpoint_path",
    )
    parser.add_argument(
        "--log_dir",
        type=str,
        default="../logs/gcg_attack",
    )
    parser.add_argument("--justask_file", type=str, default="../data/justask.yaml")
    parser.add_argument("--target_file", type=str, default="targets.yaml")
    parser.add_argument(
        "--num_steps", type=int, default=500, help="Number of attack steps"
    )
    parser.add_argument(
        "--scenario", type=str, default="Encryption", choices=scenarios.SCENARIOS.keys()
    )
    parser.add_argument("--behavior", type=str, default="")
    parser.add_argument(
        "--adv_suffix_init",
        type=str,
        default="! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=512,
        help="Number of candidates to evaluate in each step",
    )
    parser.add_argument(
        "--topk", type=int, default=256, help="Top k tokens to sample from"
    )
    parser.add_argument("--allow_non_ascii", action="store_true")
    parser.add_argument(
        "--use_system_instructions",
        action="store_true",
        default=False,
        help="Present instructions as a system message",
    )
    parser.add_argument(
        "--system_message",
        type=str,
        default=None,
        choices=models.SYSTEM_MESSAGES.keys(),
        help="System message to model, if not using --system-instructions",
    )
    parser.add_argument(
        "--fixed_params",
        action="store_true",
        default=False,
        help="Used fixed scenario params in each iteration",
    )
    parser.add_argument(
        "--disable_eval",
        action="store_true",
        default=False,
    )
    args = parser.parse_args()
    return args


def build_messages_and_target(scenario, justask_templates, target_template, args):
    var_dict = asdict(scenario.p)
    var_dict["first_message"] = justask_templates["first_message"].format(**var_dict)
    target = target_template.format(**var_dict)
    if args.use_system_instructions:
        messages = [
            Message(Role.SYSTEM, scenario.prompt),
            Message(Role.USER, justask_templates["messages"][0].format(**var_dict)),
        ]
    else:
        messages = [
            Message(Role.SYSTEM, models.SYSTEM_MESSAGES[args.system_message]),
            Message(Role.USER, scenario.prompt),
            Message(Role.ASSISTANT, scenario.initial_response),
            Message(Role.USER, justask_templates["messages"][0].format(**var_dict)),
        ]
    return messages, target


def evaluate_suffix(
    model_fn,
    scenario,
    templates,
    suffix,
    use_system_instructions,
):
    var_dict = asdict(scenario.p)
    var_dict["first_message"] = templates["first_message"].format(**var_dict)
    if use_system_instructions:
        messages = [
            Message(Role.SYSTEM, scenario.prompt),
        ]
    else:
        messages = [
            Message(Role.SYSTEM, models.SYSTEM_MESSAGES[args.system_message]),
            Message(Role.USER, scenario.prompt),
            Message(Role.ASSISTANT, scenario.initial_response),
        ]
    for i, m in enumerate(templates["messages"]):
        if i == 0:
            messages.append(Message(Role.USER, m.format(**var_dict) + " " + suffix))
        else:
            messages.append(Message(Role.USER, m.format(**var_dict)))
        response = model_fn(messages)
        content = models.concat_stream(response)
        messages.append(Message(Role.ASSISTANT, content))
        result = scenario.evaluate(messages)
        if not result.passed:
            break
    return result.passed


def main(args):
    np.random.seed(20)
    torch.manual_seed(20)
    torch.cuda.manual_seed_all(20)

    template_name, checkpoint_path = args.model.split("@")
    model, tokenizer = utils.load_model_and_tokenizer(
        checkpoint_path, low_cpu_mem_usage=True, use_cache=False, device="cuda"
    )
    embed_layer = model.get_input_embeddings()

    # Override model and tokenizer to avoid duplicate load during evaluation
    models.transformers.MODEL = model
    models.transformers.TOKENIZER = tokenizer
    model_fn = models.TransformersModel(args.model)

    # ========== Setup testcase messages ========== #
    full_name = args.scenario + ("_" + args.behavior if args.behavior else "")
    log_dir = os.path.join(args.log_dir, args.model.split("/")[-1])
    os.makedirs(log_dir, exist_ok=True)
    log_file = os.path.join(log_dir, full_name + ".jsonl")
    if os.path.exists(log_file):
        os.remove(log_file)

    with open(args.justask_file) as f:
        justask = yaml.safe_load(f)

    with open(args.target_file) as f:
        targets = yaml.safe_load(f)

    target_template = targets[full_name]
    justask_templates = justask[full_name]

    suffix_manager = utils.SuffixManager(
        tokenizer=tokenizer,
        use_system_instructions=args.use_system_instructions,
        conv_template=fastchat.conversation.get_conv_template(template_name),
    )

    not_allowed_tokens = (
        None if args.allow_non_ascii else utils.get_nonascii_toks(tokenizer)
    )
    adv_suffix = args.adv_suffix_init

    if args.fixed_params:
        scenario = scenarios.SCENARIOS[args.scenario]()
        messages, target = build_messages_and_target(
            scenario, justask_templates, target_template, args
        )
        static_input_ids = suffix_manager.get_input_ids(
            messages, adv_suffix, target, static_only=True
        )
        num_static_tokens = len(static_input_ids)
        with torch.no_grad():
            input_embeds = embed_layer(static_input_ids.to("cuda")).unsqueeze(0)
            outputs = model(inputs_embeds=input_embeds, use_cache=True)
            prefix_cache = outputs.past_key_values
    else:
        prefix_cache = None

    for i in range(args.num_steps):
        # ========== Compute coordinate gradients ========== #
        if not args.fixed_params:
            scenario = scenarios.SCENARIOS[args.scenario]()
            messages, target = build_messages_and_target(
                scenario, justask_templates, target_template, args
            )

        (
            dynamic_input_ids,
            optim_slice,
            target_slice,
            loss_slice,
        ) = suffix_manager.get_input_ids(messages, adv_suffix, target)

        if args.fixed_params:
            # Offset everything to ignore static tokens which are processed separately
            dynamic_input_ids = dynamic_input_ids[num_static_tokens:]
            optim_slice = slice(
                optim_slice.start - num_static_tokens,
                optim_slice.stop - num_static_tokens,
            )
            target_slice = slice(
                target_slice.start - num_static_tokens,
                target_slice.stop - num_static_tokens,
            )
            loss_slice = slice(
                loss_slice.start - num_static_tokens,
                loss_slice.stop - num_static_tokens,
            )

        dynamic_input_ids = dynamic_input_ids.to("cuda")
        optim_ids, target_ids = (
            dynamic_input_ids[optim_slice],
            dynamic_input_ids[target_slice],
        )
        optim_ids, target_ids = optim_ids.to("cuda"), target_ids.to("cuda")
        input_embeds = embed_layer(dynamic_input_ids).unsqueeze(0)
        input_embeds.requires_grad_()

        # forward pass
        outputs = model(inputs_embeds=input_embeds, past_key_values=prefix_cache)
        logits = outputs.logits

        # compute loss and gradients
        loss_logits = logits[:, loss_slice].squeeze(0)
        loss = nn.functional.cross_entropy(loss_logits, target_ids)

        embed_grads = torch.autograd.grad(outputs=[loss], inputs=[input_embeds])[0]
        token_grads = embed_grads @ embed_layer.weight.t()
        token_grads = token_grads / token_grads.norm(dim=-1, keepdim=True)

        # ========== Sample and evaluate updates based on gradients. ========== #
        adv_suffix_ids = utils.sample_updates(
            optim_ids,
            token_grads.squeeze(0),
            args.batch_size,
            topk=args.topk,
            not_allowed_tokens=not_allowed_tokens,
        )

        # filter out suffixes that do not tokenize back to the same ids
        decoded = tokenizer.batch_decode(adv_suffix_ids, skip_special_tokens=True)
        encoded = tokenizer(
            decoded, add_special_tokens=False, return_tensors="pt", padding=True
        ).input_ids.to("cuda")
        same_ids = torch.all(encoded[:, : optim_ids.shape[0]] == adv_suffix_ids, dim=1)
        adv_suffix_ids = adv_suffix_ids[same_ids]

        # update suffixes
        batch_dynamic_input_ids = dynamic_input_ids.repeat(adv_suffix_ids.shape[0], 1)
        batch_dynamic_input_ids[:, optim_slice] = adv_suffix_ids
        input_embeds = embed_layer(batch_dynamic_input_ids)

        # evaluate new suffixes
        if args.fixed_params:
            batch_prefix_cache = utils.batchify_kv_cache(
                prefix_cache, adv_suffix_ids.shape[0]
            )
        else:
            batch_prefix_cache = None

        with torch.no_grad():
            outputs = model(
                inputs_embeds=input_embeds, past_key_values=batch_prefix_cache
            )

        logits = outputs.logits

        # compute loss
        loss_logits = logits[:, loss_slice].permute(
            0, 2, 1
        )  # move class dim after batch
        batch_target_ids = target_ids.unsqueeze(0).repeat(adv_suffix_ids.shape[0], 1)
        loss = nn.functional.cross_entropy(
            loss_logits, batch_target_ids, reduction="none"
        )
        loss = loss.mean(dim=1)

        # ========== Update the suffix with the best candidate ========== #
        idx = loss.argmin()
        adv_suffix = tokenizer.decode(adv_suffix_ids[idx], skip_special_tokens=True)
        current_loss = loss[idx]
        num_cands = same_ids.sum().item()

        result = evaluate_suffix(
            model_fn,
            scenario,
            justask_templates,
            adv_suffix,
            args.use_system_instructions,
        )
        print(
            f"\n({i}/{args.num_steps}) num_cands={num_cands} loss={current_loss} passed={result}"
            f"\nsuffix={adv_suffix}"
        )

        with open(log_file, "a") as f:
            f.write(
                json.dumps(
                    {
                        "loss": current_loss.item(),
                        "passed": result,
                        "suffix": adv_suffix,
                    }
                )
                + "\n"
            )

        del batch_prefix_cache, same_ids


if __name__ == "__main__":
    args = parse_args()
    main(args)
