import gc
import torch
import numpy as np
from tqdm import tqdm

from prj_rag.opt_utils import (
    token_gradients,
    sample_control,
    multi_cord_sample_control,
    get_logits,
    target_loss,
    get_filtered_cands,
)


def get_nonascii_toks(tokenizer, device="cpu"):

    def is_ascii(s):
        return s.isascii() and s.isprintable()

    ascii_toks = []
    for i in range(3, tokenizer.vocab_size):
        if not is_ascii(tokenizer.decode([i])):
            ascii_toks.append(i)

    if tokenizer.bos_token_id is not None:
        ascii_toks.append(tokenizer.bos_token_id)
    if tokenizer.eos_token_id is not None:
        ascii_toks.append(tokenizer.eos_token_id)
    if tokenizer.pad_token_id is not None:
        ascii_toks.append(tokenizer.pad_token_id)
    if tokenizer.unk_token_id is not None:
        ascii_toks.append(tokenizer.unk_token_id)

    return torch.tensor(ascii_toks, device=device)


# DEPRECATED
def generate(model, tokenizer, input_ids, assistant_role_slice, gen_config=None):
    if gen_config is None:
        gen_config = model.generation_config
        gen_config.max_new_tokens = 32

    if gen_config.max_new_tokens > 50:
        print("WARNING: max_new_tokens > 32 may cause testing to slow down.")

    # print("Input string after assistant_role: ", tokenizer.decode(input_ids[: assistant_role_slice.stop]).strip())

    input_ids = input_ids[: assistant_role_slice.stop].to(model.device).unsqueeze(0)

    attn_masks = torch.ones_like(input_ids).to(model.device)
    output_ids = model.generate(
        input_ids,
        attention_mask=attn_masks,
        generation_config=gen_config,
        pad_token_id=tokenizer.pad_token_id,
    )[0]

    # print("Input string no slicing: ", tokenizer.decode(output_ids).strip())

    return output_ids[assistant_role_slice.stop :]


# DEPRECATED
def check_for_attack_success(
    model, tokenizer, input_ids, assistant_role_slice, test_prefixes, gen_config=None
):

    gen_str = tokenizer.decode(
        generate(
            model, tokenizer, input_ids, assistant_role_slice, gen_config=gen_config
        )
    ).strip()

    jailbroken = not any([prefix in gen_str for prefix in test_prefixes])

    # if(jailbroken):
    # print("output string of llm: ", gen_str)

    return jailbroken


