import torch
import json
import random
import numpy as np
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from sentence_transformers.util import semantic_search, dot_score, normalize_embeddings
from ..model_utils import get_template
from ..baseline import SingleBehaviorRedTeamingMethod
from ..wandb_logger import WandBLogger
from baselines.ppl_filter import Filter

# Ensure reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


class PEZ(SingleBehaviorRedTeamingMethod):
    def __init__(self, target_model, adv_string_init="! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !",
                 num_optim_tokens=20,
                 num_steps=100,
                 lr=1e-3,
                 targets_path=None,
                 seed=0,
                 adaptive=True,
                 target_len=-1,
                 jb_filter_window_size=16,
                 jb_filter_metric_name="perplexity",
                 jb_filter_threshold=-3150,
                 **model_kwargs):
        super().__init__(target_model, **model_kwargs)
        self.adv_string_init = adv_string_init
        self.num_optim_tokens = num_optim_tokens
        self.num_steps = num_steps
        self.lr = lr
        self.count_filter_not_passed = 0

        if targets_path is None:
            raise ValueError("targets_path must be specified")
        with open(targets_path, 'r', encoding='utf-8') as file:
            self.behavior_id_to_target = json.load(file)

        ## Instruction template
        self.target_len = target_len
        template = get_template(**target_model)['prompt']
        self.template = template
        self.before_tc, self.after_tc = template.split("{instruction}")

        # Seed for reproducibility
        self.seed = seed
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

        # Logging setup
        self.logger = WandBLogger({'num_steps': num_steps, 'adv_string_init': adv_string_init, **model_kwargs})
        self.n_input_tokens = 0

        # Filter setup
        ## Initialize the filter
        self.adaptive = adaptive

        if self.adaptive:
            base_path = 'ngrams_results_final/joined_final'

            self.jb_filter = Filter(
                unigram_path=base_path + '/arity=1/df_gutenberg_unigrams_dict_normalized_hashed.parquet',
                bigram_path=base_path + '/arity=2/df_gutenberg_bigrams_dict_normalized_hashed.parquet',
                mutation_count=1)
            self.jb_filter_window_size = jb_filter_window_size
            self.jb_filter_metric_name = jb_filter_metric_name
            self.jb_filter_threshold = jb_filter_threshold

            print('Loaded the filter')

    def generate_test_cases_single_behavior(self, behavior, num_generate=1, verbose=False):
        model = self.model
        tokenizer = self.tokenizer
        device = model.device

        behavior_str = behavior.get('ContextString', '') + f"\n\n---\n\n{behavior['Behavior']}"

        # ========== Behavior and Target str ==========
        behavior_dict = behavior
        behavior = behavior_dict["Behavior"]
        context_str = behavior_dict["ContextString"]
        behavior_id = behavior_dict["BehaviorID"]

        target = self.behavior_id_to_target[behavior_id][:self.target_len]
        behavior += " "
        if context_str:
            behavior = f"{context_str}\n\n---\n\n{behavior}"

        print(f"Behavior: {behavior_id} || Target: {target}")

        # Embedding layers
        embed_layer = model.get_input_embeddings()

        # Tokenization and embedding caching
        tokenized_inputs = tokenizer([self.before_tc, behavior_str, self.after_tc, target], padding=False,
                                     add_special_tokens=False)
        cache_input_ids = [torch.tensor(tokenized_inputs['input_ids'], device=device).unsqueeze(0)]

        before_ids, behavior_ids, after_ids, target_ids = cache_input_ids

        before_embeds, behavior_embeds, after_embeds, target_embeds = [embed_layer(ids) for ids in cache_input_ids]

        # Optimizing embedding generation
        optim_embeds = torch.randint(embed_layer.weight.shape[0], (num_generate, self.num_optim_tokens), device=device)
        optim_embeds = torch.nn.Parameter(embed_layer(optim_embeds).float(), requires_grad=True)
        optimizer = Adam([optim_embeds], lr=self.lr)
        scheduler = CosineAnnealingLR(optimizer, self.num_steps)

        for step in range(self.num_steps):
            optimizer.zero_grad()
            # Forward pass through the model
            input_embeds = torch.cat([before_embeds, behavior_embeds, optim_embeds, after_embeds, target_embeds], dim=1)
            outputs = model(inputs_embeds=input_embeds)
            logits = outputs.logits
            self.n_input_tokens += input_embeds.shape[1]

            # Compute loss
            shift_logits = logits[..., :-1, :]
            loss = CrossEntropyLoss()(shift_logits.view(-1, logits.size(-1)), target_embeds.view(-1))

            # ========== update optim_embeds ========== #
            optimizer.zero_grad()
            loss.mean().backward(inputs=[optim_embeds])

            with torch.no_grad():
                nn_indices = nn_project(optim_embeds.to(model.dtype), embed_layer)[1]

                nn_indices_behavior = torch.cat([behavior_ids.repeat(num_generate, 1), nn_indices], dim=1)

                mask = torch.ones_like(optim_embeds, dtype=torch.float32)
                for idx in range(nn_indices_behavior.size(0)):

                    passed_filter = True

                    if self.adaptive:
                        ids_check = nn_indices_behavior[idx] #torch.cat([behavior_ids, nn_indices_behavior[idx].unsqueeze(0)], dim=1).squeeze()
                        test_case = tokenizer.decode(ids_check)

                        if type(test_case) == list:
                            test_case = test_case[0]


                        passed_filter = self.jb_filter.apply_filter(test_case,
                                      window_size=self.jb_filter_window_size,
                                      metric_name=self.jb_filter_metric_name,
                                      threshold=self.jb_filter_threshold,
                                      verbose=False)

                    if not passed_filter:
                        self.count_filter_not_passed += 1
                        mask[idx] = 0  # Zero out gradients for failing tokens
                optim_embeds.grad *= mask


            optimizer.step()
            scheduler.step()

            # Logging step results
            self.logger.log({
                'step': step,
                'loss': loss.item(),
            })

            # ========== Logs Wandb ========== #
            logs_wandb = {
                "behavior_id": behavior_id,
                "behavior": behavior,
                "loss": loss.item(),
                "test_case": test_case,
                "count_filter_not_passed": self.count_filter_not_passed,
                "n_input_tokens": self.n_input_tokens,
                "n_output_tokens": self.n_output_tokens,
                "n_output_fwd_passes": self.n_output_fwd_passes,
                "n_fwd_passes": self.n_fwd_passes,
                "n_bcwd_passes": self.n_bcwd_passes,
                "time": (self.start_time - time.time()),  # in seconds
                "running_mean_time": (self.start_time - time.time()) / (i + 1),
                "optim_ids": optim_ids,

                "count_filter_not_passed_val": count_filter_not_passed,

                "is_refusal": is_refusal,
                "completions": completion
            }

            self.logger.log(logs_wandb

            if verbose:
                print(f"Step {step + 1}: Loss = {loss.item()}")


        # Final logging and cleanup
        self.logger.finish()

        return {"loss": loss.item(), "test_cases": tokenizer.decode(optim_embeds.argmax(dim

# ============================== UTILS FOR PEZ ============================== #
def nn_project(curr_embeds, embedding_layer):
    batch_size, seq_len, emb_dim = curr_embeds.shape

    # Using the sentence transformers semantic search which is
    # a dot product exact kNN search between a set of
    # query vectors and a corpus of vectors
    curr_embeds = curr_embeds.reshape((-1,emb_dim))
    curr_embeds = normalize_embeddings(curr_embeds) # queries

    embedding_matrix = embedding_layer.weight
    embedding_matrix = normalize_embeddings(embedding_matrix) # corpus

    hits = semantic_search(curr_embeds, embedding_matrix,
                            query_chunk_size=curr_embeds.shape[0],
                            top_k=3,
                            score_function=dot_score)

    nn_indices = torch.tensor([hit[0]["corpus_id"] for hit in hits], device='cuda')
    projected_embeds = embedding_layer(nn_indices)
    projected_embeds = projected_embeds.reshape((batch_size, seq_len, emb_dim))
    nn_indices = nn_indices.reshape((batch_size, seq_len))

    return projected_embeds, nn_indices

class project_soft_embeds(torch.autograd.Function):
    """
    This is a PyTorch layer that projects the soft embeddings to the nearest
    hard embedding in the forward pass and passes the gradient through in the
    backward pass. This is a straight-through estimator.
    """

    embed_layer = None  # class variable to store the embedding layer

    @staticmethod
    def forward(ctx, input):
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that can be used
        to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass using the ctx.save_for_backward method.
        """
        ctx.save_for_backward(input)
        projected_embeds, nn_indices = nn_project(input, project_soft_embeds.embed_layer)  # for Pythia
        return projected_embeds

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
        input, = ctx.saved_tensors
        return grad_output  # straight-through estimator
