import torch
from sentence_transformers.util import (semantic_search,
                                        dot_score,
                                        normalize_embeddings)
from torch.nn import CrossEntropyLoss
from ..model_utils import get_template
from ..baseline import SingleBehaviorRedTeamingMethod
import json

# ============================== PEZ CLASS DEFINITION ============================== #
class PEZ(SingleBehaviorRedTeamingMethod):
    def __init__(self,
                target_model,
                num_optim_tokens=20,
                num_steps=100,
                lr=1e-3,
                targets_path=None,
                **model_kwargs):
        """
        :param target_model: a dictionary specifying the target model (kwargs to load_model_and_tokenizer)
        :param num_optim_tokens: the number of tokens to optimize
        :param num_steps: the number of optimization steps
        :param lr: the learning rate for soft embeddings
        :param targets_path: the path to the targets JSON file
        """
        super().__init__(target_model, **model_kwargs)
        self.num_optim_tokens = num_optim_tokens
        self.num_steps = num_steps
        self.lr = lr
        
        if targets_path is None:
            raise ValueError("targets_path must be specified")
        with open(targets_path, 'r', encoding='utf-8') as file:
            self.behavior_id_to_target = json.load(file)
        self.targets_path = targets_path

        template = get_template(**target_model)['prompt']
        self.template = template
        self.before_tc, self.after_tc = template.split("{instruction}")
    
    def generate_test_cases_single_behavior(self, behavior, num_generate, verbose=False):
        """
        Generates test cases for a single behavior

        :param behavior: a dictionary specifying the behavior to generate test cases for
        :param num_generate: the number of test cases to generate in parallel
        :param verbose: whether to print progress
        :return: a list of test case and a list of logs
        """
        # ========== Behavior and Target str ==========
        # get required variables from behavior dictionary
        behavior_dict = behavior
        behavior = behavior_dict['Behavior']
        context_str = behavior_dict['ContextString']
        behavior_id = behavior_dict['BehaviorID']
    
        target = self.behavior_id_to_target[behavior_id]
        behavior += ' '
        # handle setup for behaviors with context strings
        if context_str:
            behavior = f"{context_str}\n\n---\n\n{behavior}"

        # PEZ hyperparams
        num_optim_tokens = self.num_optim_tokens
        num_steps = self.num_steps
        lr = self.lr

        # Vars
        model = self.model
        tokenizer = self.tokenizer
        device = model.device
        before_tc = self.before_tc
        after_tc = self.after_tc
        embed_layer = self.model.get_input_embeddings()
        project_soft_embeds.embed_layer = embed_layer  # have to set this here to allow using different models

        # ========== Init Cache Embeds ========
        cache_input_ids = tokenizer([before_tc], padding=False)['input_ids'] # some tokenizer have <s> for before_tc
        cache_input_ids += tokenizer([behavior, after_tc, target], padding=False, add_special_tokens=False)['input_ids']
        cache_input_ids = [torch.tensor(input_ids, device=device).unsqueeze(0) for input_ids in cache_input_ids] # make tensor separately because can't return_tensors='pt' in tokenizer
        before_ids, behavior_ids, after_ids, target_ids = cache_input_ids
        before_embeds, behavior_embeds, after_embeds, target_embeds = [embed_layer(input_ids) for input_ids in cache_input_ids]  


        # ========== setup optim_embeds, optimizer, and scheduler ========== #
        optim_embeds = torch.randint(embed_layer.weight.shape[0],
                                     (num_generate, num_optim_tokens), device=device)
        optim_embeds = embed_layer(optim_embeds).float() # convert to float to avoid NaN with Adam optimizer
        optim_embeds = torch.nn.Parameter(optim_embeds)
        optim_embeds.requires_grad_()
        optimizer = torch.optim.Adam([optim_embeds], lr=lr, weight_decay=0)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_steps)

        # ========== run optimization ========== #
        all_losses = []
        all_test_cases = []
        for i in range(num_steps):
            # ========== compute logits with concatenated optim embeds and target text ========== #
            optim_embeds_projected = project_soft_embeds.apply(optim_embeds.to(model.dtype))

            input_embeds = torch.cat([before_embeds.repeat(num_generate, 1, 1), 
                                      behavior_embeds.repeat(num_generate, 1, 1), 
                                      optim_embeds_projected, 
                                      after_embeds.repeat(num_generate, 1, 1), 
                                      target_embeds.repeat(num_generate, 1, 1)], dim=1)

            outputs = self.model(inputs_embeds=input_embeds)
            logits = outputs.logits
            
            # ========== compute loss ========== #
            # Shift so that tokens < n predict n
            tmp = input_embeds.shape[1] - target_embeds.shape[1]
            shift_logits = logits[..., tmp-1:-1, :].contiguous()
            shift_labels = target_ids.repeat(num_generate, 1)
            # Flatten the tokens
            loss_fct = CrossEntropyLoss(reduction='none')
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            loss = loss.view(num_generate, -1).mean(dim=1)

            # ========== update optim_embeds ========== #
            optimizer.zero_grad()
            loss.mean().backward(inputs=[optim_embeds])
            optimizer.step()
            scheduler.step()

            # ========== retrieve optim_tokens and test_cases========== #
            with torch.no_grad():
                # ========== detokenize and print the optimized prompt ========== #
                _, nn_indices = nn_project(optim_embeds.to(model.dtype), embed_layer)

            nn_indices = torch.cat([behavior_ids.repeat(num_generate, 1), nn_indices], dim=1)
            test_cases = tokenizer.batch_decode(nn_indices, skip_special_tokens=True)

            # ========= Saving and logging test cases ========= #
            current_loss = loss.detach().cpu().tolist()
            all_test_cases.append(test_cases)
            all_losses.append(current_loss)

        logs = [{
                'final_loss': current_loss[i],
                'all_losses': [loss[i] for loss in all_losses],
                'all_test_cases': [test_cases[i] for test_cases in all_test_cases]
        } for i in range(num_generate)]
        return test_cases, logs

