import gc

import torch
import json
import time
import random
import numpy as np
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from sentence_transformers.util import semantic_search, dot_score, normalize_embeddings
from tqdm import tqdm

from ..model_utils import get_template
from ..baseline import SingleBehaviorRedTeamingMethod
from ..check_refusal_utils import check_refusal_completions
from ..wandb_logger import WandBLogger
from baselines.ppl_filter import Filter

# Ensure reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


class PEZ_adaptive(SingleBehaviorRedTeamingMethod):
    def __init__(self, target_model,
                 adv_string_init="! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !",
                 num_optim_tokens=20,
                 num_steps=100,
                 eval_steps=50,
                 use_prefix_cache=True,
                 lr=1e-3,
                 targets_path=None,
                 seed=0,
                 adaptive=True,
                 eval_with_check_refusal=False,
                 check_refusal_min_loss=0.1,
                 early_stopping=False,
                 early_stopping_min_loss=0.1,  # early stopping min loss criteria
                 target_len=-1,
                 jb_filter_window_size=16,
                 jb_filter_metric_name="perplexity",
                 jb_filter_threshold=-3150,
                 **model_kwargs):
        super().__init__(target_model, **model_kwargs)
        self.adv_string_init = adv_string_init
        self.num_optim_tokens = num_optim_tokens
        self.num_steps = num_steps
        self.lr = lr
        self.use_prefix_cache = use_prefix_cache
        self.count_filter_not_passed = 0


        if targets_path is None:
            raise ValueError("targets_path must be specified")
        with open(targets_path, 'r', encoding='utf-8') as file:
            self.behavior_id_to_target = json.load(file)

        ### Eval Vars ###
        self.eval_steps = eval_steps
        self.eval_with_check_refusal = eval_with_check_refusal
        self.check_refusal_min_loss = check_refusal_min_loss
        self.early_stopping = early_stopping
        self.early_stopping_min_loss = early_stopping_min_loss


        ## Instruction template
        self.target_len = target_len
        template = get_template(**target_model)['prompt']
        self.template = template
        self.before_tc, self.after_tc = template.split("{instruction}")

        ## seed everything
        self.logging_dict = {
            "adaptive": adaptive,
            "num_steps": num_steps,
            "adv_string_init": adv_string_init,
            "use_prefix_cache": use_prefix_cache,
            "eval_steps": eval_steps,
            "eval_with_check_refusal": eval_with_check_refusal,
            "check_refusal_min_loss": check_refusal_min_loss,
            "early_stopping": early_stopping,
            "early_stopping_min_loss": early_stopping_min_loss,
            "jb_filter_window_size": jb_filter_window_size,
            "jb_filter_metric_name": jb_filter_metric_name,
            "jb_filter_threshold": jb_filter_threshold,
            **model_kwargs,
        }
        self.seed = seed

        print('Using the seed', self.seed)
        random.seed(self.seed)
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        torch.cuda.manual_seed_all(self.seed)

        print(f"DEVICE: {self.model.device}")


        ## Initialize the filter
        self.adaptive = adaptive

        if self.adaptive:
            base_path = 'ngrams_results_final/joined_final'

            self.jb_filter = Filter(
                unigram_path=base_path + '/arity=1/df_gutenberg_unigrams_dict_normalized_hashed.parquet',
                bigram_path=base_path + '/arity=2/df_gutenberg_bigrams_dict_normalized_hashed.parquet',
                mutation_count=1)
            self.jb_filter_window_size = jb_filter_window_size
            self.jb_filter_metric_name = jb_filter_metric_name
            self.jb_filter_threshold = jb_filter_threshold

            print('Loaded the filter')

    def generate_test_cases_single_behavior(self, behavior, num_generate=1, verbose=False):

        self.logger = WandBLogger(
            args = self.logging_dict,
            project = 'harmbench.pez_attack.0th_iteration'
        )
        self.n_input_tokens = 0
        self.n_output_tokens = 0
        self.n_output_fwd_passes = 0
        self.n_fwd_passes = 0
        self.n_bcwd_passes = 0
        self.count_filter_not_passed = 0
        self.count_not_passed = 0
        self.start_time = time.time()


        # ========== Behavior and Target str ==========
        behavior_dict = behavior
        behavior = behavior_dict["Behavior"]
        context_str = behavior_dict["ContextString"]
        behavior_id = behavior_dict["BehaviorID"]


        target = self.behavior_id_to_target[behavior_id][:self.target_len]
        behavior += " "
        if context_str:
            behavior = f"{context_str}\n\n---\n\n{behavior}"

        print(f"Behavior: {behavior_id} || Target: {target} || num_generate: {num_generate}")

        # Vars
        model = self.model
        tokenizer = self.tokenizer
        device = model.device
        before_tc = self.before_tc
        after_tc = self.after_tc
        embed_layer = self.model.get_input_embeddings()
        project_soft_embeds.embed_layer = embed_layer  # have to set this here to allow using different models


        # ========== Init Cache Embeds ========
        cache_input_ids = tokenizer([before_tc], padding=False)['input_ids']  # some tokenizer have <s> for before_tc
        cache_input_ids += tokenizer([behavior, after_tc, target], padding=False, add_special_tokens=False)['input_ids']
        cache_input_ids = [torch.tensor(input_ids, device=device).unsqueeze(0) for input_ids in
                           cache_input_ids]  # make tensor separately because can't return_tensors='pt' in tokenizer
        before_ids, behavior_ids, after_ids, target_ids = cache_input_ids
        before_embeds, behavior_embeds, after_embeds, target_embeds = [embed_layer(input_ids) for input_ids in
                                                                       cache_input_ids]

        # ========== setup optim_embeds, optimizer, and scheduler ========== #
        optim_embeds = torch.randint(embed_layer.weight.shape[0],
                                     (num_generate, self.num_optim_tokens), device=device)
        optim_embeds = embed_layer(optim_embeds).float()  # convert to float to avoid NaN with Adam optimizer
        optim_embeds = torch.nn.Parameter(optim_embeds)
        optim_embeds.requires_grad_()
        optimizer = torch.optim.Adam([optim_embeds], lr=self.lr, weight_decay=0)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, self.num_steps)

        if self.use_prefix_cache:
            # precompute KV cache for everything before the optimized tokens
            input_embeds = torch.cat([before_embeds, behavior_embeds], dim=1)
            self.n_input_tokens += input_embeds.shape[1]
            self.n_fwd_passes += 1
            with torch.no_grad():
                outputs = model(inputs_embeds=input_embeds, use_cache=True)
                self.prefix_cache = outputs.past_key_values

        for i in tqdm(range(self.num_steps), desc="Attack steps"):

            # ========== compute logits with concatenated optim embeds and target text ========== #
            optim_embeds_projected = project_soft_embeds.apply(optim_embeds.to(model.dtype))

            input_embeds = torch.cat([before_embeds.repeat(num_generate, 1, 1),
                                      behavior_embeds.repeat(num_generate, 1, 1),
                                      optim_embeds_projected,
                                      after_embeds.repeat(num_generate, 1, 1),
                                      target_embeds.repeat(num_generate, 1, 1)], dim=1)

            outputs = self.model(inputs_embeds=input_embeds)
            self.n_input_tokens += input_embeds.shape[1]
            self.n_fwd_passes += 1
            logits = outputs.logits

            # ========== compute loss ========== #
            # Shift so that tokens < n predict n
            tmp = input_embeds.shape[1] - target_embeds.shape[1]
            shift_logits = logits[..., tmp - 1:-1, :].contiguous()
            shift_labels = target_ids.repeat(num_generate, 1)
            # Flatten the tokens
            loss_fct = CrossEntropyLoss(reduction='none')
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            loss = loss.view(num_generate, -1).mean(dim=1)

            # ========== update optim_embeds ========== #
            optimizer.zero_grad()
            loss.mean().backward(inputs=[optim_embeds])

            all_test_cases = []
            all_losses = []

            with torch.no_grad():
                nn_indices = nn_project(optim_embeds.to(model.dtype), embed_layer)[1]

                nn_indices_behavior = torch.cat([behavior_ids.repeat(num_generate, 1), nn_indices], dim=1)

                mask = torch.ones_like(optim_embeds, dtype=torch.float32)
                count_passed_filter = 0
                for idx in range(nn_indices_behavior.size(0)):

                    passed_filter = True

                    if self.adaptive:
                        ids_check = nn_indices_behavior[idx] #torch.cat([behavior_ids, nn_indices_behavior[idx].unsqueeze(0)], dim=1).squeeze()
                        test_case = tokenizer.decode(ids_check)

                        if type(test_case) == list:
                            test_case = test_case[0]


                        passed_filter = self.jb_filter.apply_filter(test_case,
                                      window_size=self.jb_filter_window_size,
                                      metric_name=self.jb_filter_metric_name,
                                      threshold=self.jb_filter_threshold,
                                      verbose=False)

                    if not passed_filter:
                        self.count_filter_not_passed += 1

                        mask[idx] = 0.1 #0.5  # Zero out gradients for failing tokens
                    else:
                        count_passed_filter += 1

                if count_passed_filter == 0:
                    print('No embedding passed the filter!')

                optim_embeds.grad *= mask


            optimizer.step()
            scheduler.step()

            # ========== retrieve optim_tokens and test_cases========== #
            with torch.no_grad():
                # ========== detokenize and print the optimized prompt ========== #
                _, nn_indices = nn_project(optim_embeds.to(model.dtype), embed_layer)

            nn_indices = torch.cat([behavior_ids.repeat(num_generate, 1), nn_indices], dim=1)
            test_cases = tokenizer.batch_decode(nn_indices, skip_special_tokens=True)


            # ========= Saving and logging test cases ========= #
            current_loss = loss.detach().cpu().tolist()
            current_loss_median = np.median(current_loss)
            all_test_cases.append(test_cases)
            all_losses.append(current_loss)

            test_cases_best = test_cases[loss.detach().argmin()]

            # ========== Logs Wandb ========== #
            logs_wandb = {
                "loss": current_loss,
                "test_cases": test_cases,
                "test_case_best": test_cases_best,
                "count_not_passed": self.count_not_passed,
                "count_filter_not_passed": self.count_filter_not_passed,
                "n_input_tokens": self.n_input_tokens,
                "optim_ids": nn_indices,
                "n_fwd_passes": self.n_fwd_passes,
                "n_bcwd_passes": self.n_bcwd_passes,
            }

            self.logger.log(logs_wandb)

            # ========== Eval and Early Stopping ========== #
            if (i % self.eval_steps == 0) or (i == self.num_steps - 1):

                p_output = f"\n===>Step {i}\n===>Test Case: {test_cases_best}\n===>Loss median: {current_loss_median}"


                if self.adaptive:
                    print('Checking the filter', self.jb_filter.apply_filter(test_cases_best,
                                                                             window_size=self.jb_filter_window_size,
                                                                             metric_name=self.jb_filter_metric_name,
                                                                             threshold=self.jb_filter_threshold,
                                                                             verbose=True))

                if self.eval_with_check_refusal: # and current_loss_median < self.check_refusal_min_loss:

                    input_str = self.template.format(instruction=test_cases_best)
                    is_refusal, completions, _ = check_refusal_completions(
                        model, tokenizer, inputs=[input_str]
                    )
                    self.n_output_tokens += len(tokenizer(input_str)["input_ids"])
                    self.n_output_fwd_passes += 1

                    self.logger.log(
                        {"behavior_id": behavior_id,
                         "behavior": behavior,

                         "test_case_val": test_cases,
                         "loss_val": current_loss,

                         "n_input_tokens_val": self.n_input_tokens,

                         "n_output_tokens": self.n_output_tokens,
                         "n_output_fwd_passes": self.n_output_fwd_passes,
                         "n_fwd_passes": self.n_fwd_passes,
                         "n_bcwd_passes": self.n_bcwd_passes,
                         "time": (self.start_time - time.time()),  # in seconds
                         "running_mean_time": (self.start_time - time.time()) / (i + 1),
                         "optim_ids": nn_indices,

                         "count_filter_not_passed": self.count_filter_not_passed,

                         "is_refusal": is_refusal,
                         "completions": completions
                         }
                    )

                    p_output += f"\n\n===>Completion: {completions[0]}"

                    # if verbose:
                    print(p_output)

                    # Refusal check
                    if not is_refusal[0]:
                        break

                    # Early stopping check
                    if self.early_stopping and current_loss_median < self.early_stopping_min_loss:
                        print(f"Early stopping at step {i} with loss {current_loss_median}")
                        print()
                        break

            del input_embeds, logits, shift_logits, loss
            torch.cuda.empty_cache()
            gc.collect()

            if verbose:
                print(f"Step {i + 1}: Loss = {current_loss_median}")


        # Final logging and cleanup
        self.logger.finish()
        logs = [{
            'final_loss': current_loss,
            'all_losses': [loss[i] for loss in all_losses],
            'all_test_cases': [test_cases[i] for test_cases in all_test_cases]
        } for i in range(num_generate)]
        return all_test_cases, logs


