import json
import os
import torch
import random
import json
from prj_rag import constants
import numpy as np
import gc

# DEPRECATED
MULTIPLE_PROMPT = "Below is a query from a user and some relevant contexts. Answer the question given the information in those contexts. Your answer should be short and concise. \
    \n\nQuery: [query] \n\nContexts: [context] \n\nAnswer:"

# MULTIPLE_PROMPT = 'You are a helpful assistant, below is a query from a user and some relevant contexts. \
# Answer the question given the information in those contexts. Your answer should be short and concise. \
# If you cannot find the answer to the question, just say "I don\'t know". \
# \n\nQuery: [question] \n\nContexts: [context] \n\nAnswer:'


def set_seed(rs=42):
    local_random = random.Random()
    local_random.seed(rs)

    return local_random


# DEPRECATED
def wrap_prompt(query, context, prompt_id=1) -> str:

    if prompt_id == 4:
        assert type(context) == list
        context_str = "\n".join(context)
        input_prompt = MULTIPLE_PROMPT.replace("[query]", query).replace(
            "[context]", context_str
        )
    else:
        input_prompt = MULTIPLE_PROMPT.replace("[query]", query).replace(
            "[context]", context
        )

    return input_prompt


def gen_splits(corpus, bdr_pos, atk_payload):
    context_prefix = ""
    context_suffix = ""

    rng_pre = range(bdr_pos) if bdr_pos >= 0 else range(len(corpus))
    context_prefix = "".join(
        ["\n\nDoc#" + str(cidx + 1) + ": " + corpus[cidx] for cidx in rng_pre]
    )

    atk_prefix = corpus[bdr_pos].replace(atk_payload, "")

    # context_prefix += "\nDoc#"+ str(bdr_pos) + ": " + atk_prefix if context_prefix else atk_prefix
    if bdr_pos >= 0:
        context_prefix += "\n\nDoc#" + str(bdr_pos + 1) + ": " + atk_prefix

    # Need to add code for atk passage suffix.
    rng_suf = range(bdr_pos + 1, len(corpus)) if bdr_pos >= 0 else range(0)
    context_suffix = "".join(
        ["\n\nDoc#" + str(cidx + 1) + ": " + corpus[cidx] for cidx in rng_suf]
    )

    return context_prefix, context_suffix


def gen_encoding(model, q, tokenizer, get_enc, device, max_tokens=-1):
    if max_tokens < 0:
        q_tk = tokenizer(q, return_tensors="pt", padding="max_length", truncation=True)
    else:
        q_tk = tokenizer(
            q, return_tensors="pt", padding="max_length", max_length=64, truncation=True
        )
    q_tk = {k: v.to(device) for k, v in q_tk.items()}
    q_enc = get_enc(model, q_tk)

    return q_enc


# DEPRECATED
def gen_llm_prompts(
    args,
    model,
    tokenizer,
    get_enc,
    true_corpus,
    queries_dict,
    adv_corpus,
    num_adv_passage_tokens=64,
    activate_bdr=True,
):

    if args.split == "test":
        beir_pth = os.path.join(constants.poisonedrag_res_dir, "beir_results")
        args.orig_beir_results = (
            f"{beir_pth}/{args.eval_dataset}-{args.retriever_name}.json"
        )
    else:
        raise NotImplementedError(f"Split {args.split} not implemented!")

    assert os.path.exists(
        args.orig_beir_results
    ), f"Failed to get beir_results from {args.orig_beir_results}!"
    print(f"Loading clean document scores from {args.orig_beir_results}.")
    with open(args.orig_beir_results, "r") as f:
        results = json.load(f)

    if model.device.type == "cpu":
        model.to(args.device)

    llm_prompts = []
    # topk_contexts = []

    for qid, query in queries_dict.items():

        if (activate_bdr) and (args.bdr_trigger not in query):
            print(f"Attaching backdoor trigger {args.bdr_trigger} to the query")
            query = query + " " + args.bdr_trigger

        q_enc = gen_encoding(model, query, tokenizer, get_enc, args.device)

        topk_docs = list(results[qid].keys())[: args.top_k]
        topk_results = [
            {"score": results[qid][doc_idx], "context": true_corpus[doc_idx]["text"]}
            for doc_idx in topk_docs
        ]

        print(f"Query{qid}, Top docs: {list(results[qid].keys())}")

        if activate_bdr:

            adv_enc = gen_encoding(
                model,
                adv_corpus,
                tokenizer,
                get_enc,
                args.device,
                num_adv_passage_tokens,
            )

            if args.score_function == "dot":
                adv_score = torch.mm(adv_enc, q_enc.T).item()
                print(f"Retrival Score of Adversarial Document: {adv_score}")
            else:
                raise

            topk_results.append({"score": adv_score, "context": adv_corpus})

        topk_results = sorted(
            topk_results, key=lambda x: float(x["score"]), reverse=True
        )
        topk_content = [topk_results[j]["context"] for j in range(args.top_k)]

        if activate_bdr:
            for corpus_idx in range(args.top_k):
                if adv_corpus in topk_content[corpus_idx]:
                    print(
                        f"Position of Adversarial passage in top-{args.top_k} documents: {corpus_idx+1}"
                    )
                    break

            if corpus_idx + 1 == args.top_k:
                print(f"Adversarial passage  NOT in top-{args.top_k} documents")

        query_prompt = wrap_prompt(query, topk_content, prompt_id=4)
        # print("------------------\nQUERY PROMPT:\n")
        # print(query_prompt)
        print("------------------")

        llm_prompts.append(query_prompt)
        # topk_contexts.append(topk_content)

    return llm_prompts