# ============================== UTILS FOR PEZ ============================== #
def nn_project(curr_embeds, embedding_layer):
    batch_size, seq_len, emb_dim = curr_embeds.shape
    
    # Using the sentence transformers semantic search which is 
    # a dot product exact kNN search between a set of 
    # query vectors and a corpus of vectors
    curr_embeds = curr_embeds.reshape((-1,emb_dim))
    curr_embeds = normalize_embeddings(curr_embeds) # queries

    embedding_matrix = embedding_layer.weight
    embedding_matrix = normalize_embeddings(embedding_matrix) # corpus
    
    hits = semantic_search(curr_embeds, embedding_matrix, 
                            query_chunk_size=curr_embeds.shape[0], 
                            top_k=3,
                            score_function=dot_score)

    nn_indices = torch.tensor([hit[0]["corpus_id"] for hit in hits], device='cuda')
    projected_embeds = embedding_layer(nn_indices)
    projected_embeds = projected_embeds.reshape((batch_size, seq_len, emb_dim))
    nn_indices = nn_indices.reshape((batch_size, seq_len))

    return projected_embeds, nn_indices

class project_soft_embeds(torch.autograd.Function):
    """
    This is a PyTorch layer that projects the soft embeddings to the nearest
    hard embedding in the forward pass and passes the gradient through in the
    backward pass. This is a straight-through estimator.
    """

    embed_layer = None  # class variable to store the embedding layer

    @staticmethod
    def forward(ctx, input):
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that can be used
        to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass using the ctx.save_for_backward method.
        """
        ctx.save_for_backward(input)
        projected_embeds, nn_indices = nn_project(input, project_soft_embeds.embed_layer)  # for Pythia
        return projected_embeds

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
        input, = ctx.saved_tensors
        return grad_output  # straight-through estimator
