import json
import gc
from tqdm import tqdm
import torch
from torch.nn import CrossEntropyLoss
# https://huggingface.co/docs/accelerate/v0.11.0/en/memory#accelerate.find_executable_batch_size
from accelerate.utils import find_executable_batch_size
from ..baseline import SingleBehaviorRedTeamingMethod
from ..model_utils import load_model_and_tokenizer, get_template
from ..check_refusal_utils import check_refusal_completions
from .templates import ADV_STRING_INIT, REFERENCES
from .utils import autodan_sample_control, load_mutate_model

class AutoDAN(SingleBehaviorRedTeamingMethod):

    def __init__(self,
                target_model, # Target model
                mutate_model, # Mutate model (GPT/Opensource models)
                targets_path,
                num_steps=100,
                eval_steps=5,
                eval_with_check_refusal=True,
                batch_size=128,
                num_elites=0.1,
                crossover=0.5,
                num_points=5,
                mutation=0.01,
                forward_batch_size=None, # if None then will use automated batch_size
                **kwargs,
                ):
        """
        :param target_model: a dictionary specifying the target model (kwargs to load_model_and_tokenizer)
        :param mutate_model: a dictionary specifying the mutate model (kwargs to load_mutate_model)
        :param targets_path: the path to the targets JSON file
        :param num_steps: the number of optimization steps to use
        :param eval_steps: the number of steps between each evaluation
        :param eval_with_check_refusal: whether to generate and check for refusal every eval_steps; used for early stopping
        :param batch_size: the number of candidates to consider at each step
        :param num_elites: AutoDAN mutation parameter
        :param crossover: AutoDAN mutation parameter
        :param num_points: AutoDAN mutation parameter
        :param mutation: AutoDAN mutation parameter
        :param forward_batch_size: the batch size to use for computing candidates
        """
        
        model_short_name = target_model.pop('model_short_name', None)
        developer_name = target_model.pop('developer_name', None)

        assert model_short_name, "Developer name needed to be specified for references template, for example: Vicuna or Llama2"
        assert developer_name, "Developer name needed to be specified for references template, for example: developer_name of Llama2 is Meta and developer_name of vicuna is LMSYS"

        # super().__init__(model)
        self.num_steps = num_steps
        self.eval_steps = eval_steps
        self.eval_with_check_refusal = eval_with_check_refusal
        self.batch_size = batch_size
        self.num_elites = int(self.batch_size * num_elites)
        self.crossover = crossover
        self.num_points = num_points
        self.mutation = mutation

        self.autodan_target_prefix = "(🔓Developer Mode Output) "
        with open(targets_path, 'r', encoding='utf-8') as file:
            self.behavior_id_to_target = json.load(file)

        references = REFERENCES
        for o in range(len(references)):
            references[o] = references[o].replace('ChatGPT', model_short_name)
            references[o] = references[o].replace('ModelKeeper', developer_name)
            references[o] = references[o] + ' [PROMPT]: '
        self.references = references

        self.forward_batch_size = forward_batch_size

        # ===== Init Mutate Model =====
        self.mutate_model, allocated_gpus = load_mutate_model(**mutate_model)

        # ===== Init Target Model =====
        self.model_name_or_path = target_model['model_name_or_path']
        # manually allocate target model into unallocated gpus
        num_gpus = target_model.get("num_gpus", 1)
        max_memory = {i: torch.cuda.mem_get_info(i)[0] if i >= allocated_gpus else 0 for i in range(0, allocated_gpus + num_gpus)}
        self.model, self.tokenizer = load_model_and_tokenizer(**target_model, max_memory=max_memory)
        
        # ====== Init instruction template ====
        self.template = get_template(**target_model)['prompt']
        self.before_tc, self.after_tc = self.template.split("{instruction}")

    def generate_test_cases_single_behavior(self, behavior_dict, verbose=False, **kwargs):
        """
        Generates test cases for a single behavior

        :param behavior: a dictionary specifying the behavior to generate test cases for
        :param verbose: whether to print progress
        :return: a test case and logs
        """
        # Initialize conversations
        tokenizer = self.tokenizer
        model = self.model

        # ==== Method Variables ====
        batch_size = self.batch_size
        num_elites = self.num_elites
        crossover = self.crossover
        num_points = self.num_points
        mutation = self.mutation
        references = self.references
        autodan_target_prefix = self.autodan_target_prefix
        num_steps = self.num_steps
        eval_steps =self.eval_steps
        eval_with_check_refusal = self.eval_with_check_refusal
        template = self.template
        before_tc = self.before_tc
        after_tc = self.after_tc

        self.forward_batch_size = batch_size if self.forward_batch_size is None else self.forward_batch_size  # starting foward_batch_size, will auto reduce batch_size later if go OOM
        mutate_model = self.mutate_model
        
        # ========== Behavior meta data ==========
        behavior = behavior_dict['Behavior']
        context_str = behavior_dict['ContextString']
        behavior_id = behavior_dict['BehaviorID']
        target = autodan_target_prefix + self.behavior_id_to_target[behavior_id]

        if context_str:
            behavior = f"{context_str}\n\n---\n\n{behavior}"
        
        # ==== Early stopping vars =====
        previous_loss = None
        early_stopping_threshold = 10
                                           
        print(f'Behavior: {behavior_id} || Target: {target}')

        # ===== init target embeds =====
        embed_layer = model.get_input_embeddings()
        target_ids = tokenizer([target], padding=False, add_special_tokens=False, return_tensors='pt')['input_ids'].to(model.device)
        target_embeds = embed_layer(target_ids)

        new_adv_prefixes = references[:batch_size]
        logs = []
        test_case = behavior

        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        for i in tqdm(range(num_steps)):
            # ======== compute loss and best candidates =======
            # candidate = USER: adv_prefix + behavior + ASSISTANT:
            candidates = [before_tc + p + behavior + after_tc for p in new_adv_prefixes]
            candidates_tokens = tokenizer(candidates, padding=True, add_special_tokens=False, return_tensors='pt').input_ids
            candidates_embeds = embed_layer(candidates_tokens.to(model.device))

            input_embeds = torch.cat([candidates_embeds, target_embeds.repeat(batch_size, 1, 1)], dim=1)
            loss = find_executable_batch_size(self.compute_candidates_loss, self.forward_batch_size)(input_embeds, target_ids)

            best_new_adv_prefix_id = loss.argmin()
            best_new_adv_prefix = new_adv_prefixes[best_new_adv_prefix_id]
            current_loss = loss.min().item()

            logs.append(dict(loss=current_loss, test_case=test_case))
            # ========== Eval and Early Stopping ========== #
            test_case = best_new_adv_prefix + behavior
            if (i % eval_steps == 0) or (i == num_steps - 1):
                log_str = f'\n===>Step {i}\n===>Test Case: {test_case}\n===>Loss: {current_loss}'
                
                # ====== Checking for refusal =======
                if eval_with_check_refusal:
                    input_str = template.format(instruction=test_case)
                    is_refusal, completions, _ = check_refusal_completions(model, tokenizer, inputs=[input_str])
                    log_str += f"\n===>Completion: {completions[0]}"
                    if not is_refusal[0]:
                        break
                        
                if verbose:
                    print(log_str)
            
            # ====== Checking for early stopping if loss stuck at local minima ======
            if previous_loss is None or current_loss < previous_loss:
                previous_loss = current_loss
                early_stopping_threshold = 10
            else:
                early_stopping_threshold -= 1
                if early_stopping_threshold <= 0:
                    break

            unfiltered_new_adv_prefixes = autodan_sample_control(control_prefixes=new_adv_prefixes,
                                                                score_list=loss.float().cpu().numpy().tolist(),
                                                                num_elites=num_elites,
                                                                batch_size=batch_size,
                                                                crossover=crossover,
                                                                num_points=num_points,
                                                                mutation=mutation,
                                                                mutate_model=mutate_model,
                                                                reference=references,
                                                                no_mutate_words=[])
            new_adv_prefixes = unfiltered_new_adv_prefixes
            
            del candidates_embeds, input_embeds, loss
            torch.cuda.empty_cache()
            gc.collect()
        
        return test_case, logs

    def compute_candidates_loss(self, forward_batch_size, input_embeds, target_ids):
        if self.forward_batch_size != forward_batch_size:
            print(f"INFO: Setting candidates forward_batch_size to {forward_batch_size})")
            self.forward_batch_size = forward_batch_size

        all_loss = []
        for i in range(0, input_embeds.shape[0], forward_batch_size):
            with torch.no_grad():
                input_embeds_batch = input_embeds[i:i+forward_batch_size]
                outputs = self.model(inputs_embeds=input_embeds_batch)
            logits = outputs.logits

            # compute loss
            # Shift so that tokens < n predict n
            tmp = input_embeds.shape[1] - target_ids.shape[1]
            shift_logits = logits[..., tmp-1:-1, :].contiguous()
            shift_labels = target_ids.repeat(forward_batch_size, 1)
            # Flatten the tokens
            loss_fct = CrossEntropyLoss(reduction='none')
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            loss = loss.view(forward_batch_size, -1).mean(dim=1)
            all_loss.append(loss)

        return torch.cat(all_loss, dim=0)        