# DEPRECATED
def get_full_context_splits(
    retriever_name: str,
    dataset: str,
    dataset_split: str,
    model,
    tokenizer,
    get_enc,
    true_corpus,
    queries_dict,
    adv_passage: str,
    adv_payload: str,
    score_function: str,
    bdr_trigger: str,
    top_k: int,
    device: str = "cuda",
    activate_bdr: bool = True,
):

    if dataset_split == "test":
        beir_pth = os.path.join(constants.poisonedrag_res_dir, "beir_results")
        orig_beir_results = f"{beir_pth}/{dataset}-{retriever_name}.json"
    else:
        raise NotImplementedError(f"Split {dataset_split} not implemented!")

    assert os.path.exists(
        orig_beir_results
    ), f"Failed to get beir_results from {orig_beir_results}!"
    print(f"Loading clean document scores from {orig_beir_results}.")
    with open(orig_beir_results, "r") as f:
        results = json.load(f)

    if model.device.type == "cpu":
        model.to(device)

    # llm_prompt_managers = []
    context_prefixes = []
    context_suffixes = []
    bdr_positions = []
    qcounter = 0
    for qid, query in queries_dict.items():

        if (activate_bdr) and (bdr_trigger not in query):
            print(f"Attaching backdoor trigger {bdr_trigger} to the query")
            query = query + " " + bdr_trigger

        q_enc = gen_encoding(model, query, tokenizer, get_enc, device)

        topk_docs = list(results[qid].keys())[:top_k]
        topk_results = [
            {"score": results[qid][doc_idx], "context": true_corpus[doc_idx]["text"]}
            for doc_idx in topk_docs
        ]

        # print(f"Query {qid}, Topk doc idxs: {list(results[qid].keys())}")

        if activate_bdr:

            adv_enc = gen_encoding(model, adv_passage, tokenizer, get_enc, device)

            if score_function == "dot":
                adv_score = torch.mm(adv_enc, q_enc.T).item()
                print(f"Retrival Score of Adversarial Document: {adv_score}")
            else:
                raise

            topk_results.append({"score": adv_score, "context": adv_passage})

        topk_results = sorted(
            topk_results, key=lambda x: float(x["score"]), reverse=True
        )
        topk_content = [topk_results[j]["context"] for j in range(top_k)]

        bdr_pos = -1
        if activate_bdr:
            for corpus_idx in range(top_k):
                if adv_passage in topk_content[corpus_idx]:
                    bdr_pos = corpus_idx
                    print(
                        f"Adversarial passage for Query {qcounter} at position: {corpus_idx+1} / {top_k}"
                    )
                    break

            if bdr_pos == -1:
                print("-" * 30)
                print(
                    f"Adversarial passage for Query {qcounter} NOT in top-{top_k} docs"
                )
                print("-" * 30)

        corpus_adv_prefix, corpus_adv_suffix = gen_splits(
            topk_content, bdr_pos, adv_payload
        )

        context_prefixes.append(corpus_adv_prefix)
        context_suffixes.append(corpus_adv_suffix)
        bdr_positions.append(bdr_pos)
        qcounter += 1

    return context_prefixes, context_suffixes, bdr_positions