# ============================== UTILS FOR PEZ ============================== #

def n_gram_penalty(sequence, n_gram_counts, total_n_grams, vocab, threshold):
    penalty = 0.0
    for i in range(len(sequence) - 1):
        bigram = (sequence[i], sequence[i + 1])
        bigram_count = n_gram_counts.get(bigram, 0)
        # Penalizing based on frequency
        penalty += -torch.log(torch.tensor((bigram_count / total_n_grams + 1e-10), device=sequence.device))
    return penalty if penalty / len(sequence) < threshold else 0.0


def compute_loss_and_penalty(tokenizer, input_embeds, target_embeds, target_ids, model, n_gram_counts, total_n_grams, vocab,
                             threshold, num_generate):
    outputs = model(inputs_embeds=input_embeds)
    logits = outputs.logits

    # Standard loss calculation
    tmp = input_embeds.shape[1] - target_embeds.shape[1]
    shift_logits = logits[..., tmp - 1:-1, :].contiguous()
    shift_labels = target_ids.repeat(num_generate, 1)
    loss_fct = CrossEntropyLoss(reduction='none')
    cross_entropy_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    cross_entropy_loss = cross_entropy_loss.view(num_generate, -1).mean(dim=1)

    # N-gram penalty calculation
    n_gram_penalty_loss = torch.tensor(
        [n_gram_penalty(tokenizer.decode(ids), n_gram_counts, total_n_grams, vocab, threshold) for ids in target_ids],
        device=target_ids.device)

    # Combine losses
    total_loss = cross_entropy_loss + n_gram_penalty_loss.mean()

    return total_loss

