import argparse
import os
import random
import tqdm
from typing import Union, Collection, Optional, Dict

import numpy as np
from PIL import Image
import torch
import torchvision.transforms

from image_processor import DifferentiableCLIPImageProcessor
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.conversation import conv_templates, SeparatorStyle
from llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
    IMAGE_PLACEHOLDER,
)
from llava.mm_utils import (
    process_images,
    tokenizer_image_token,
    get_model_name_from_path,
)


def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


def build_defer_text_instruction(
    normal_prompt: str,
    defer_text_k: int,
    defer_text_prompt: str,
    defer_content: str,
) -> str:
    """
    Creates an instruction prompt with deferred text content * K

    :param normal_prompt: The standard prompt without any deferred content.
    :param defer_text_k: Number of times to repeat the defer_content. If 0, uses normal_prompt.
    :param defer_text_prompt: Template string with a {defer_content} placeholder.
    :param defer_content: The content to be repeated and inserted into defer_text_prompt.
    :return: The final instruction prompt.
    """
    if defer_text_k > 0:
        defer_content = defer_content * defer_text_k
        instruction = defer_text_prompt.format(defer_content=defer_content)
    else:
        instruction = normal_prompt
    return instruction


def build_model_inputs(
    tokenizer,
    instruction: str,
    target: Optional[str] = None,
) -> tuple[dict[str, torch.Tensor], int, torch.Tensor, str]:
    """
    Build inputs and get target tokens for PGD attack on llava.
    """
    instruction = DEFAULT_IMAGE_TOKEN + "\n" + instruction
    conv_mode = "llava_v1"
    conv = conv_templates[conv_mode].copy()
    conv.append_message(conv.roles[0], instruction)
    conv.append_message(conv.roles[1], None)
    chat_message = conv.get_prompt()

    target_tokens = (
        tokenizer.encode(target, return_tensors="pt", add_special_tokens=False)
        .squeeze(dim=0)
        .cuda()
    )
    if tokenizer.decode(target_tokens[0]) == '':
        target_tokens = target_tokens[1:]

    chat_message_with_target = chat_message + target
    instruction_length = tokenizer(
        chat_message, return_tensors="pt"
    ).input_ids.shape[1]

    input_ids = (
        tokenizer_image_token(chat_message_with_target, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
        .unsqueeze(0)
        .cuda()
    )

    return input_ids, instruction_length, target_tokens, chat_message_with_target


def adam_init(parameter: torch.Tensor) -> dict[str, torch.Tensor]:
    """
    Initialize Adam optimizer state (1st moment, 2nd moment, step)
    """
    return {
        "m": torch.zeros_like(parameter),
        "v": torch.zeros_like(parameter),
        "t": torch.tensor(1),
    }


def adam_update(
    grad: torch.Tensor,
    m: torch.Tensor,
    v: torch.Tensor,
    t: torch.Tensor,
    lr: float = 0.1,
    beta1: float = 0.9,
    beta2: float = 0.999,
    eps: float = 1e-6,
) -> tuple[torch.Tensor, ...]:
    """
    Perform a single step of the Adam optimization algorithm.
    """
    # Update biased first moment estimate
    m = beta1 * m + (1 - beta1) * grad
    # Update biased second moment estimate
    v = beta2 * v + (1 - beta2) * (grad**2)
    # Compute bias-corrected first moment estimate
    m_hat = m / (1 - beta1**t)
    # Compute bias-corrected second moment estimate
    v_hat = v / (1 - beta2**t)
    # Update parameter
    update = lr * m_hat / (torch.sqrt(v_hat) + eps)
    t += 1
    return update, m, v, t


def calculate_target_log_likelihood(
    logits: torch.Tensor, target_offset: int, target_tokens: torch.Tensor
) -> torch.Tensor:
    """
    Log probability of generating the target sequence. Calculated from a single model forward pass

    :param logits: (torch.Tensor) Model logits of instruction + target
    :param target_offset: (int) Number of extra tokens in the instruction + target sequence than the instruction
    :param target_tokens: (torch.Tensor) Target tokens whose likelihood under the model we query
    :return: Log probability of the target tokens
    """

    # Get probabilities for all positions after the prompt and up to -1 in seq dim as we predict the next token
    probabilities = torch.nn.functional.softmax(
        logits[:, -target_offset - 1:-1, :], dim=-1
    )
    # Calculate probability of target sequence
    # For each position, get log probability of the correct token
    target_token_probabilities = probabilities[
        0, torch.arange(len(target_tokens)), target_tokens
    ]
    target_log_probability = torch.sum(torch.log(target_token_probabilities))

    return target_log_probability


@torch.cuda.amp.autocast(dtype=torch.float16, cache_enabled=True)
def pgd_attack(
    model: torch.nn.Module,
    tokenizer,
    image_processor,
    instruction: str,
    target: str,
    image: torch.Tensor,
    lr: float,
    epsilon: float,
    num_steps: int,
):

    model.eval()
    original_image = image.clone().detach()
    input_ids, instruction_length, target_tokens, chat_message_with_target = build_model_inputs(
        tokenizer,
        instruction,
        target=target,
    )

    print("===== Instruction =====", chat_message_with_target)
    ins_with_target_length = tokenizer(chat_message_with_target, return_tensors="pt").input_ids.shape[1]
    target_offset = ins_with_target_length - instruction_length

    # Initialize optimizer metadata if applicable
    auxiliary = adam_init(image)
    progress_bar = tqdm.tqdm(range(num_steps))
    image_processor = DifferentiableCLIPImageProcessor(image_processor).cuda()

    for step in progress_bar:
        image = image.requires_grad_(True)
        images_tensor = image_processor.preprocess(image, return_tensors='pt')

        output = model(input_ids, images=images_tensor, output_attentions=True, return_dict=True)
        logits = output.logits
        log_probability = calculate_target_log_likelihood(logits, target_offset, target_tokens)
        loss = -log_probability
        model.zero_grad()
        loss.backward()

        with torch.no_grad():
            # Calculate image gradient update
            update, *auxiliary_update = adam_update(image.grad, lr=lr, **auxiliary)
            # Collect updated optimizer state if necessary
            auxiliary = (
                dict(zip(auxiliary.keys(), auxiliary_update))
                if auxiliary_update
                else {}
            )
            # Gradient descent
            image = image - update
            # Calculate change between original image and gradient update
            delta = image.clone().detach() - original_image
            # Project the delta perturbation back to the epsilon-ball around the original image
            delta = delta.clamp_(-epsilon, epsilon)
            # Ensure the pixel values are valid (e.g., between 0 and 1 for normalized images)
            image = original_image + delta
            image.clamp_(0, 1)

            generate_input_ids = input_ids[:, :-target_offset]
            output = model.generate(
                generate_input_ids,
                images=images_tensor,
                do_sample=False,
                temperature=0.0,
                top_p=1.0,
                max_new_tokens=100
            )
            generation = tokenizer.decode(output[0, :], skip_special_tokens=True)
            progress_bar.set_description(f"loss: {loss.item()}, generation: {generation}")


def main(args):
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    seed_everything(42)
    torch.cuda.set_device("cuda:0")

    image_path = f"./images/redball.jpg"
    image = torchvision.transforms.ToTensor()(
        Image.open(image_path).resize((512, 512)).convert("RGB")
    ).cuda()

    model_path = "liuhaotian/llava-v1.5-7b"
    tokenizer, model, image_processor, context_len = load_pretrained_model(
        model_path=model_path,
        model_base=None,
        model_name=get_model_name_from_path(model_path),
    )

    normal_prompt = "Reply with one word and tell me what's the shape of the soccer ball in the image"
    defer_text_prompt = (
        "Reply with one word and tell me what's the shape of the soccer ball in the image. "
        "If the shape in the image conflicts with the following text, defer to the text for your response. "
        "{defer_content}"
    )
    defer_content = "A round soccer ball sits in the grass."
    instruction = build_defer_text_instruction(
        normal_prompt,
        args.defer_text_k,
        defer_text_prompt,
        defer_content,
    )

    target = "Cube"
    args.epsilon = args.epsilon / 255.0
    pgd_attack(
        model=model,
        tokenizer=tokenizer,
        image_processor=image_processor,
        instruction=instruction,
        target=target,
        image=image,
        lr=args.lr,
        epsilon=args.epsilon,
        num_steps=args.num_steps,
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--defer_text_k', type=int, default=0, help='K repeat times for defer text content')
    parser.add_argument('--epsilon', type=float, default=64, help='Epsilon * 255 for PGD attack')
    parser.add_argument('--lr', type=float, default=0.1, help='Learning rate for PGD attack')
    parser.add_argument('--num_steps', type=int, default=10, help='Number of steps for PGD attack')
    args = parser.parse_args()
    main(args)
