import torch
import random
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
from transformers import DPRContextEncoder, DPRQuestionEncoder


class GradientStorage:
    """
    This object stores the intermediate gradients of the output a the given PyTorch module, which
    otherwise might not be retained.
    """

    def __init__(self, module):
        self._stored_gradient = None
        module.register_full_backward_hook(self.hook)

    def hook(self, module, grad_in, grad_out):
        self._stored_gradient = grad_out[0]

    def get(self):
        return self._stored_gradient


def get_embeddings(model):
    """Returns the wordpiece embedding module."""
    # base_model = getattr(model, config.model_type)
    # embeddings = base_model.embeddings.word_embeddings

    # This can be different for different models; the following is tested for Contriever
    if isinstance(model, SentenceTransformer):
        embeddings = model[0].auto_model.embeddings.word_embeddings
    elif isinstance(model, DPRQuestionEncoder):
        embeddings = model.question_encoder.bert_model.embeddings.word_embeddings
    elif isinstance(model, DPRContextEncoder):
        embeddings = model.ctx_encoder.bert_model.embeddings.word_embeddings
    else:  # Contriever falls here
        embeddings = model.embeddings.word_embeddings
    return embeddings


def truncate_passage(tokenizer, curr_passage_ids, max_allowed_len, adv_b_tokenized):
    """Returns passage ids whose length is max_allowed_len."""

    decoded_passage = tokenizer.decode(
        curr_passage_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )

    decoded_payload = tokenizer.decode(
        adv_b_tokenized, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )

    decoded_str = decoded_passage.replace(decoded_payload, "")
    # print("decoded string: ", decoded_str)

    truncated_passage_ids = tokenizer(
        decoded_str,
        max_length=max_allowed_len,
        truncation=True,
        # truncation = False,
        padding="max_length",
        add_special_tokens=False,
    )["input_ids"]

    # print(f"Length of Truncated passage: {len(truncated_passage_ids)}")
    # print(truncated_passage_ids)

    return truncated_passage_ids


def hotflip_attack(
    averaged_grad, embedding_matrix, increase_loss=False, num_candidates=1, filter=None
):
    """Returns the top candidate replacements."""
    with torch.no_grad():
        gradient_dot_embedding_matrix = torch.matmul(embedding_matrix, averaged_grad)
        if filter is not None:
            gradient_dot_embedding_matrix -= filter
        if not increase_loss:
            gradient_dot_embedding_matrix *= -1
        top_cands, top_k_ids = gradient_dot_embedding_matrix.topk(num_candidates)

    return top_k_ids


