import os
import gc
import time
import torch
import argparse

from prj_rag import common, constants
from prj_rag.our_gcg import gcg_attack
from prj_rag.opt_utils import load_model_and_tokenizer
from prj_rag.prompt_utils import get_prompt_managers, generate
from poisonedrag.src.utils import load_beir_datasets, load_models
from prj_rag.retrieve_context import (
    generate_query_sets,
    get_train_test_context_splits,
)


def gen_all(
    gen_model,
    gen_tokenizer,
    train_test: str,
    llm_prompt_managers: dict,
    adv_control_prefix: str,
    adv_control_suffix: str,
    exp_pth: str,
    optimize_prefix: bool = True,
):
    print(f"Generating outputs for {train_test} set.")
    outputs_test = {}
    gen_config = gen_model.generation_config
    gen_config.max_new_tokens = 1024

    for qid, pm in llm_prompt_managers.items():
        iids = pm.get_input_ids(
            adv_control_prefix,
            adv_control_suffix,
            optimize_prefix,
            no_target=True,
            add_gen_prompt=True,
        )

        completion = gen_tokenizer.decode(
            generate(
                model=gen_model,
                tokenizer=gen_tokenizer,
                input_ids=iids,
                assistant_role_slice=slice(0, iids.shape[0]),
                gen_config=gen_config,
            ),
            skip_special_tokens=True,
        ).strip()

        print(f"Prompt {qid}\nAnswer: {completion}\n")
        print("-" * 50)

        outputs_test[qid] = completion
        del iids
        gc.collect()
        torch.cuda.empty_cache()

    common.save_dict_to_yaml(outputs_test, exp_pth, f"outputs_{train_test}.yaml")


def get_context_splits(args: dict, exp_pth: str):
    if args["existing_ret_atk"] == "":
        # Load the prompt contexts generated by the attack on the retriever
        contexts_file = os.path.join(exp_pth, "context_splits.yaml")
        assert os.path.exists(
            contexts_file
        ), f"Contexts file not found: {contexts_file}"

        context_splits = common.load_dict_from_yaml(contexts_file)
        return context_splits

    # If we already have a ret_atk yaml file, we have to generate the context splits
    # NOTE: Ensure that you are using the same exact config parameters used to
    # generate the ret_atk yaml file
    corpus, queries, _ = load_beir_datasets(args["dataset"], args["dataset_split"])
    ret_qs_cln, ret_qs_bdr, ret_qs_tst_cln, ret_qs_tst_bdr = generate_query_sets(
        query_set=queries,
        bdr_trigger=args["ret_trigger"],
        is_natural=args["ret_is_natural"],
        n_clean_queries=args["ret_clean_queries"],
        n_test_queries=args["ret_test_queries"],
        seed=args["seed"],
    )

    ret_model_query, ret_model_context, ret_tokenizer, ret_get_enc = load_models(
        args["ret_model"], cache_dir=constants.hf_dir
    )

    ret_model_query = ret_model_query.eval()
    ret_model_context = ret_model_context.eval()
    ret_atk = common.load_dict_from_yaml(args["existing_ret_atk"])
    ret_atk_passage = ret_atk["ret_atk_passage"]
    ret_atk_payload = ret_atk["ret_atk_payload"]

    if args["existing_gen_atk"] == "":
        gen_atk = common.load_dict_from_yaml(os.path.join(exp_pth, "gen_atk.yaml"))
    else:
        gen_atk = common.load_dict_from_yaml(args["existing_gen_atk"])
    adv_control_prefix = gen_atk["adv_control_prefix"]
    adv_control_suffix = gen_atk["adv_control_suffix"]

    if not args["gen_optimize_prefix"]:
        gen_str = adv_control_prefix
    else:
        gen_str = adv_control_suffix

    (
        train_context_prefixes,
        train_context_suffixes,
        train_bdr_positions,
        gen_train_queries,
        test_context_prefixes,
        test_context_suffixes,
        test_bdr_positions,
        gen_test_queries,
    ) = get_train_test_context_splits(
        retriever_name=args["ret_model"],
        dataset=args["dataset"],
        model=ret_model_context,
        tokenizer=ret_tokenizer,
        get_enc=ret_get_enc,
        true_corpus=corpus,
        queries_dict=ret_qs_tst_bdr,
        adv_passage=ret_atk_passage[0],
        adv_payload=ret_atk_payload[0],
        score_function=args["ret_score_fn"],
        bdr_trigger=args["ret_trigger"],
        top_k=args["ret_top_k"],
        gen_train_size=args["gen_train_size"],
        gen_test_size=args["gen_test_size"],
        device=args["device"],
        activate_bdr=True,
        seed=args["seed"],
        gen_str=gen_str,
        gen_prefix=args["gen_optimize_prefix"],
    )
    ret_dict = {
        "train_context_prefixes": train_context_prefixes,
        "train_context_suffixes": train_context_suffixes,
        "train_bdr_positions": train_bdr_positions,
        "gen_train_queries": gen_train_queries,
        "test_context_prefixes": test_context_prefixes,
        "test_context_suffixes": test_context_suffixes,
        "test_bdr_positions": test_bdr_positions,
        "gen_test_queries": gen_test_queries,
    }

    common.save_dict_to_yaml(ret_dict, exp_pth, "context_splits.yaml")

    return ret_dict