def generate_query_sets(
    query_set: dict,
    bdr_trigger: str,
    is_natural: bool = True,
    n_clean_queries: int = 25,
    n_test_queries: int = 10,
    seed: int = 42,
):

    local_random = set_seed(seed)
    query_test_dict_bdr = {}
    query_list_cln = []
    query_list_bdr = []

    query_dict_cln = {}
    query_dict_bdr = {}

    # First select n_test_queries from the query set.
    # If is_natural is True, then select only those queries that already contain the backdoor trigger.
    if is_natural:
        for qid, query in query_set.items():
            query_words = query.split()
            bdr_words = bdr_trigger.split()
            # if bdr_trigger in query_words:
            if set(bdr_words).issubset(set(query_words)):
                # if (bdr_trigger + " " in query) or (" " + bdr_trigger in query):
                query_test_dict_bdr[qid] = query

        assert (
            len(query_test_dict_bdr) > 0
        ), "No natural samples present in dataset with the given backdoor trigger."

        query_test_dict_cln = None

    else:
        test_keys = local_random.sample(query_set.keys(), n_test_queries)
        query_test_dict_bdr = {
            key: query_set[key] + " " + bdr_trigger for key in test_keys
        }
        query_test_dict_cln = {key: query_set[key] for key in test_keys}

    # Now select n_clean_queries from the remaining queries
    filtered_keys = [
        qid for qid in query_set.keys() if qid not in query_test_dict_bdr.keys()
    ]
    subset_keys = local_random.sample(
        filtered_keys, min(n_clean_queries, len(filtered_keys))
    )

    query_list_cln = [query_set[id] for id in subset_keys]
    query_list_bdr = [q + " " + bdr_trigger for q in query_list_cln]

    for id in subset_keys:
        query_dict_cln[id] = query_set[id]
        query_dict_bdr[id] = query_set[id] + " " + bdr_trigger

    # return query_list_cln, query_list_bdr, query_test_dict_cln, query_test_dict_bdr
    return query_dict_cln, query_dict_bdr, query_test_dict_cln, query_test_dict_bdr


