### Helpers for PEZ ###
import torch
from sentence_transformers.util import (semantic_search,
                                        dot_score,
                                        normalize_embeddings)
import numpy as np
from tqdm import tqdm
ATTACKS_FUNCTION = {
    # "gbda": run_gbda,
}

def nn_project(curr_embeds, embedding_layer):    
    B, seq_len, emb_dim = curr_embeds.shape
    
    # Using the sentence transformers semantic search which is 
    # a dot product exact kNN search between a set of 
    # query vectors and a corpus of vectors
    curr_embeds = curr_embeds.reshape((-1, emb_dim))
    curr_embeds = normalize_embeddings(curr_embeds) # queries

    embedding_matrix = embedding_layer.weight
    embedding_matrix = normalize_embeddings(embedding_matrix) # corpus
    
    hits = semantic_search(curr_embeds, embedding_matrix, 
                            query_chunk_size=curr_embeds.shape[0], 
                            top_k=3,
                            score_function=dot_score)

    nn_indices = torch.tensor([hit[0]["corpus_id"] for hit in hits]).to(curr_embeds.device).reshape((B, seq_len))
    projected_embeds = embedding_layer(nn_indices)

    return projected_embeds, nn_indices

class project_soft_embeds(torch.autograd.Function):
    """
    This is a PyTorch layer that projects the soft embeddings to the nearest
    hard embedding in the forward pass and passes the gradient through in the
    backward pass. This is a straight-through estimator.
    """
    @staticmethod
    def forward(ctx, input, model):
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that can be used
        to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass using the ctx.save_for_backward method.
        """
        ctx.save_for_backward(input)
        # projected_embeds, nn_indices = nn_project(input, model.transformer.wte)  # for GPT-2
        # projected_embeds, nn_indices = nn_project(input, model.gpt_neox.embed_in)  # for Pythia
        projected_embeds, nn_indices = nn_project(input, model.deberta.embeddings.word_embeddings) # for deberta
        return projected_embeds

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
        input, = ctx.saved_tensors
        return grad_output, None  # straight-through estimator
    
def run_pez(b_input_ids, b_input_mask, model, tokenizer, num_optim_tokens, num_steps, eval_steps, mode='random_replace', maximize=True):

    if mode != 'random_replace':
        raise NotImplementedError

    model.eval()
    for p in model.parameters(): p.requires_grad = False

    multiplier = -1 if maximize else 1

    # ========== orig prompts ========== #
    attack_positions = []

    lengths = [(input_ids != tokenizer.pad_token_id).sum() - 2 for input_ids in b_input_ids]
    attack_positions = [np.random.choice(length.item(), num_optim_tokens, replace=num_optim_tokens > length.item()) + 1 for length in lengths]
    # attack_positions.append([len(prompt_tokens) + offset for offset in range(num_optim_tokens)])

    all_input_embeds = model.deberta.embeddings.word_embeddings(b_input_ids).data
    attack_positions = torch.tensor(np.array(attack_positions)).to(model.device)
    
    # ========== setup optim_embeds ========== #
    # TODO: instead of num_optim_tokens, we will update origianl input embeddigns
    optim_embeds = all_input_embeds[torch.arange(all_input_embeds.size(0)).unsqueeze(1), attack_positions].clone()
    optim_embeds = torch.nn.Parameter(optim_embeds)
    optim_embeds.requires_grad_()

    # ========== setup optimizer and scheduler ========== #
    optimizer = torch.optim.Adam([optim_embeds], lr=0.5, weight_decay=0)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_steps)

    # ========== run optimization ========== #
    orig_loss = None
    min_loss = 1e10
    optim_embeds_best = None
    for step in range(num_steps):
        optimizer.zero_grad()
        # ========== compute logits with concatenated optim embeds and target text ========== #
        projected_optim_embeds = project_soft_embeds.apply(optim_embeds, model)
        input_embeds = all_input_embeds.clone()
        input_embeds[torch.arange(input_embeds.size(0)).unsqueeze(1), attack_positions] = projected_optim_embeds

        outputs = model(inputs_embeds=input_embeds, attention_mask=b_input_mask)
        logits = outputs.logits

        loss = multiplier * logits
        loss = loss.mean()
        
        if loss < min_loss:
            min_loss = loss
            optim_embeds_best = optim_embeds.data.clone()
        if step == 0:
            orig_loss = loss
        
        # ========== update optim_embeds ========== #
        loss_bp = loss
        loss_bp.backward()
        optimizer.step()
        scheduler.step()
        # if step % 10 == 0:
        #     _, input_ids_ = nn_project(input_embeds.data, model.deberta.embeddings.word_embeddings) # for deberta
        #     attention_mask_ = torch.ones(input_ids_.shape[0], input_ids_.shape[1], dtype=torch.long).to(model.device)

        #     with torch.no_grad():
        #         outputs = model(input_ids_, attention_mask=attention_mask_)
        #         logits = outputs.logits.to(torch.float32).reshape((-1)).detach().cpu().numpy()
        #         print("Adv Score[0]=", logits[0], " ; maximize=",maximize)

    _, nn_indices = nn_project(input_embeds.data, model.deberta.embeddings.word_embeddings) # for deberta

    for p in model.parameters(): p.requires_grad = True 
    model.train()

    return nn_indices.detach()

def run_gbda(b_input_ids, b_input_mask, model, tokenizer, num_optim_tokens, num_steps, eval_steps, mode='random_replace', maximize=True):

    if mode != 'random_replace':
        raise NotImplementedError

    model.eval()
    for p in model.parameters(): p.requires_grad = False
    with torch.no_grad():
        embeddings = model.get_input_embeddings()(torch.arange(0, tokenizer.vocab_size).long().to(model.device))

    multiplier = -1 if maximize else 1

    # ========== orig prompts ========== #
    attack_positions = []

    lengths = [(input_ids != tokenizer.pad_token_id).sum() - 2 for input_ids in b_input_ids]
    attack_positions = [np.random.choice(length.item(), num_optim_tokens, replace=num_optim_tokens > length.item()) + 1 for length in lengths]
    # attack_positions.append([len(prompt_tokens) + offset for offset in range(num_optim_tokens)])

    all_input_embeds = model.deberta.embeddings.word_embeddings(b_input_ids).data
    attack_positions = torch.tensor(np.array(attack_positions)).to(model.device)

    log_coeffs = torch.zeros(b_input_ids.size(0), num_optim_tokens, embeddings.size(0), dtype=embeddings.dtype).squeeze(0).to(model.device)
    log_coeffs.requires_grad = True 
    # ========== setup optimizer and scheduler ========== #
    optimizer = torch.optim.Adam([log_coeffs], lr=0.5)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_steps)
    taus = np.linspace(1, 0.1, num_steps)

    # ========== run optimization ========== #
    min_loss = 1e10
    for step in range(num_steps):
        optimizer.zero_grad()
        
        # ========== compute logits with concatenated optim embeds and target text ========== #
        coeffs = torch.nn.functional.gumbel_softmax(log_coeffs.unsqueeze(0), hard=False, tau=taus[step]) 
        optim_embeds = (coeffs @ embeddings[None, :, :])
        input_embeds = all_input_embeds.clone()
        input_embeds[torch.arange(input_embeds.size(0)).unsqueeze(1), attack_positions] = optim_embeds

        outputs = model(inputs_embeds=input_embeds, attention_mask=b_input_mask)
        logits = outputs.logits

        loss = multiplier * logits
        loss = loss.mean()
        
        if loss < min_loss:
            min_loss = loss
        
        loss_bp = loss
        loss_bp.backward()
        optimizer.step()
        scheduler.step()

        # if step % eval_steps == 0 or step == num_steps - 1:
        #     if accelerator.is_main_process:
        #         outputs = model(inputs_embeds=input_embeds, attention_mask=b_input_mask)
        #         logits = outputs.logits.reshape((-1))
        #         print('step {} logits {:.3f}'.format(step, logits.mean().item()))

    optim_tokens = torch.argmax(log_coeffs, dim=-1).detach()
    b_input_ids[torch.arange(input_embeds.size(0)).unsqueeze(1), attack_positions] = optim_tokens

    for p in model.parameters(): p.requires_grad = True 
    model.train()
    return b_input_ids