def nn_project(curr_embeds, embedding_layer):
    batch_size, seq_len, emb_dim = curr_embeds.shape

    # Using the sentence transformers semantic search which is
    # a dot product exact kNN search between a set of
    # query vectors and a corpus of vectors
    curr_embeds = curr_embeds.reshape((-1,emb_dim))
    curr_embeds = normalize_embeddings(curr_embeds) # queries

    embedding_matrix = embedding_layer.weight
    embedding_matrix = normalize_embeddings(embedding_matrix) # corpus

    hits = semantic_search(curr_embeds, embedding_matrix,
                            query_chunk_size=curr_embeds.shape[0],
                            top_k=3,
                            score_function=dot_score)

    nn_indices = torch.tensor([hit[0]["corpus_id"] for hit in hits], device='cuda')
    projected_embeds = embedding_layer(nn_indices)
    projected_embeds = projected_embeds.reshape((batch_size, seq_len, emb_dim))
    nn_indices = nn_indices.reshape((batch_size, seq_len))

    return projected_embeds, nn_indices

class project_soft_embeds(torch.autograd.Function):
    """
    This is a PyTorch layer that projects the soft embeddings to the nearest
    hard embedding in the forward pass and passes the gradient through in the
    backward pass. This is a straight-through estimator.
    """

    embed_layer = None  # class variable to store the embedding layer

    @staticmethod
    def forward(ctx, input):
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that can be used
        to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass using the ctx.save_for_backward method.
        """
        ctx.save_for_backward(input)
        projected_embeds, nn_indices = nn_project(input, project_soft_embeds.embed_layer)  # for Pythia
        return projected_embeds

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
        input, = ctx.saved_tensors
        return grad_output  # straight-through estimator