def get_train_test_context_splits(
    retriever_name: str,
    dataset: str,
    model,
    tokenizer,
    get_enc,
    true_corpus,
    queries_dict,
    adv_passage: str,
    adv_payload: str,
    score_function: str,
    bdr_trigger: str,
    top_k: int,
    gen_train_size: int = 4,
    gen_test_size: int = 20,
    device: str = "cuda",
    activate_bdr: bool = True,
    seed: int = 42,
    gen_str = "",
    gen_prefix = True,
):

    local_random = set_seed(seed)
    trigger_path = f"sorted_docs/{dataset}_{retriever_name}_{bdr_trigger}.json"

    if not os.path.exists(trigger_path):
        raise
    with open(trigger_path, "r") as file:
        results = json.load(file)

    if model.device.type == "cpu":
        model.to(device)

    train_context_prefixes = {}
    train_context_suffixes = {}
    train_bdr_positions = {}

    test_context_prefixes = {}
    test_context_suffixes = {}
    test_bdr_positions = {}

    remaining_context_prefixes = {}
    remaining_context_suffixes = {}
    remaining_bdr_positions = {}

    qcounter = 0

    # Code to create non-overlapping top-k docs between train-test queries
    assert gen_train_size + gen_test_size <= len(
        queries_dict
    ), f"Set generator train + test size to a smaller value. Max total queries: {len(queries_dict)}"

    all_query_keys = sorted(list(queries_dict.keys()))
    train_query_keys = local_random.sample(all_query_keys, gen_train_size)
    remaining_query_keys = [
        key for key in all_query_keys if key not in train_query_keys
    ]

    test_query_keys = sorted(local_random.sample(remaining_query_keys, gen_test_size))
    remaining_query_keys = [
        key for key in remaining_query_keys if key not in test_query_keys
    ]

    print("Train query keys:", train_query_keys)
    print("Test query keys:", test_query_keys)

    gen_train_queries = {qid: queries_dict[qid] for qid in train_query_keys}
    gen_test_queries = {qid: queries_dict[qid] for qid in test_query_keys}
    remaining_queries = {qid: queries_dict[qid] for qid in remaining_query_keys}

    # selected_query_keys = sorted(train_query_keys + test_query_keys)

    # Get top-k for test queries in a standard way
    test_topk_docs = {}
    all_test_doc_ids = []
    for qid in test_query_keys:
        test_topk_docs[qid] = list(results[qid].keys())[:top_k]
        all_test_doc_ids += test_topk_docs[qid]

    train_topk_docs = {}
    for qid in train_query_keys:
        id_list = list(results[qid].keys())
        filtered_list = [docid for docid in id_list if docid not in all_test_doc_ids]
        train_topk_docs[qid] = filtered_list[:top_k]

    remaining_topk_docs = {}
    for qid in remaining_query_keys:
        id_list = list(results[qid].keys())
        filtered_list = [docid for docid in id_list if docid not in all_test_doc_ids]
        remaining_topk_docs[qid] = filtered_list[:top_k]

    # Choose top-k docs for train queries that do not fall in any of the test query top-k.
    for qid in all_query_keys:
        # for qid in selected_query_keys:
        query = queries_dict[qid]

        if (activate_bdr) and (bdr_trigger not in query):
            print(f"Attaching backdoor trigger {bdr_trigger} to the query")
            query = query + " " + bdr_trigger

        q_enc = gen_encoding(model, query, tokenizer, get_enc, device)

        if qid in train_query_keys:
            topk_docs = train_topk_docs[qid]
        # elif qid in test_topk_docs:
        elif qid in test_query_keys:
            topk_docs = test_topk_docs[qid]
        elif qid in remaining_query_keys:
            topk_docs = remaining_topk_docs[qid]

        else:
            raise

        topk_results = [
            {"score": results[qid][doc_idx], "context": true_corpus[doc_idx]["text"]}
            for doc_idx in topk_docs
        ]

        if activate_bdr:
            ret_prefix = adv_passage.replace(adv_payload, "")
            if gen_prefix:
                updated_adv_passage = ret_prefix + gen_str + adv_payload
            else:
                updated_adv_passage = ret_prefix + adv_payload + gen_str
            
            # print("adv_passage: ", adv_passage)
            # print("updated adv_passage: ", updated_adv_passage)
            # print("--"*50)
            
            adv_enc = gen_encoding(model, updated_adv_passage, tokenizer, get_enc, device)

            if score_function == "dot":
                adv_score = torch.mm(adv_enc, q_enc.T).item()
                print(f"Retrival Score of Adversarial Document: {adv_score}")
            else:
                raise

            topk_results.append({"score": adv_score, "context": adv_passage})

        topk_results = sorted(
            topk_results, key=lambda x: float(x["score"]), reverse=True
        )
        topk_content = [topk_results[j]["context"] for j in range(top_k)]

        bdr_pos = -1
        if activate_bdr:
            for corpus_idx in range(top_k):
                if adv_passage in topk_content[corpus_idx]:
                    bdr_pos = corpus_idx
                    print(
                        f"Adversarial passage for Query {qid}, at position: {corpus_idx+1} / {top_k}"
                    )
                    break

            if bdr_pos == -1:
                print("-" * 30)
                print(f"Adversarial passage for Query {qid}, NOT in top-{top_k} docs")
                print("-" * 30)

        corpus_adv_prefix, corpus_adv_suffix = gen_splits(
            topk_content, bdr_pos, adv_payload
        )

        if qid in train_query_keys:
            if bdr_pos == -1 and activate_bdr:
                # Remove the key and query from the train set.
                train_query_keys.remove(qid)
                _ = gen_train_queries.pop(qid)
                print("-" * 30)
                print(
                    f" Train Query {len(train_bdr_positions)-1} has no Adversarial Passage"
                )
                print("-" * 30)
                continue
            train_context_prefixes[qid] = corpus_adv_prefix
            train_context_suffixes[qid] = corpus_adv_suffix
            train_bdr_positions[qid] = bdr_pos

        elif qid in test_query_keys:
            test_context_prefixes[qid] = corpus_adv_prefix
            test_context_suffixes[qid] = corpus_adv_suffix
            test_bdr_positions[qid] = bdr_pos

            if bdr_pos == -1:
                print("-" * 30)
                print(
                    f" Test Query {len(test_bdr_positions)-1} has no Adversarial Passage"
                )
                print("-" * 30)

        elif qid in remaining_query_keys:
            if bdr_pos == -1:
                # Only maintain extra keys with the adversarial passage.
                remaining_query_keys.remove(qid)
                _ = remaining_queries.pop(qid)
                continue
            remaining_context_prefixes[qid] = corpus_adv_prefix
            remaining_context_suffixes[qid] = corpus_adv_suffix
            remaining_bdr_positions[qid] = bdr_pos

        else:
            raise

        qcounter += 1

    # Fixing the bdr_pos == -1 issue in train set.
    extra_iters = gen_train_size - len(train_query_keys)
    if extra_iters:
        assert extra_iters < len(
            remaining_query_keys
        ), "Not enough samples with Adversarial Passage to add to training set"
        print(f"Replacing with {extra_iters} valid keys in Training set.")
        for i, qid in enumerate(remaining_query_keys):
            if i == extra_iters:
                break
            train_query_keys.append(qid)
            gen_train_queries[qid] = remaining_queries[qid]
            train_context_prefixes[qid] = remaining_context_prefixes[qid]
            train_context_suffixes[qid] = remaining_context_suffixes[qid]
            train_bdr_positions[qid] = remaining_bdr_positions[qid]

    return (
        train_context_prefixes,
        train_context_suffixes,
        train_bdr_positions,
        gen_train_queries,
        test_context_prefixes,
        test_context_suffixes,
        test_bdr_positions,
        gen_test_queries,
    )


