from sentence_transformers import SentenceTransformer
import torch
import random
from tqdm import tqdm
from src.utils import load_json
import json
import os

class GradientStorage:
    """
    This object stores the intermediate gradients of the output a the given PyTorch module, which
    otherwise might not be retained.
    """
    def __init__(self, module):
        self._stored_gradient = None
        module.register_full_backward_hook(self.hook)

    def hook(self, module, grad_in, grad_out):
        self._stored_gradient = grad_out[0]

    def get(self):
        return self._stored_gradient

def get_embeddings(model):
    """Returns the wordpiece embedding module."""
    # base_model = getattr(model, config.model_type)
    # embeddings = base_model.embeddings.word_embeddings

    # This can be different for different models; the following is tested for Contriever
    if isinstance(model, SentenceTransformer):
        embeddings = model[0].auto_model.embeddings.word_embeddings
    else:
        embeddings = model.embeddings.word_embeddings
    return embeddings

def hotflip_attack(averaged_grad,
                   embedding_matrix,
                   increase_loss=False,
                   num_candidates=1,
                   filter=None):
    """Returns the top candidate replacements."""
    with torch.no_grad():
        gradient_dot_embedding_matrix = torch.matmul(
            embedding_matrix,
            averaged_grad
        )
        if filter is not None:
            gradient_dot_embedding_matrix -= filter
        if not increase_loss:
            gradient_dot_embedding_matrix *= -1
        _, top_k_ids = gradient_dot_embedding_matrix.topk(num_candidates)

    return top_k_ids


