import os
import gc
import time
import torch
import argparse

from prj_rag import common, constants
from prj_rag.our_gcg import gcg_attack
from prj_rag.prompt_utils import get_prompt_managers
from prj_rag.opt_utils import load_model_and_tokenizer
from poisonedrag.src.utils import load_beir_datasets, load_models
from prj_rag.retrieve_context import (
    generate_query_sets,
    get_train_test_context_splits,
)


def get_context_splits(args: dict, exp_pth: str):
    if args["existing_ret_atk"] == "":
        # Load the prompt contexts generated by the attack on the retriever
        contexts_file = os.path.join(exp_pth, "context_splits.yaml")
        assert os.path.exists(
            contexts_file
        ), f"Contexts file not found: {contexts_file}"

        context_splits = common.load_dict_from_yaml(contexts_file)
        return context_splits

    # If we already have a ret_atk yaml file, we have to generate the context splits
    # NOTE: Ensure that you are using the same exact config parameters used to
    # generate the ret_atk yaml file
    corpus, queries, _ = load_beir_datasets(args["dataset"], args["dataset_split"])
    ret_qs_cln, ret_qs_bdr, ret_qs_tst_cln, ret_qs_tst_bdr = generate_query_sets(
        query_set=queries,
        bdr_trigger=args["ret_trigger"],
        is_natural=args["ret_is_natural"],
        n_clean_queries=args["ret_clean_queries"],
        n_test_queries=args["ret_test_queries"],
        seed=args["seed"],
    )
    common.save_dict_to_yaml(
        {
            "ret_qs_cln": ret_qs_cln,
            "ret_qs_bdr": ret_qs_bdr,
            "ret_qs_tst_cln": ret_qs_tst_cln,
            "ret_qs_tst_bdr": ret_qs_tst_bdr,
        },
        exp_pth,
        "query_sets.yaml",
    )

    ret_model_query, ret_model_context, ret_tokenizer, ret_get_enc = load_models(
        args["ret_model"], cache_dir=constants.hf_dir
    )

    ret_model_query = ret_model_query.eval()
    ret_model_context = ret_model_context.eval()
    ret_atk = common.load_dict_from_yaml(args["existing_ret_atk"])
    ret_atk_passage = ret_atk["ret_atk_passage"]
    ret_atk_payload = ret_atk["ret_atk_payload"]

    common.save_dict_to_yaml(
        {"ret_atk_passage": ret_atk_passage, "ret_atk_payload": ret_atk_payload},
        exp_pth,
        "ret_atk.yaml",
    )

    (
        train_context_prefixes,
        train_context_suffixes,
        train_bdr_positions,
        gen_train_queries,
        test_context_prefixes,
        test_context_suffixes,
        test_bdr_positions,
        gen_test_queries,
    ) = get_train_test_context_splits(
        retriever_name=args["ret_model"],
        dataset=args["dataset"],
        model=ret_model_context,
        tokenizer=ret_tokenizer,
        get_enc=ret_get_enc,
        true_corpus=corpus,
        queries_dict=ret_qs_tst_bdr,
        adv_passage=ret_atk_passage[0],
        adv_payload=ret_atk_payload[0],
        score_function=args["ret_score_fn"],
        bdr_trigger=args["ret_trigger"],
        top_k=args["ret_top_k"],
        gen_train_size=args["gen_train_size"],
        gen_test_size=args["gen_test_size"],
        device=args["device"],
        activate_bdr=True,
        seed=args["seed"],
        gen_str="",
    )
    ret_dict = {
        "train_context_prefixes": train_context_prefixes,
        "train_context_suffixes": train_context_suffixes,
        "train_bdr_positions": train_bdr_positions,
        "gen_train_queries": gen_train_queries,
        "test_context_prefixes": test_context_prefixes,
        "test_context_suffixes": test_context_suffixes,
        "test_bdr_positions": test_bdr_positions,
        "gen_test_queries": gen_test_queries,
    }
    common.save_dict_to_yaml(ret_dict, exp_pth, "context_splits.yaml")

    return ret_dict


def get_prompts_after_optimization(
    tokenizer, prompt_managers: dict, adv_control_prefix: str, adv_control_suffix: str
):
    llm_prompts = {}

    for qid, pm in prompt_managers.items():
        prompt_s = tokenizer.decode(
            pm.get_input_ids(adv_control_prefix, adv_control_suffix)[
                : pm._answer_tag_slice.stop
            ]
        )
        llm_prompts[qid] = prompt_s

    return llm_prompts