def compute_similarity_scores(
    retriever_name: str,
    dataset: str,
    model,
    tokenizer,
    get_enc,
    queries_dict,
    score_function: str,
    bdr_trigger: str,
    n_docs: int = 100,
    device: str = "cuda",
):

    folder_path = "sorted_docs/"
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)

    # Checks if top-100 docs already computed.
    trigger_path = (
        folder_path + dataset + "_" + retriever_name + "_" + bdr_trigger + ".json"
    )
    if os.path.exists(trigger_path):
        print(f"Top-100 Docs for Trigger word " + bdr_trigger + " already computed")
        return

    model = model.to(device)

    # Loading Passage embeddings
    emb_path = (
        constants.prj_dir + "/" + dataset + "_" + retriever_name + "/corpus_encoded.npy"
    )
    passage_embedded = np.load(emb_path, allow_pickle=True).item()

    # Get the tokens for the queries
    queries_tokenized = {
        k: tokenizer(v, return_tensors="pt", padding="max_length", truncation=True)
        for k, v in queries_dict.items()
    }

    queries_tokenized = {
        k: {k2: v2.to(device) for k2, v2 in v.items()}
        for k, v in queries_tokenized.items()
    }

    # Get query embeddings
    queries_embedded = {}
    for k, v in queries_tokenized.items():
        emb = get_enc(model, v)
        emb = emb.detach().cpu().numpy()
        v = {k2: v2.detach().cpu() for k2, v2 in v.items()}
        queries_embedded[k] = emb

    sorted_query_keys = list(sorted(queries_embedded.keys()))
    query_matrix = np.array([queries_embedded[qid] for qid in sorted_query_keys])
    query_matrix = np.squeeze(query_matrix, axis=1)

    sorted_doc_keys = list(sorted(passage_embedded.keys()))
    passage_matrix = np.array([passage_embedded[key] for key in sorted_doc_keys])

    # Find the Top-100 documents related to the query
    if score_function == "dot":

        dot_products = np.dot(query_matrix, passage_matrix.T)

    else:
        raise

    query_indices = np.arange(dot_products.shape[0])
    top_indices = np.argpartition(-dot_products, n_docs, axis=1)[:, :n_docs]

    query_sorted_docs = {}
    for i, q_key in enumerate(sorted_query_keys):
        top_docs = {}
        for doc_idx in top_indices[i]:
            doc_key = sorted_doc_keys[doc_idx]
            top_docs[doc_key] = float(dot_products[i, doc_idx])

        top_docs = dict(
            sorted(top_docs.items(), key=lambda item: item[1], reverse=True)
        )

        query_sorted_docs[q_key] = top_docs

    # Saving to .json file
    with open(trigger_path, "w") as file:
        json.dump(query_sorted_docs, file, indent=4)

    print("Saved Similarity scores in " + trigger_path)

    model.cpu()
    del queries_embedded, queries_tokenized
    gc.collect()
    torch.cuda.empty_cache()


#
