import gc
import json
import os
import random
import time

import numpy as np
import torch

# https://huggingface.co/docs/accelerate/v0.11.0/en/memory#accelerate.find_executable_batch_size
from accelerate.utils import find_executable_batch_size
from torch.nn import CrossEntropyLoss
from tqdm import tqdm

from baselines.ppl_filter import Filter

from ..baseline import SingleBehaviorRedTeamingMethod
from ..check_refusal_utils import check_refusal_completions
from ..model_utils import get_template
from ..wandb_logger import WandBLogger
from .gcg_utils import get_nonascii_toks, sample_control

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


# ============================== GCG CLASS DEFINITION ============================== #
class GCG_adaptive(SingleBehaviorRedTeamingMethod):
    def __init__(
            self,
            target_model,
            targets_path,
            num_steps=50,
            adv_string_init="! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !",
            allow_non_ascii=False,
            search_width=512,
            use_prefix_cache=True,
            eval_steps=10,
            eval_with_check_refusal=False,
            check_refusal_min_loss=0.1,
            early_stopping=False,
            early_stopping_min_loss=0.1,  # early stopping min loss criteria
            starting_search_batch_size=None,  # preset search_batch_size which will auto go half each time go OOM
            adaptive=True,
            target_len=-1,
            jb_filter_window_size=16,
            jb_filter_metric_name="perplexity",
            jb_filter_threshold=-3150,
            seed=0,
            **model_kwargs,
    ):
        """
        :param num_steps: the number of optimization steps to use
        :param adv_string_init: the initial adversarial string to start with
        :param allow_non_ascii: whether to allow non-ascii characters
        :param search_width: the number of candidates to sample at each step ("search_width" in the original GCG code)
        :param early_stopping: whether to stop early if the loss is less than early_stopping_min_loss
        """
        super().__init__(
            target_model=target_model, **model_kwargs
        )  # assigns self.model, self.tokenizer, and self.model_name_or_path

        # Increase number of steps for adaptive attack, given that it filters out a lot of candidates
        # if adaptive:
        #    num_steps *= 3

        self.num_steps = num_steps
        self.adv_string_init = adv_string_init
        self.allow_non_ascii = allow_non_ascii
        self.search_width = search_width
        self.use_prefix_cache = use_prefix_cache

        with open(targets_path, "r", encoding="utf-8") as file:
            self.behavior_id_to_target = json.load(file)

        ### Eval Vars ###
        self.eval_steps = eval_steps
        self.eval_with_check_refusal = eval_with_check_refusal
        self.check_refusal_min_loss = check_refusal_min_loss
        self.early_stopping = early_stopping
        self.early_stopping_min_loss = early_stopping_min_loss
        self.starting_search_batch_size = starting_search_batch_size

        ## Instruction template
        self.target_len = target_len
        template = get_template(**target_model)["prompt"]
        self.template = template
        self.before_tc, self.after_tc = template.split("{instruction}")

        if "llama-3" in self.model_name_or_path.lower():
            self.idx_to_val = {
                v: k.replace("Ġ", " ") for k, v in self.tokenizer.get_vocab().items()
            }
        else:
            self.idx_to_val = self.tokenizer.get_vocab().items()

        ## Initialize the filter
        self.adaptive = adaptive
        if self.adaptive:
            base_path = 'ngrams_results_final/joined_final'

            self.jb_filter = Filter(
                unigram_path=base_path + '/arity=1/df_gutenberg_unigrams_dict_normalized_hashed.parquet',
                bigram_path=base_path + '/arity=2/df_gutenberg_bigrams_dict_normalized_hashed.parquet',
                mutation_count=1)
            self.jb_filter_window_size = jb_filter_window_size
            self.jb_filter_metric_name = jb_filter_metric_name
            self.jb_filter_threshold = jb_filter_threshold

            print('Loaded the filter')

        ## seed everything
        self.logging_dict = {
            "adaptive": adaptive,
            "num_steps": num_steps,
            "adv_string_init": adv_string_init,
            "allow_non_ascii": allow_non_ascii,
            "search_width": search_width,
            "use_prefix_cache": use_prefix_cache,
            "eval_steps": eval_steps,
            "eval_with_check_refusal": eval_with_check_refusal,
            "check_refusal_min_loss": check_refusal_min_loss,
            "early_stopping": early_stopping,
            "early_stopping_min_loss": early_stopping_min_loss,
            "starting_search_batch_size": starting_search_batch_size,
            "jb_filter_window_size": jb_filter_window_size,
            "jb_filter_metric_name": jb_filter_metric_name,
            "jb_filter_threshold": jb_filter_threshold,
            "template": self.template,
            **model_kwargs,
        }
        self.seed = seed

        print('Using the seed', self.seed)
        random.seed(self.seed)
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        torch.cuda.manual_seed_all(self.seed)

        print(f"DEVICE: {self.model.device}")

    def generate_test_cases_single_behavior(self, behavior, verbose=False, **kwargs):
        """
        Generate a test cases for the behavior

        :param behavior: the behavior to generate test cases for
        :param verbose: whether to print progress
        """
        ######### Tokens counting ##########
        self.n_fwd_target_model_input_tokens = 0
        self.n_fwd_target_model_cache_tokens = 0

        self.n_bcwd_target_model_input_tokens = 0 # tokens for which input we called backward pass
        self.n_bcwd_target_model_optim_tokens = 0 # subset of tokens in n_bcwd_target_model_input_tokens
        self.n_bcwd_target_model_cache_tokens = 0 # same, but cache

        self.n_target_model_output_tokens = 0

        self.n_target_model_fwd_passes = 0
        self.n_target_model_bcwd_passes = 0

        # only counts evaluation passes
        self.n_fwd_target_model_input_tokens_eval = 0
        self.n_fwd_target_model_cache_tokens_eval = 0

        self.n_target_model_output_tokens_eval = 0

        self.n_target_model_fwd_passes_eval = 0

        self.count_filter_not_passed = 0
        self.count_not_passed = 0
        self.start_time = time.time()

        # starting search_batch_size, will auto reduce batch_size later if go OOM (resets for every new behavior)
        self.search_batch_size = (
            self.starting_search_batch_size
            if self.starting_search_batch_size
            else self.search_width
        )

        # ========== Behavior and Target str ==========
        behavior_dict = behavior
        behavior = behavior_dict["Behavior"]
        context_str = behavior_dict["ContextString"]
        behavior_id = behavior_dict["BehaviorID"]

        self.target = self.behavior_id_to_target[behavior_id][:self.target_len]
        behavior += " "
        if context_str:
            behavior = f"{context_str}\n\n---\n\n{behavior}"

        self.logging_dict.update(
            {
                "behavior": behavior,
                "behavior_id": behavior_id,
                "target": self.target,
            }
        )
        adaptive_str = "_adaptive" if self.adaptive else ""
        self.logger = WandBLogger(
            args=self.logging_dict,
            project=f'harmbench.gcg_attack.6th_iteration{adaptive_str}'
        )

        print(f"Behavior: {behavior_id} || Target: {self.target}")
        ### Targeted Model and Tokenier ###
        model = self.model
        device = model.device
        tokenizer = self.tokenizer
        
        ### GCG hyperparams ###
        num_steps = self.num_steps
        adv_string_init = self.adv_string_init
        allow_non_ascii = self.allow_non_ascii
        search_width = self.search_width
        before_tc = self.before_tc
        after_tc = self.after_tc

        ### Eval vars ###
        eval_steps = self.eval_steps
        eval_with_check_refusal = self.eval_with_check_refusal
        check_refusal_min_loss = self.check_refusal_min_loss
        early_stopping = self.early_stopping
        early_stopping_min_loss = self.early_stopping_min_loss

        embed_layer = model.get_input_embeddings()
        vocab_size = embed_layer.weight.shape[
            0
        ]  # can be larger than tokenizer.vocab_size for some models
        vocab_embeds = embed_layer(torch.arange(0, vocab_size).long().to(model.device))

        not_allowed_tokens = None if allow_non_ascii else get_nonascii_toks(tokenizer)
        # remove redundant tokens
        not_allowed_tokens = torch.unique(not_allowed_tokens)

        # ========== setup optim_ids and cached components ========== #
        optim_ids = tokenizer(
            adv_string_init, return_tensors="pt", add_special_tokens=False
        ).to(device)["input_ids"]
        num_optim_tokens = len(optim_ids[0])

        cache_input_ids = tokenizer([before_tc], padding=False)[
            "input_ids"
        ]  # some tokenizer have <s> for before_tc
        cache_input_ids += tokenizer(
            [behavior, after_tc, self.target], padding=False, add_special_tokens=False
        )["input_ids"]
        cache_input_ids = [
            torch.tensor(input_ids, device=device).unsqueeze(0)
            for input_ids in cache_input_ids
        ]  # make tensor separately because can't return_tensors='pt' in tokenizer
        before_ids, behavior_ids, after_ids, target_ids = cache_input_ids
        before_embeds, behavior_embeds, after_embeds, target_embeds = [
            embed_layer(input_ids) for input_ids in cache_input_ids
        ]

        if self.use_prefix_cache:
            # precompute KV cache for everything before the optimized tokens
            input_embeds = torch.cat([before_embeds, behavior_embeds], dim=1)
            with torch.no_grad():
                outputs = model(inputs_embeds=input_embeds, use_cache=True)
                self.prefix_cache = outputs.past_key_values
                self.cache_size = input_embeds.shape[1]
                self.n_target_model_output_tokens += target_ids.shape[1]
            self.n_fwd_target_model_input_tokens += input_embeds.shape[1]
            self.n_target_model_fwd_passes += 1

        # ========== run optimization ========== #
        all_losses = []
        all_test_cases = []
        used_tokens = set()

        for i in tqdm(range(num_steps), desc="Attack steps"):
            # ========== compute coordinate token_gradient ========== #
            # create input
            optim_ids_onehot = torch.zeros(
                (1, num_optim_tokens, vocab_size), device=device, dtype=model.dtype
            )
            optim_ids_onehot.scatter_(2, optim_ids.unsqueeze(2), 1.0).requires_grad_()
            optim_embeds = torch.matmul(
                optim_ids_onehot.squeeze(0), vocab_embeds
            ).unsqueeze(0)

            # forward pass
            if self.use_prefix_cache:
                input_embeds = torch.cat(
                    [optim_embeds, after_embeds, target_embeds], dim=1
                )
                outputs = model(
                    inputs_embeds=input_embeds, past_key_values=self.prefix_cache
                )
                self.n_fwd_target_model_input_tokens += input_embeds.shape[1]
                self.n_fwd_target_model_cache_tokens += self.cache_size
                self.n_target_model_fwd_passes += 1
                self.n_target_model_output_tokens += target_embeds.shape[1]
            else:
                input_embeds = torch.cat(
                    [
                        before_embeds,
                        behavior_embeds,
                        optim_embeds,
                        after_embeds,
                        target_embeds,
                    ],
                    dim=1,
                )
                outputs = model(inputs_embeds=input_embeds)
                self.n_fwd_target_model_input_tokens += input_embeds.shape[1]
                self.n_target_model_fwd_passes += 1
                self.n_target_model_output_tokens += target_embeds.shape[1]

            logits = outputs.logits

            # compute loss
            # Shift so that tokens < n predict n
            tmp = input_embeds.shape[1] - target_embeds.shape[1]
            shift_logits = logits[..., tmp - 1: -1, :].contiguous()
            shift_labels = target_ids
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
            )

            token_grad = torch.autograd.grad(outputs=[loss], inputs=[optim_ids_onehot])[
                0
            ]

            self.n_target_model_bcwd_passes += 1
            self.n_bcwd_target_model_optim_tokens += optim_ids_onehot.shape[1]
            if self.use_prefix_cache:
                self.n_bcwd_target_model_cache_tokens += self.cache_size
            self.n_bcwd_target_model_input_tokens += input_embeds.shape[1]

            # ========== Sample a batch of new tokens based on the coordinate gradient. ========== #
            new_sampled_top_indices = []
            counter_while = 1
            count = 0
            # Initialize a set to keep track of unique tokens
            unique_tokens = set()

            while len(new_sampled_top_indices) <= 512 and (counter_while <= 10 or len(new_sampled_top_indices) < 1):
                print('trial', counter_while, len(new_sampled_top_indices), count)
                counter_while += 1
                sampled_top_indices = sample_control(
                    optim_ids.squeeze(0),
                    token_grad.squeeze(0),
                    search_width,
                    topk=256,
                    temp=1,
                    not_allowed_tokens=not_allowed_tokens,
                    seed=self.seed + i
                )

                if len(sampled_top_indices) > 0:
                    # ========= Filter sampled_top_indices that aren't the same after detokenize->retokenize ========= #
                    if "llama-3" in self.model_name_or_path.lower():
                        sampled_top_indices_text = [
                            "".join([self.idx_to_val[i.item()] for i in sampled_top_indices[j]])
                            for j in range(len(sampled_top_indices))
                        ]
                    else:
                        sampled_top_indices_text = tokenizer.batch_decode(sampled_top_indices)

                    count_filter_not_passed = 0

                    for j in range(len(sampled_top_indices_text)):
                        # tokenize again
                        tmp = tokenizer(
                            sampled_top_indices_text[j],
                            return_tensors="pt",
                            add_special_tokens=False,
                        ).to(device)["input_ids"][0]
                        # if the tokenized text is different, then set the sampled_top_indices to padding_top_indices
                        if not torch.equal(tmp, sampled_top_indices[j]):
                            count += 1

                            print('did not pass the retokenization', count_filter_not_passed, count, sampled_top_indices_text[j])
                            if verbose:
                                print('retokenized:', tmp, sampled_top_indices_text[j])
                                print('original:', sampled_top_indices[j])
                            continue
                        else:
                            # print('Checking text', sampled_top_indices_text[j])

                            passed_filter = True

                            if self.adaptive:
                                ids_check = torch.cat([behavior_ids, sampled_top_indices[j].unsqueeze(0)],
                                                      dim=1).squeeze()
                                test_case = tokenizer.decode(ids_check)

                                if type(test_case) == list:
                                    test_case = test_case[0]

                                passed_filter, metrics = self.jb_filter.apply_filter(test_case,
                                                                            window_size=self.jb_filter_window_size,
                                                                            metric_name=self.jb_filter_metric_name,
                                                                            threshold=self.jb_filter_threshold,
                                                                            return_metrics=True,
                                                                            verbose=False)

                            if passed_filter:
                                token_tuple = tuple(sampled_top_indices[j].tolist())

                                if token_tuple not in unique_tokens and token_tuple not in used_tokens:
                                    unique_tokens.add(token_tuple)
                                    new_sampled_top_indices.append(sampled_top_indices[j])
                                else:
                                    count_filter_not_passed += 1
                                    # print(f'Duplicate token sequence ignored: {sampled_top_indices[j]}', unique_tokens)
                            else:
                                print('did not pass the filter', count_filter_not_passed, count, sampled_top_indices_text[j], metrics)
                                count_filter_not_passed += 1

            sampled_top_indices = torch.stack(new_sampled_top_indices[:512])
            if verbose:
                print('shape of top indices', sampled_top_indices.shape)
            if count + count_filter_not_passed >= search_width // 2:
                print("\nLots of removals:", count, count_filter_not_passed)

            ##new_search_width = search_width - (count+count_filter_not_passed)
            new_search_width = sampled_top_indices.shape[0]  # search_width

            self.count_not_passed += count
            self.count_filter_not_passed += count_filter_not_passed

            # ========== Compute loss on these candidates and take the argmin. ========== #
            # create input
            sampled_top_embeds = embed_layer(sampled_top_indices)
            if self.use_prefix_cache:
                input_embeds = torch.cat(
                    [
                        sampled_top_embeds,
                        after_embeds.repeat(new_search_width, 1, 1),
                        target_embeds.repeat(new_search_width, 1, 1),
                    ],
                    dim=1,
                )
            else:
                input_embeds = torch.cat(
                    [
                        before_embeds.repeat(new_search_width, 1, 1),
                        behavior_embeds.repeat(new_search_width, 1, 1),
                        sampled_top_embeds,
                        after_embeds.repeat(new_search_width, 1, 1),
                        target_embeds.repeat(new_search_width, 1, 1),
                    ],
                    dim=1,
                )

            # Auto Find Batch Size for foward candidates (each time go OOM will decay search_batch_size // 2)
            loss = find_executable_batch_size(
                self.compute_candidates_loss, self.search_batch_size
            )(input_embeds, target_ids)

            # ========== Update the optim_ids with the best candidate ========== #
            optim_ids = sampled_top_indices[loss.argmin()].unsqueeze(0)
            used_tokens.add(tuple(optim_ids.squeeze().tolist()))

            test_case_ids = torch.cat([behavior_ids, optim_ids], dim=1)
            test_case = tokenizer.decode(test_case_ids[0])
            all_test_cases.append(test_case)
            current_loss = loss.min().item()
            all_losses.append(current_loss)

            # ========== Logs Wandb ========== #
            logs_wandb = {
                "loss": current_loss,
                "test_case": test_case,
                "count_not_passed": self.count_not_passed,
                "count_filter_not_passed": self.count_filter_not_passed,
                "optim_ids": optim_ids,
                "time": (time.time() - self.start_time),  # in seconds
                "running_mean_time": (time.time() - self.start_time) / (i + 1),
                "n_fwd_target_model_input_tokens": self.n_fwd_target_model_input_tokens,
                "n_fwd_target_model_cache_tokens": self.n_fwd_target_model_cache_tokens,
                "n_target_model_output_tokens": self.n_target_model_output_tokens,
                "n_target_model_fwd_passes": self.n_target_model_fwd_passes,
                "n_target_model_bcwd_passes": self.n_target_model_bcwd_passes,
                "n_bcwd_target_model_input_tokens": self.n_bcwd_target_model_input_tokens,
                "n_bcwd_target_model_optim_tokens": self.n_bcwd_target_model_optim_tokens,
                "n_bcwd_target_model_cache_tokens": self.n_bcwd_target_model_cache_tokens,
                "n_fwd_target_model_input_tokens_eval": self.n_fwd_target_model_input_tokens_eval,
                "n_fwd_target_model_cache_tokens_eval": self.n_fwd_target_model_cache_tokens_eval,
                "n_target_model_output_tokens_eval": self.n_target_model_output_tokens_eval,
                "n_target_model_fwd_passes_eval": self.n_target_model_fwd_passes_eval,
            }

            self.logger.log(logs_wandb)

            # ========== Eval and Early Stopping ========== #
            if (i % eval_steps == 0) or (i == num_steps - 1):

                p_output = f"\n===>Step {i}\n===>Test Case: {test_case}\n===>Loss: {current_loss}"

                if self.adaptive:
                    print('Checking the filter', self.jb_filter.apply_filter(test_case,
                                                                             window_size=self.jb_filter_window_size,
                                                                             metric_name=self.jb_filter_metric_name,
                                                                             threshold=self.jb_filter_threshold,
                                                                             verbose=True))

                ###if : # and current_loss < check_refusal_min_loss:
                input_str = self.template.format(instruction=test_case)
                is_refusal, completions, _ = check_refusal_completions(
                    model, tokenizer, inputs=[input_str]
                )
                self.n_target_model_output_tokens_eval += len(tokenizer(completions[0])["input_ids"])
                self.n_fwd_target_model_input_tokens_eval += len(tokenizer(input_str)["input_ids"])
                self.n_target_model_fwd_passes_eval += 1

                p_output += f"\n\n===>Completion: {completions[0]}"
                self.logger.log({"is_refusal": is_refusal[0], "completions": completions[0]})
                print(p_output)

                # Refusal check
                if not is_refusal[0] and (
                        current_loss < check_refusal_min_loss or completions[0].startswith(self.target)):
                    break

                # Early stopping check
                if early_stopping and (current_loss < early_stopping_min_loss):
                    print(f"Early stopping at step {i} with loss {current_loss}")
                    break

            del input_embeds, sampled_top_embeds, logits, shift_logits, loss
            torch.cuda.empty_cache()
            gc.collect()

        logs = {
            "final_loss": current_loss,
            "all_losses": all_losses,
            "all_test_cases": all_test_cases,
        }

        wandb_log = {
            "time": (self.start_time - time.time()),  # in seconds
            "running_mean_time": (self.start_time - time.time()) / (i + 1),
            "n_fwd_target_model_input_tokens_eval": self.n_fwd_target_model_input_tokens_eval,
            "n_fwd_target_model_cache_tokens_eval": self.n_fwd_target_model_cache_tokens_eval,
            "n_target_model_output_tokens_eval": self.n_target_model_output_tokens_eval,
            "n_target_model_fwd_passes_eval": self.n_target_model_fwd_passes_eval,
        }

        self.logger.log(wandb_log)

        self.logger.finish()
        return [test_case], [logs]

    def compute_candidates_loss(self, search_batch_size, input_embeds, target_ids):
        if self.search_batch_size != search_batch_size:
            print(f"INFO: Setting candidates search_batch_size to {search_batch_size})")
            self.search_batch_size = search_batch_size
            torch.cuda.empty_cache()
            gc.collect()

        all_loss = []
        for i in range(0, input_embeds.shape[0], search_batch_size):
            with torch.no_grad():
                input_embeds_batch = input_embeds[i: i + search_batch_size]

                if self.use_prefix_cache:
                    # Expand prefix cache to match batch size
                    prefix_cache_batch = self.prefix_cache
                    current_batch_size = input_embeds_batch.shape[0]
                    prefix_cache_batch = []
                    for i in range(len(self.prefix_cache)):
                        prefix_cache_batch.append([])
                        for j in range(len(self.prefix_cache[i])):
                            prefix_cache_batch[i].append(
                                self.prefix_cache[i][j].expand(
                                    current_batch_size, -1, -1, -1
                                )
                            )
                    outputs = self.model(
                        inputs_embeds=input_embeds_batch,
                        past_key_values=prefix_cache_batch,
                    )
                    self.n_fwd_target_model_input_tokens += input_embeds_batch.shape[1] * input_embeds_batch.shape[0] # len * batch_size
                    self.n_fwd_target_model_cache_tokens += self.cache_size * input_embeds_batch.shape[0]
                    self.n_target_model_fwd_passes += input_embeds_batch.shape[0]
                    self.n_target_model_output_tokens += target_ids.shape[1] * input_embeds_batch.shape[0]
                else:
                    outputs = self.model(inputs_embeds=input_embeds_batch)
                    self.n_fwd_target_model_input_tokens += input_embeds_batch.shape[1] * input_embeds_batch.shape[0]
                    self.n_target_model_fwd_passes += input_embeds_batch.shape[0]
                    self.n_target_model_output_tokens += target_ids.shape[1]  * input_embeds_batch.shape[0]
            logits = outputs.logits

            # compute loss
            # Shift so that tokens < n predict n
            tmp = input_embeds_batch.shape[1] - target_ids.shape[1]
            shift_logits = logits[..., tmp - 1: -1, :].contiguous()
            shift_labels = target_ids.repeat(input_embeds_batch.shape[0], 1)
            # Flatten the tokens
            loss_fct = CrossEntropyLoss(reduction="none")
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
            )
            loss = loss.view(input_embeds_batch.shape[0], -1).mean(dim=1)
            all_loss.append(loss)

            del outputs, logits, loss
            torch.cuda.empty_cache()
            gc.collect()

        return torch.cat(all_loss, dim=0)