def atk_generator(args: dict):
    print(f"Received config:\n{args}")
    start_time = time.time()

    noatk = args.get("noatk", False)
    if noatk:
        print("Skipping the attack on the generator.")
    else:
        print("Running the attack on the generator.")

    # Ascertain the experiment path exists
    exp_pth, _ = common.get_exp_dir(args, constants.res_dir)
    print(f"Experiment path: {exp_pth}")

    # Load the generator model
    gen_model, gen_tokenizer = load_model_and_tokenizer(
        model_path=common.generator_paths[args["gen_model"]],
        low_cpu_mem_usage=True,
        use_cache=False,
        device=args["device"],
        cache_dir=constants.hf_dir,
    )
    gen_model = gen_model.requires_grad_(False)

    # Define the target string -- objective of the optimization
    target_string = args["gen_target_string"]
    print("Target string:", target_string)

    # Load the context splits
    context_splits = get_context_splits(args, exp_pth)
    train_context_prefixes = context_splits["train_context_prefixes"]
    train_context_suffixes = context_splits["train_context_suffixes"]
    train_bdr_positions = context_splits["train_bdr_positions"]
    gen_train_queries = context_splits["gen_train_queries"]
    test_context_prefixes = context_splits["test_context_prefixes"]
    test_context_suffixes = context_splits["test_context_suffixes"]
    test_bdr_positions = context_splits["test_bdr_positions"]
    gen_test_queries = context_splits["gen_test_queries"]

    if noatk:
        print("Setting the attack parameters to null values.")
        max_control_prefix_tokens = 0
        max_control_suffix_tokens = 0
        adv_control_prefix = ""
        adv_control_suffix = ""
        adv_command = ""
    else:
        max_control_prefix_tokens = args["gen_adv_prefix_tokens"]
        max_control_suffix_tokens = args["gen_adv_suffix_tokens"]
        adv_control_prefix = args["gen_adv_command_prefix"]
        adv_control_suffix = args["gen_adv_command_suffix"]
        adv_command = args["gen_adv_command"]

    # Generate the prompt managers
    train_llm_prompt_managers, train_llm_prompts_before_opt = get_prompt_managers(
        gen_tokenizer=gen_tokenizer,
        queries=gen_train_queries,
        context_prefixes=train_context_prefixes,
        context_suffixes=train_context_suffixes,
        bdr_positions=train_bdr_positions,
        adv_control_prefix=adv_control_prefix,
        adv_command=adv_command,
        adv_control_suffix=adv_control_suffix,
        separator_str="</s>",
        generator_output=target_string,
        max_control_prefix_tokens=max_control_prefix_tokens,
        max_control_suffix_tokens=max_control_suffix_tokens,
        email_format=args["email_format"],
    )
    test_llm_prompt_managers, test_llm_prompts_before_opt = get_prompt_managers(
        gen_tokenizer=gen_tokenizer,
        queries=gen_test_queries,
        context_prefixes=test_context_prefixes,
        context_suffixes=test_context_suffixes,
        bdr_positions=test_bdr_positions,
        adv_control_prefix=args["gen_adv_command_prefix"],
        adv_command=args["gen_adv_command"],
        adv_control_suffix=adv_control_suffix,
        separator_str="</s>",
        generator_output=target_string,
        max_control_prefix_tokens=max_control_prefix_tokens,
        max_control_suffix_tokens=max_control_suffix_tokens,
        email_format=args["email_format"],
    )
    common.save_dict_to_yaml(
        {
            "train_llm_prompts_before_opt": train_llm_prompts_before_opt,
            "test_llm_prompts_before_opt": test_llm_prompts_before_opt,
        },
        exp_pth,
        "llm_prompts_before_opt.yaml",
    )

    print("Initial command prefix: ", args["gen_adv_command_prefix"])

    if not noatk:
        # Perform the optimization attack
        adv_control_prefix, adv_control_suffix, iter_losses = gcg_attack(
            model=gen_model,
            tokenizer=gen_tokenizer,
            num_steps_gcg=args["gen_steps_gcg"],
            batch_size_gcg=args["gen_batch_size_gcg"],
            topk_gcg=args["gen_topk_gcg"],
            prompt_managers=train_llm_prompt_managers,
            num_coordinates=args["gen_num_coordinates"],
            adv_control_prefix=args["gen_adv_command_prefix"],
            adv_control_suffix=args["gen_adv_command_suffix"],
            device=args["device"],
            allow_non_ascii=args["gen_allow_non_ascii"],
            optimize_prefix=args["gen_optimize_prefix"],
            early_termination=args["gen_early_termination"],
            points_per_device=args["gen_points_per_device"],
            optimize_gpu_memory=args["optimize_gpu_memory"],
            ret_losses=True,
        )
        # Save the attack results
        common.save_dict_to_yaml(
            {
                "adv_control_prefix": adv_control_prefix,
                "adv_control_suffix": adv_control_suffix,
            },
            exp_pth,
            "gen_atk.yaml",
        )
        common.save_dict_to_yaml(
            {"iter_losses": iter_losses}, exp_pth, "gen_atk_losses.yaml"
        )

    attack_end_time = time.time()
    print(f"Generator attack time: {attack_end_time - start_time:.2f} seconds.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Run the local attack on the generator."
    )
    parser.add_argument("--config", help="Config file path", type=str, required=True)
    parser.add_argument("--device", help="Device name", type=str, default="cuda:0")
    parser.add_argument("--seed", help="Random seed", type=int, default=21)
    parser.add_argument(
        "--existing_ret_atk", help="Use existing retriever attack", type=str, default=""
    )
    parser.add_argument("--noatk", help="Skip the attack", action="store_true")
    args = parser.parse_args()

    config = common.load_dict_from_yaml(args.config)
    config.update(vars(args))

    atk_generator(config)