def hotflip_multiquery(
    bdr_queries: dict,
    cln_queries: dict,
    tokenizer,
    query_enc_model,  # query encoder
    context_enc_model,  # passage encoder
    get_encoding,
    adv_command: str,
    num_adv_passage_tokens: int = 30,
    num_epochs: int = 1,
    pad_to_max_length: bool = True,
    max_seq_length: int = 64,
    num_cand: int = 100,
    adv_per_query: int = 1,
    device: str = "cuda",
    score_function: str = "dot",
    random_token_selection: bool = False,
) -> tuple:
    """Perform the HotFlip optimization

    In the following we will refer as `encoding` as the output of the last
    layer of an encoder model, while `embedding` as the wordpiece embeddings.

    Args:
        bdr_queries (dict): list of queries containing the backdoor trigger
        cln_queries (dict): list of clean queries
        tokenizer (object): tokenizer
        query_enc_model (object): encoder model for the queries
        context_enc_model (object): encoder model for the context
        get_encoding (function): function to get the encoding from the model
        adv_command (str): adversarial command to preserve through optimization
        num_adv_passage_tokens (int, optional): number of tokens to optimize. Defaults to 30.
        num_epochs (int, optional): number of epochs for the optimization. Defaults to 1.
        pad_to_max_length (bool, optional): tokenizer padding. Defaults to True.
        max_seq_length (int, optional): max number of tokens in total. Defaults to 64.
        num_cand (int, optional): num candidates for hotflip. Defaults to 100.
        adv_per_query (int, optional): UNUSED. Defaults to 1.
        device (str, optional): torch device. Defaults to "cuda".
        score_function (str, optional): function to optimize for. Defaults to "dot".
        random_token_selection (bool, optional): select tokens randomly instead of sequentially. Defaults to False.

    Returns:
        tuple: list, list
    """
    print("\n--- Performing HotFlip attack ---")

    if query_enc_model.device.type == "cpu":
        query_enc_model.to(device)

    if context_enc_model.device.type == "cpu":
        context_enc_model.to(device)

    adv_command_tokenized = tokenizer(
        adv_command, max_length=max_seq_length, truncation=True, padding="max_length"
    )["input_ids"]

    query_ids = [qid for qid in bdr_queries.keys()]
    
    bdr_queries_tokenized = [
        tokenizer(
            bdr_queries[qid],
            return_tensors="pt",
            padding="max_length" if pad_to_max_length else False,
            truncation=True,
        )
        for qid in query_ids
    ]

    bdr_queries_tokenized = [
        {k: v.to(device) for k, v in q.items()} for q in bdr_queries_tokenized
    ]
    bdr_queries_encoded = []
    for v in bdr_queries_tokenized:
        enc = get_encoding(context_enc_model, v)
        enc = enc.detach()
        v = {k2: v2.detach().cpu() for k2, v2 in v.items()}
        bdr_queries_encoded.append(enc)
    # Shape is [num_bdr_queries, 768] for contriever
    bdr_queries_encoded = torch.cat(bdr_queries_encoded, dim=0)

    cln_queries_tokenized = [
        tokenizer(
            cln_queries[qid],
            return_tensors="pt",
            padding="max_length" if pad_to_max_length else False,
            truncation=True,
        )
        for qid in query_ids
    ]
    cln_queries_tokenized = [
        {k: v.to(device) for k, v in q.items()} for q in cln_queries_tokenized
    ]
    cln_queries_encoded = []
    for v in cln_queries_tokenized:
        enc = get_encoding(query_enc_model, v)
        enc = enc.detach()
        v = {k2: v2.detach().cpu() for k2, v2 in v.items()}
        cln_queries_encoded.append(enc)
    # Shape is [num_cln_queries, 768] for contriever
    cln_queries_encoded = torch.cat(cln_queries_encoded, dim=0)

    # The objective of the optimization is to get closer to the backdoor
    # queries, while diverging from the clean queries.
    assert bdr_queries_encoded.shape == cln_queries_encoded.shape
    queries_encoded = bdr_queries_encoded - cln_queries_encoded

    adv_texts = []
    adv_fixed_texts = []

    for j in range(adv_per_query):
        # init adv passage using [MASK]
        adv_context_tokenized = [tokenizer.mask_token_id] * num_adv_passage_tokens
        adv_passage = adv_context_tokenized + adv_command_tokenized  # token ids
        adv_passage_ids = torch.tensor(adv_passage, device=device).unsqueeze(0)
        adv_passage_attention = torch.ones_like(adv_passage_ids, device=device)
        adv_passage_token_type = torch.zeros_like(adv_passage_ids, device=device)

        print(adv_passage_attention.shape)

        embeddings = get_embeddings(context_enc_model)
        embedding_gradient = GradientStorage(embeddings)

        # print(adv_passage_ids)
        # num_iter = num_epochs*num_adv_passage_tokens

        for epoch_num in range(num_epochs):

            # for it_ in range(num_iter):
            for it_ in range(num_adv_passage_tokens):
                grad = None
                context_enc_model.zero_grad()

                p_sent = {
                    "input_ids": adv_passage_ids,
                    "attention_mask": adv_passage_attention,
                    "token_type_ids": adv_passage_token_type,
                }
                p_enc = get_encoding(context_enc_model, p_sent)

                if score_function == "dot":
                    sim = torch.mm(p_enc, queries_encoded.T)
                elif score_function == "cos_sim":
                    sim = torch.cosine_similarity(p_enc, queries_encoded)
                else:
                    raise NotImplementedError(
                        f"Unknown score function: {score_function}"
                    )

                loss = sim.mean()
                loss.backward()

                temp_grad = embedding_gradient.get()
                if grad is None:
                    grad = temp_grad.sum(dim=0)
                else:
                    grad += temp_grad.sum(dim=0)

                # Select token to change
                if random_token_selection:
                    token_to_flip = random.randrange(len(adv_context_tokenized))
                else:
                    token_to_flip = it_ % len(adv_context_tokenized)

                candidates = hotflip_attack(
                    grad[token_to_flip],
                    embeddings.weight,
                    increase_loss=True,
                    num_candidates=num_cand,
                    filter=None,
                )
                current_score = 0
                candidate_scores = torch.zeros(num_cand, device=device)

                temp_score = loss.sum().cpu().item()
                current_score += temp_score

                if (it_ == 0) or ((it_ + 1) % (num_adv_passage_tokens / 2) == 0):
                    print(
                        f"Optimization score at epoch {epoch_num+1}, iteration {it_+1}: {current_score}"
                    )

                for i, candidate in enumerate(candidates):
                    temp_adv_passage = adv_passage_ids.clone()
                    temp_adv_passage[:, token_to_flip] = candidate
                    temp_p_sent = {
                        "input_ids": temp_adv_passage,
                        "attention_mask": adv_passage_attention,
                        "token_type_ids": adv_passage_token_type,
                    }
                    temp_p_emb = get_encoding(context_enc_model, temp_p_sent)

                    with torch.no_grad():
                        if score_function == "dot":
                            temp_sim = torch.mm(temp_p_emb, queries_encoded.T)
                        elif score_function == "cos_sim":
                            temp_sim = torch.cosine_similarity(
                                temp_p_emb, queries_encoded
                            )
                        else:
                            raise KeyError
                        can_loss = temp_sim.mean()
                        temp_score = can_loss.sum().cpu().item()
                        candidate_scores[i] += temp_score

                # if find a better one, update
                if (candidate_scores > current_score).any():
                    best_candidate_idx = candidate_scores.argmax()
                    adv_passage_ids[:, token_to_flip] = candidates[best_candidate_idx]
                else:
                    continue

            # This step tries to ensure that adv_a_tokenized  has atmost num_adv_passage_tokens, as Encode(Decode(tokens)) may produce a different length tokenization.
            
            adv_context_tokenized = truncate_passage(
                tokenizer,
                adv_passage_ids[0],
                num_adv_passage_tokens,
                adv_command_tokenized,
            )
            
            adv_passage = adv_context_tokenized + adv_command_tokenized  # token ids
            adv_passage_ids = torch.tensor(adv_passage, device=device).unsqueeze(0)

        adv_text = tokenizer.decode(
            adv_passage_ids[0],
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False,
        )

        adv_fixed = tokenizer.decode(
            adv_command_tokenized,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False,
        )

        adv_texts.append(adv_text)
        adv_fixed_texts.append(adv_fixed)

    # Clearing GPU memory
    bdr_queries_tokenized = [
        {k: v.cpu() for k, v in q.items()} for q in bdr_queries_tokenized
    ]
    cln_queries_tokenized = [
        {k: v.cpu() for k, v in q.items()} for q in cln_queries_tokenized
    ]
    candidate_scores.cpu()

    return adv_texts, adv_fixed_texts