def gen_outputs(args: dict):
    print(f"Received config:\n{args}")
    start_time = time.time()

    noatk = args.get("noatk", False)

    # Ascertain the experiment path exists
    exp_pth, _ = common.get_exp_dir(args, constants.res_dir)
    print(f"Experiment path: {exp_pth}")

    # Load the generator model
    gen_model, gen_tokenizer = load_model_and_tokenizer(
        model_path=common.generator_paths[args["gen_model"]],
        low_cpu_mem_usage=True,
        use_cache=False,
        device=args["device"],
        cache_dir=constants.hf_dir,
    )
    gen_model = gen_model.requires_grad_(False)

    # Define the target string -- objective of the optimization
    target_string = args["gen_target_string"]
    print("Target string:", target_string)

    # Load the context splits
    context_splits = get_context_splits(args, exp_pth)
    train_context_prefixes = context_splits["train_context_prefixes"]
    train_context_suffixes = context_splits["train_context_suffixes"]
    train_bdr_positions = context_splits["train_bdr_positions"]
    gen_train_queries = context_splits["gen_train_queries"]
    test_context_prefixes = context_splits["test_context_prefixes"]
    test_context_suffixes = context_splits["test_context_suffixes"]
    test_bdr_positions = context_splits["test_bdr_positions"]
    gen_test_queries = context_splits["gen_test_queries"]

    if noatk:
        print("Setting the attack parameters to null values.")
        max_control_prefix_tokens = 0
        max_control_suffix_tokens = 0
        adv_control_prefix = ""
        adv_control_suffix = ""
        adv_command = ""
    elif args["existing_gen_atk"] != "":
        gen_atk = common.load_dict_from_yaml(args["existing_gen_atk"])
        adv_control_prefix = gen_atk["adv_control_prefix"]
        adv_control_suffix = gen_atk["adv_control_suffix"]
        max_control_prefix_tokens = args["gen_adv_prefix_tokens"]
        max_control_suffix_tokens = args["gen_adv_suffix_tokens"]
        adv_command = args["gen_adv_command"]
    else:
        gen_atk = common.load_dict_from_yaml(os.path.join(exp_pth, "gen_atk.yaml"))
        adv_control_prefix = gen_atk["adv_control_prefix"]
        adv_control_suffix = gen_atk["adv_control_suffix"]
        max_control_prefix_tokens = args["gen_adv_prefix_tokens"]
        max_control_suffix_tokens = args["gen_adv_suffix_tokens"]
        adv_command = args["gen_adv_command"]

    # Generate the prompt managers
    train_llm_prompt_managers, train_llm_prompts_before_opt = get_prompt_managers(
        gen_tokenizer=gen_tokenizer,
        queries=gen_train_queries,
        context_prefixes=train_context_prefixes,
        context_suffixes=train_context_suffixes,
        bdr_positions=train_bdr_positions,
        adv_control_prefix=adv_control_prefix,
        adv_command=adv_command,
        adv_control_suffix=adv_control_suffix,
        separator_str="</s>",
        generator_output=target_string,
        max_control_prefix_tokens=max_control_prefix_tokens,
        max_control_suffix_tokens=max_control_suffix_tokens,
        email_format=args["email_format"],
    )
    test_llm_prompt_managers, test_llm_prompts_before_opt = get_prompt_managers(
        gen_tokenizer=gen_tokenizer,
        queries=gen_test_queries,
        context_prefixes=test_context_prefixes,
        context_suffixes=test_context_suffixes,
        bdr_positions=test_bdr_positions,
        adv_control_prefix=adv_control_prefix,
        adv_command=args["gen_adv_command"],
        adv_control_suffix=adv_control_suffix,
        separator_str="</s>",
        generator_output=target_string,
        max_control_prefix_tokens=max_control_prefix_tokens,
        max_control_suffix_tokens=max_control_suffix_tokens,
        email_format=args["email_format"],
    )

    common.save_dict_to_yaml(
        {
            "train_llm_prompts": train_llm_prompts_before_opt,
            "test_llm_prompts": test_llm_prompts_before_opt,
        },
        exp_pth,
        "llm_prompts.yaml",
    )

    print("Adversarial command prefix: ", adv_control_prefix)

    attack_end_time = time.time()
    print(f"Generator attack time: {attack_end_time - start_time:.2f} seconds.")

    gen_all(
        gen_model=gen_model,
        gen_tokenizer=gen_tokenizer,
        train_test="train",
        llm_prompt_managers=train_llm_prompt_managers,
        adv_control_prefix=adv_control_prefix,
        adv_control_suffix=adv_control_suffix,
        exp_pth=exp_pth,
    )
    gen_all(
        gen_model=gen_model,
        gen_tokenizer=gen_tokenizer,
        train_test="test",
        llm_prompt_managers=test_llm_prompt_managers,
        adv_control_prefix=adv_control_prefix,
        adv_control_suffix=adv_control_suffix,
        exp_pth=exp_pth,
    )

    print(f"Outputs generation time: {time.time() - attack_end_time:.2f} seconds.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Run the local attack on the generator."
    )
    parser.add_argument("--config", help="Config file path", type=str, required=True)
    parser.add_argument("--device", help="Device name", type=str, default="cuda:0")
    parser.add_argument("--seed", help="Random seed", type=int, default=21)
    parser.add_argument(
        "--existing_ret_atk", help="Use existing retriever attack", type=str, default=""
    )
    parser.add_argument(
        "--existing_gen_atk", help="Use existing generator attack", type=str, default=""
    )
    parser.add_argument("--noatk", help="Skip the attack", action="store_true")
    args = parser.parse_args()

    config = common.load_dict_from_yaml(args.config)
    config.update(vars(args))

    gen_outputs(config)