def gcg_attack(
    model,
    tokenizer,
    num_steps_gcg: int,
    batch_size_gcg: int,
    topk_gcg: int,
    prompt_managers: dict,
    adv_control_prefix: str,
    adv_control_suffix: str,
    device: str,
    num_coordinates=1,
    iters_per_cord_batch=1,
    allow_non_ascii: bool = False,
    optimize_prefix: bool = True,
    multi_coordinate: bool = True,
    early_termination: bool = False,
    early_stop_threshold: float = 0.005,
    optimize_gpu_memory=True,
    points_per_device: int = 8,
    old_mcg_version=False,
    ret_losses=False
):

    not_allowed_tokens = None if allow_non_ascii else get_nonascii_toks(tokenizer)
    sorted_query_keys = sorted(prompt_managers.keys())
    print("Number of prompts: ", len(sorted_query_keys))
    p0 = prompt_managers[sorted_query_keys[0]]

    # Logging the losses
    iter_losses = []

    # Setting the min_loss to a large value, which will later be overwritten.
    min_loss = np.inf
    # For each num_coordinates value the gcg runs atleast min_iters time.
    min_iters = 1

    max_control_prefix_tokens = p0.max_control_prefix_tokens
    max_control_suffix_tokens = p0.max_control_suffix_tokens

    print(f"Intial control prefix string: {adv_control_prefix}")
    print(f"Intial control suffix string: {adv_control_suffix}")
    print("Number of coordinates that will be changed:", num_coordinates)

    pbar = tqdm(range(num_steps_gcg), desc="GCG Attack")

    for i in pbar:
        input_id_list = []
        coordinate_grad_list = []

        for qid in sorted_query_keys:
            prompt_manager = prompt_managers[qid]

            # Step 1. Encode user prompt with the context as tokens and return token ids.
            input_ids = prompt_manager.get_input_ids(
                adv_control_prefix, adv_control_suffix, optimize_prefix
            )
            input_ids = input_ids.to(device)

            if optimize_prefix:
                control_slice = prompt_manager._adv_control_prefix_slice
            else:
                control_slice = prompt_manager._adv_control_suffix_slice

            # Step 2. Compute Coordinate Gradient of prefix/suffix for each prompt
            coordinate_grad = token_gradients(
                model,
                input_ids,
                control_slice,
                prompt_manager._target_slice,
                prompt_manager._loss_slice,
            )

            input_id_list.append(input_ids.cpu())
            coordinate_grad_list.append(coordinate_grad.cpu())

            if optimize_gpu_memory:
                del input_ids, coordinate_grad
                gc.collect()
                torch.cuda.empty_cache()

        # Computes the average of gradients (Gradient = dLoss/dcontext_control_slice tokens) over all user queries.
        avg_coordinate_grad = torch.stack(coordinate_grad_list, dim=0).mean(dim=0)

        # Step 3. Sample a batch of new tokens based on the coordinate gradient.
        # Notice that we only need the one that minimizes the loss.
        with torch.no_grad():
            # Step 3.1 Slice the input to locate the adversarial control prefix. This is common across all prompts.
            if optimize_prefix:

                adv_control_tokens = input_id_list[0][p0._adv_control_prefix_slice].to(
                    device
                )

            else:
                adv_control_tokens = input_id_list[0][p0._adv_control_suffix_slice].to(
                    device
                )

            # Step 3.2 Randomly sample a batch of replacements.
            if not multi_coordinate:
                num_coordinates = 1
                new_adv_control_toks = sample_control(
                    adv_control_tokens,
                    avg_coordinate_grad,
                    batch_size_gcg,
                    topk=topk_gcg,
                    temp=1,
                    not_allowed_tokens=not_allowed_tokens,
                    verbose=True,
                )

            else:
                # Multi coordinate version
                new_adv_control_toks = multi_cord_sample_control(
                    adv_control_tokens,
                    avg_coordinate_grad,
                    batch_size_gcg,
                    topk=topk_gcg,
                    temp=1,
                    not_allowed_tokens=not_allowed_tokens,
                    num_coordinates=num_coordinates,
                    verbose=False,
                )
            new_adv_control_toks = new_adv_control_toks.cpu()

            if optimize_gpu_memory:
                del avg_coordinate_grad
                gc.collect()
                torch.cuda.empty_cache()

            # Step 3.3 This step ensures all adversarial candidates have the same number of tokens.
            # This step is necessary because tokenizers are not invertible
            # so Encode(Decode(tokens)) may produce a different tokenization.
            # We ensure the number of token remains to prevent the memory keeps growing and run into OOM.

            # Setting filter_cand to FALSE. Throws error as many times it filters out all samples from the batch (of size 64) when filter_cand = True.
            if optimize_prefix:
                max_control_toks = max_control_prefix_tokens
                control_str = adv_control_prefix
            else:
                max_control_toks = max_control_suffix_tokens
                control_str = adv_control_suffix

            new_adv_control_strings = get_filtered_cands(
                tokenizer,
                new_adv_control_toks,
                max_toks=max_control_toks,
                filter_cand=False,
                curr_control=control_str,
            )

            # print("Number of filtered candidates: ", new_adv_control_strings)

            candidate_losses_list = []
            for p, qid in enumerate(sorted_query_keys):
                prompt_manager = prompt_managers[qid]

                # print("AT Prompt:", p)
                # print("CUDA BEG", torch.cuda.memory_allocated(1) / 1024 / 1024)

                if optimize_prefix:
                    control_slice = prompt_manager._adv_control_prefix_slice
                else:
                    control_slice = prompt_manager._adv_control_suffix_slice

                # Step 3.4 Compute loss on these candidates and take the argmin.
                logits, ids = get_logits(
                    model=model,
                    tokenizer=tokenizer,
                    input_ids=input_id_list[p],
                    control_slice=control_slice,
                    test_controls=new_adv_control_strings,
                    return_ids=True,
                    batch_size=points_per_device,
                )

                losses = target_loss(logits, ids, prompt_manager._target_slice)
                # losses = target_loss(
                # logits.cpu(), ids.cpu(), prompt_manager._target_slice
                # )
                candidate_losses_list.append(losses)
                # print("CUDA PRE", torch.cuda.memory_allocated(1) / 1024 / 1024)

                if optimize_gpu_memory:
                    if p % 3 == 0:
                        # logits, ids, losses = logits.cpu(), ids.cpu(), losses.cpu()
                        del logits, ids, losses
                        gc.collect()
                        torch.cuda.empty_cache()

                # print("CUDA POST", torch.cuda.memory_allocated(1) / 1024 / 1024)

            # Computes the average loss over the queries for candidates of batch size B.
            candidate_avg_losses = torch.stack(candidate_losses_list, dim=0).mean(dim=0)

            best_new_adv_id = candidate_avg_losses.argmin().cpu()
            best_new_adv_str = new_adv_control_strings[best_new_adv_id]

            current_loss = candidate_avg_losses[best_new_adv_id]

            # Heuristic to Decrease the number of coordinates in GCG.
            # if num_coordinates > 1 and ((i + 1) % 10 == 0):
            #Version before ChatTemplate was introduced
            if old_mcg_version:
                if (
                    (num_coordinates > 1)
                    and (min_iters > iters_per_cord_batch)
                    and (min_loss < current_loss)
                ):
                    # if min_loss < current_loss:
                    num_coordinates = max(num_coordinates // 2, 1)
                    min_iters = 1
                    print(f"----Changing number of coordinates to {num_coordinates}----")
                    continue

            #New version: decreases the number of coordinates by 2 every iteration 
            else: 
                if num_coordinates > 1:
                    num_coordinates = max(num_coordinates // 2, 1)
                    print(f"----Changing number of coordinates to {num_coordinates}----")

            if current_loss <= min_loss:
                min_loss = current_loss.item()

            if optimize_prefix:
                adv_control_prefix = best_new_adv_str
            else:
                adv_control_suffix = best_new_adv_str

            min_iters += 1

            print("Updated adversarial string:", best_new_adv_str)

            pbar.set_postfix({"Loss": current_loss})
            if ret_losses:
                iter_losses.append(current_loss.cpu().item())

            # (Optional) Clean up the cache.
            del (
                coordinate_grad_list,
                input_id_list,
                adv_control_tokens,
                new_adv_control_toks,
                candidate_avg_losses,
                best_new_adv_id,
                candidate_losses_list,
            )
            gc.collect()
            torch.cuda.empty_cache()

            # Early Termination
            if early_termination and (current_loss < early_stop_threshold):
                print("Loss value < threshold, possibly jailbroken")
                break

    if ret_losses:
        return adv_control_prefix, adv_control_suffix, iter_losses
    return adv_control_prefix, adv_control_suffix