class Attacker():
    def __init__(self, args, **kwargs) -> None:
        # assert args.attack_method in ['default', 'whitebox']
        self.args = args
        self.attack_method = args.attack_method
        self.adv_per_query = args.adv_per_query
        
        self.model = kwargs.get('model', None)
        self.c_model = kwargs.get('c_model', None)
        self.tokenizer = kwargs.get('tokenizer', None)
        self.get_emb = kwargs.get('get_emb', None)

        if args.attack_method in ['hotflip', 'hotflip_batch_candidate', 'hotflip_batch_query']:
            self.max_seq_length = kwargs.get('max_seq_length', 128)
            self.pad_to_max_length = kwargs.get('pad_to_max_length', True)
            self.per_gpu_eval_batch_size = kwargs.get('per_gpu_eval_batch_size', 64)
            self.num_adv_passage_tokens = kwargs.get('num_adv_passage_tokens', 30)            

            self.num_cand = kwargs.get('num_cand', 100)
            self.num_iter = kwargs.get('num_iter', 30)
            self.gold_init = kwargs.get('gold_init', True)
            self.early_stop = kwargs.get('early_stop', True)

        self.all_adv_texts = load_json(f'results/adv_corpus_generated/{args.eval_dataset}.json')

    def get_attack(self, target_queries) -> list:
        '''
        This function returns adv_text_groups, which contains adv_texts for M queries
        For each query, if adv_per_query>1, we use different generated adv_texts or copies of the same adv_text
        '''
        adv_text_groups = [] # get the adv_text for the iter
        if self.attack_method == "LM_targeted":
            for i in range(len(target_queries)):
                question = target_queries[i]['query']
                id = target_queries[i]['id']
                adv_per_query = min(self.adv_per_query, len(self.all_adv_texts[id]['adv_texts']))
                adv_texts_b = self.all_adv_texts[id]['adv_texts'][:adv_per_query]
                adv_text_a = question + "."
                adv_texts = [adv_text_a + i for i in adv_texts_b]
                adv_text_groups.append(adv_texts)  

                # question = target_queries[i]['query']
                # id = target_queries[i]['id']
                # adv_texts = self.all_adv_texts[id]['adv_texts'][:self.adv_per_query]
                # adv_text_a = question + "."
                # adv_texts = [i for i in adv_texts_b]
                # adv_text_groups.append(adv_texts)  
        elif self.attack_method == 'hotflip':
            adv_text_groups = self.hotflip(target_queries)
        elif self.attack_method == 'hotflip_batch_candidate':
            adv_text_groups = self.hotflip_batch(target_queries)
        elif self.attack_method == 'hotflip_batch_query':
            adv_text_groups = self.hotflip_batch_query(target_queries)
        else: raise NotImplementedError
        return adv_text_groups       


    def hotflip(self, target_queries, adv_b=None, **kwargs) -> list:
        device = 'cuda'
        print('Doing HotFlip attack!')
        adv_text_groups = []
        adv_text_a_groups = []

        for query_score in tqdm(target_queries):
            query = query_score['query']
            top1_score = query_score['top1_score']
            id = query_score['id']
            adv_texts_b = self.all_adv_texts[id]['adv_texts']

            adv_results = []
            adv_results_a = []
            adv_per_query = min(self.adv_per_query, len(adv_texts_b))
            for j in range(adv_per_query):
                adv_b = adv_texts_b[j]
                adv_b_ids = self.tokenizer(adv_b, max_length=self.max_seq_length, truncation=True, padding=False)['input_ids']

                if self.gold_init:
                    adv_a = query
                    adv_a_ids = self.tokenizer(adv_a, max_length=self.max_seq_length, truncation=True, padding=False)['input_ids']
                else:
                    adv_a_ids = [self.tokenizer.mask_token_id] * self.num_adv_passage_tokens

                embeddings = get_embeddings(self.c_model)
                embedding_gradient = GradientStorage(embeddings)

                adv_passage_ids = torch.tensor(adv_a_ids + adv_b_ids, device=device).unsqueeze(0)
                adv_passage_attention = torch.ones_like(adv_passage_ids, device=device)
                adv_passage_token_type = torch.zeros_like(adv_passage_ids, device=device)

                q_sent = self.tokenizer(query, max_length=self.max_seq_length, truncation=True, padding="max_length" if self.pad_to_max_length else False, return_tensors="pt")
                q_sent = {key: value.cuda() for key, value in q_sent.items()}
                q_emb = self.get_emb(self.model, q_sent).detach()

                for it_ in range(self.num_iter):
                    grad = None
                    self.c_model.zero_grad()

                    p_sent = {
                        'input_ids': adv_passage_ids,
                        'attention_mask': adv_passage_attention,
                        'token_type_ids': adv_passage_token_type
                    }
                    p_emb = self.get_emb(self.c_model, p_sent)

                    if self.args.score_function == 'dot':
                        sim = torch.mm(p_emb, q_emb.T)
                    elif self.args.score_function == 'cos_sim':
                        sim = torch.cosine_similarity(p_emb, q_emb)
                    else:
                        raise KeyError

                    loss = sim.mean()

                    if self.early_stop and it_ > 5 and (
                        (self.args.score_function == 'dot' and sim.item() > top1_score + 0.1) or
                        (self.args.score_function == 'cos_sim' and sim.item() > top1_score + 0.05)
                    ):
                        break

                    loss.backward()
                    temp_grad = embedding_gradient.get()
                    grad = temp_grad.sum(dim=0)

                    token_to_flip = random.randrange(len(adv_a_ids))
                    candidates = hotflip_attack(
                        grad[token_to_flip],
                        embeddings.weight,
                        increase_loss=True,
                        num_candidates=self.num_cand,
                        filter=None
                    )

                    current_score = loss.sum().cpu().item()
                    candidate_scores = torch.zeros(self.num_cand, device=device)

                    for i, candidate in enumerate(candidates):
                        temp_adv_passage = adv_passage_ids.clone()
                        temp_adv_passage[:, token_to_flip] = candidate
                        temp_p_sent = {
                            'input_ids': temp_adv_passage,
                            'attention_mask': adv_passage_attention,
                            'token_type_ids': adv_passage_token_type
                        }
                        temp_p_emb = self.get_emb(self.c_model, temp_p_sent)
                        with torch.no_grad():
                            if self.args.score_function == 'dot':
                                temp_sim = torch.mm(temp_p_emb, q_emb.T)
                            elif self.args.score_function == 'cos_sim':
                                temp_sim = torch.cosine_similarity(temp_p_emb, q_emb)
                            else:
                                raise KeyError
                            candidate_scores[i] += temp_sim.mean().cpu().item()

                    if (candidate_scores > current_score).any():
                        best_candidate_idx = candidate_scores.argmax()
                        adv_passage_ids[:, token_to_flip] = candidates[best_candidate_idx]

                # ✅ decode full passage and just adv_a part
                adv_passage_text = self.tokenizer.decode(adv_passage_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
                adv_a_text_ids = adv_passage_ids[0][:len(adv_a_ids)]
                adv_a_text = self.tokenizer.decode(adv_a_text_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)

                # adv_results.append({
                #     'adv_passage': adv_passage_text,
                #     'adv_a_text': adv_a_text
                # })
                adv_results.append(adv_passage_text)
                adv_results_a.append(adv_a_text)

            adv_text_groups.append(adv_results)
            adv_text_a_groups.append(adv_results_a)

        return adv_text_groups, adv_text_a_groups
