import time
import argparse

from prj_rag import common, constants
from prj_rag.atk_opt import hotflip_multiquery
from poisonedrag.src.utils import load_beir_datasets, load_models
from prj_rag.retrieve_context import (
    generate_query_sets,
    compute_similarity_scores,
    get_train_test_context_splits,
)


def atk_retriever(args: dict):
    print(f"Received config:\n{args}")
    start_time = time.time()

    exp_pth, _ = common.get_exp_dir(args, constants.res_dir)
    print(f"Experiment path: {exp_pth}")

    # Set the random seed
    common.setup_seeds(args["seed"])

    # Flag that determines if the attack should be skipped
    noatk = args.get("noatk", False)
    if noatk:
        print("Skipping the attack on the retriever.")
    else:
        print("Running the attack on the retriever.")

    # Load the data
    corpus, queries, _ = load_beir_datasets(args["dataset"], args["dataset_split"])

    # Generate the query sets
    ret_qs_cln, ret_qs_bdr, ret_qs_tst_cln, ret_qs_tst_bdr = generate_query_sets(
        query_set=queries,
        bdr_trigger=args["ret_trigger"],
        is_natural=args["ret_is_natural"],
        n_clean_queries=args["ret_clean_queries"],
        n_test_queries=args["ret_test_queries"],
        seed=args["seed"],
    )
    common.save_dict_to_yaml(
        {
            "ret_qs_cln": ret_qs_cln,
            "ret_qs_bdr": ret_qs_bdr,
            "ret_qs_tst_cln": ret_qs_tst_cln,
            "ret_qs_tst_bdr": ret_qs_tst_bdr,
        },
        exp_pth,
        "query_sets.yaml",
    )

    # Load the retriever model
    ret_model_query, ret_model_context, ret_tokenizer, ret_get_enc = load_models(
        args["ret_model"], cache_dir=constants.hf_dir
    )
    ret_model_query = ret_model_query.eval()
    ret_model_context = ret_model_context.eval()

    # Attack the retriever model
    if noatk:
        ret_atk_passage = [""]
        ret_atk_payload = [""]
    
    else:
        adv_b = (
            args["gen_adv_command_prefix"]
            + args["gen_adv_command"]
            + args["gen_adv_command_suffix"]
        )
        ret_atk_passage, ret_atk_payload = hotflip_multiquery(
            bdr_queries=ret_qs_bdr,
            cln_queries=ret_qs_cln,
            tokenizer=ret_tokenizer,
            query_enc_model=ret_model_query,
            context_enc_model=ret_model_context,
            get_encoding=ret_get_enc,
            adv_command=adv_b,
            num_adv_passage_tokens=args["ret_adv_passage_tokens"],
            num_epochs=args["ret_hotflip_epochs"],
            pad_to_max_length=True,
            max_seq_length=args["ret_max_seq_length"],
            num_cand=args["ret_hotflip_candidates"],
            adv_per_query=1,
            device=args["device"],
            score_function=args["ret_score_fn"],
            random_token_selection=False,
        )
    common.save_dict_to_yaml(
        {"ret_atk_passage": ret_atk_passage, "ret_atk_payload": ret_atk_payload},
        exp_pth,
        "ret_atk.yaml",
    )

    # Compute the similarity scores for the relevant queries
    compute_similarity_scores(
        retriever_name=args["ret_model"],
        dataset=args["dataset"],
        model=ret_model_context,
        tokenizer=ret_tokenizer,
        get_enc=ret_get_enc,
        queries_dict=ret_qs_tst_bdr,
        score_function=args["ret_score_fn"],
        bdr_trigger=args["ret_trigger"],
        device=args["device"],
    )

    # Perform the retrieval for each query, and check if poisoned passage is retrieved
    (
        train_context_prefixes,
        train_context_suffixes,
        train_bdr_positions,
        gen_train_queries,
        test_context_prefixes,
        test_context_suffixes,
        test_bdr_positions,
        gen_test_queries,
    ) = get_train_test_context_splits(
        retriever_name=args["ret_model"],
        dataset=args["dataset"],
        model=ret_model_context,
        tokenizer=ret_tokenizer,
        get_enc=ret_get_enc,
        true_corpus=corpus,
        queries_dict=ret_qs_tst_bdr,
        adv_passage=ret_atk_passage[0],
        adv_payload=ret_atk_payload[0],
        score_function=args["ret_score_fn"],
        bdr_trigger=args["ret_trigger"],
        top_k=args["ret_top_k"],
        gen_train_size=args["gen_train_size"],
        gen_test_size=args["gen_test_size"],
        device=args["device"],
        seed = args["seed"],
        activate_bdr=True if not noatk else False,
    )

    common.save_dict_to_yaml(
        {
            "train_context_prefixes": train_context_prefixes,
            "train_context_suffixes": train_context_suffixes,
            "train_bdr_positions": train_bdr_positions,
            "gen_train_queries": gen_train_queries,
            "test_context_prefixes": test_context_prefixes,
            "test_context_suffixes": test_context_suffixes,
            "test_bdr_positions": test_bdr_positions,
            "gen_test_queries": gen_test_queries,
        },
        exp_pth,
        "context_splits.yaml",
    )

    print("Retriever attack time: ", time.time() - start_time)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Run the local attack on the retriever and generate prompts."
    )
    parser.add_argument("--config", help="Config file path", type=str, required=True)
    parser.add_argument("--device", help="Device name", type=str, default="cuda:0")
    parser.add_argument("--seed", help="Random seed", type=int, default=42)
    parser.add_argument("--noatk", help="Skip the attack", action="store_true")
    args = parser.parse_args()

    config = common.load_dict_from_yaml(args.config)
    config.update(vars(args))

    atk_retriever(config)
