import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
import torch
import torch.nn.functional as F
from transformers import Blip2Processor, Blip2ForConditionalGeneration, InstructBlipProcessor, InstructBlipForConditionalGeneration, AutoProcessor, LlavaForConditionalGeneration, AutoTokenizer, Qwen2VLForConditionalGeneration
from PIL import Image
from torchvision import transforms
torch.use_deterministic_algorithms(True)

from qwen2vl_tensor_processor import Qwen2VLTensorImageProcessor

Threshold = 1000

class VisionLanguageFeatureModifier:
    def __init__(self, model_name, device=None, start_slice_num=2, length_slice=2):
        self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
        self.model_name = model_name
        self.model_type = None
        self.start_slice_num = start_slice_num
        self.length_slice = length_slice

        hook_layername = None
        if "blip2" in model_name:
            self.model_type = "blip2"
            hook_layername = "language_projection"
            self.processor = Blip2Processor.from_pretrained(model_name)
            self.model = Blip2ForConditionalGeneration.from_pretrained(model_name, attn_implementation="eager", torch_dtype=torch.float32)
        elif "instructblip" in model_name:
            self.model_type = "instructblip"
            hook_layername = "language_projection"
            self.processor = InstructBlipProcessor.from_pretrained(model_name)
            self.model = InstructBlipForConditionalGeneration.from_pretrained(model_name, attn_implementation="eager", torch_dtype=torch.float32)
        elif "llava" in model_name:
            self.model_type = "llava"
            hook_layername = "multi_modal_projector.linear_2"
            self.processor = AutoProcessor.from_pretrained(model_name)
            self.model = LlavaForConditionalGeneration.from_pretrained(model_name, attn_implementation="eager", torch_dtype=torch.float32)
            self.model.language_model.half()
        elif "Qwen" in model_name:
            self.model_type = "qwen"
            hook_layername = "visual.merger.mlp.2"
            self.processor = AutoProcessor.from_pretrained(model_name)
            self.model = Qwen2VLForConditionalGeneration.from_pretrained(model_name, attn_implementation="eager", torch_dtype=torch.float32)
        else:
            raise ValueError("Unsupported model.")

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = self.model.to(self.device)

        self.hook = None
        self.embedding_replacement = None
        self.visual_features = None
        self.text_embeddings = None
        self.cos_similarity = None
        self.vision_layer = None
        self.y = None

        for layername, modu in self.model.named_modules():
            if layername == hook_layername:
                self.vision_layer = modu
                break
        
        if self.vision_layer is None:
            raise ValueError("Visual embedding layer not found.")

    def generate_embedding_from_text(self, text, num_queries):
        tokenized = self.tokenizer(text, return_tensors="pt", add_special_tokens=False)#, truncation=True)
        input_ids = tokenized.input_ids.to(self.device)
        length_slice = self.length_slice
        num_slice = int(num_queries / length_slice)
        start_num = self.start_slice_num
        x = input_ids[:, start_num:start_num+length_slice]
        y = torch.cat([x]*num_slice, dim=1)
        self.y = y

        embeddings = self.model.get_input_embeddings()(y)
        if self.model_type == "qwen":
            embeddings = embeddings.squeeze(0)
        print(f"Generated text embedding shape: {embeddings.shape}")
        return embeddings
    
    def register_hook(self, replacement_text):
        self.replacement_text = replacement_text

        def hook_fn(module, input, output):
            print("Original visual feature output shape:", output.shape)  # (batch_size, num_queries, embed_dim)
            if self.model_type == "qwen":
                num_queries, embed_dim = output.shape
            else:
                batch_size, num_queries, embed_dim = output.shape

            text_embeddings = self.generate_embedding_from_text(self.replacement_text, num_queries)

            self.visual_features = output.detach()
            self.text_embeddings = text_embeddings.detach()

            if text_embeddings.shape == output.shape:
                print("Replace the visual embeddings with text embeddings.")
                return text_embeddings
            else:
                raise ValueError(f"Generated text embedding dimension mismatch: expected {output.shape}, but got {text_embeddings.shape}")

        self.hook = self.vision_layer.register_forward_hook(hook_fn)

    def remove_hook(self):
        if self.hook:
            self.hook.remove()
            self.hook = None
            print("Hook has been removed.")

    def process_image(self, image_path):
        image = Image.open(image_path).convert("RGB")
        return image

    def generate_text(self, inputs, max_length=1024, do_sample=False):
        with torch.no_grad():
            generated_ids = self.model.generate(
                **inputs,
                max_new_tokens = max_length,
                do_sample = do_sample
            )

        generated_token_length = generated_ids.shape[1] - inputs["input_ids"].shape[1]

        generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        return generated_text, generated_token_length
    
    def compute_similarity(self):
        if self.visual_features is None or self.text_embeddings is None:
            raise ValueError("Please run `generate_text` first to generate the visual features and text embeddings.")

        visual_flat = self.visual_features.view(-1, self.visual_features.shape[-1])
        text_flat = self.text_embeddings.view(-1, self.text_embeddings.shape[-1])

        cosine_sim = F.cosine_similarity(visual_flat, text_flat, dim=-1)
        self.cos_similarity = cosine_sim

        avg_similarity = cosine_sim.mean()
        return avg_similarity
    
    def compute_avgemb_similarity(self):
        if self.visual_features is None or self.text_embeddings is None:
            raise ValueError("Please run `generate_text` first to generate the visual features and text embeddings.")

        visual_avg = self.visual_features.mean(dim=1)
        text_avg = self.text_embeddings.mean(dim=1)

        cosine_sim = F.cosine_similarity(visual_avg, text_avg, dim=1)
        #self.cos_similarity = visual_avg

        return cosine_sim.squeeze()
    
class OptimizedImageModifier(VisionLanguageFeatureModifier):
    def __init__(self, model_name, device=None, start_slice_num=2, length_slice=2, EPSILON = 0.032, lr=0.01, alpha=1, num_steps=50, mu=0.9):
        super().__init__(model_name=model_name, device=device, start_slice_num=start_slice_num, length_slice=length_slice)
        self.EPSILON = EPSILON
        self.lr = lr
        self.alpha = alpha
        self.num_steps = num_steps
        self.mu = mu
        self.optimized_image = None
        self.language_projection_output = None
        self.hook = None
        
        self.normalize_mean = self.processor.image_processor.image_mean
        self.normalize_std = self.processor.image_processor.image_std
        
        self.inverse_transform = transforms.Normalize(
            mean=[-m / s for m, s in zip(self.normalize_mean, self.normalize_std)],
            std=[1 / s for s in self.normalize_std]
        )

        self.norm_transform = transforms.Normalize(mean=self.normalize_mean, std=self.normalize_std)

        self.tensor_processor = Qwen2VLTensorImageProcessor(
            #min_pixels=min_pixels,
            #max_pixels=max_pixels,
            patch_size=14,
            temporal_patch_size=2,
            merge_size=2,
            do_resize=False,
            do_rescale=False,  
        )
    
    def register_adv_hook(self):
        def hook_fn(module, input, output):
            self.visual_features = output

        self.hook = self.vision_layer.register_forward_hook(hook_fn)
        
    def remove_adv_hook(self):
        if self.hook:
            self.hook.remove()
            self.hook = None

    def optimize_image(self, image_path, text_input):
        if self.model_type == "qwen":
            image = Image.open(image_path).convert("RGB").resize((336,336))
            inputs = self.processor(images=image, text=text_input, return_tensors="pt").to(self.device)
            image_tensor = transforms.ToTensor()(image)
            image_tensor = image_tensor.unsqueeze(0).to(self.device)
            original_image = image_tensor.detach().clone()
        else:
            image = Image.open(image_path).convert("RGB")
            inputs = self.processor(images=image, text=text_input, return_tensors="pt").to(self.device)

            original_image = inputs["pixel_values"].detach().clone()
        input_image = original_image.clone().requires_grad_(True)
        bestadvimage = torch.empty_like(input_image)
        begin_cos = 0
        begin_score = 0
        record_step = 0
        record_cos = 0
        record_score = 0
        token_length = 0
        max_token_length = float('-inf')
        record_generated_text = None
        generated_text = None
        origin_generated_text = None
        origin_token_length = 0
        flag = 0

        EPSILON = self.EPSILON
        g = torch.zeros_like(input_image)
        print("image:", image_path)
        print("lr:", self.lr)
        print("EPSILON:", EPSILON)
        print("alpha:", self.alpha)
        print("decay factor:", self.mu)
        
        for step in range(self.num_steps):
            input_image.grad = None

            if self.model_type == "qwen":
                clamp_image = torch.clamp(input_image.data, 0, 1)
                perturbation = torch.clamp(clamp_image - original_image, -EPSILON, EPSILON)
                adv_image = perturbation + original_image
                input_image.data.copy_(adv_image.data)
                pixel_values, _ = self.tensor_processor.preprocess_tensor_keep_grad(input_image)
                inputs['pixel_values'] = pixel_values
            else:
                inv_image = self.inverse_transform(input_image.data.squeeze(0))
                inv_image = torch.clamp(inv_image, 0, 1)
                inv_original_image = self.inverse_transform(original_image.squeeze(0))
                perturbation = torch.clamp(inv_image - inv_original_image, -EPSILON, EPSILON)
                inv_image = perturbation + inv_original_image
                input_image_normalized = self.norm_transform(inv_image).unsqueeze(0)
                input_image.data.copy_(input_image_normalized.data)

                inputs['pixel_values'] = input_image
            
            if step % 10 == 0 or step == self.num_steps - 1:
                generated_text, token_length = self.generate_text(inputs=inputs, max_length=1024, do_sample=False)

            self.register_adv_hook()
            _ = self.model(**inputs)
            self.remove_adv_hook()

            cosine_sim = self.compute_similarity()
            score = torch.std(self.cos_similarity, correction=0)

            if step % 10 == 0 or step == self.num_steps - 1:
                if token_length >= Threshold:
                    flag = 1
                    max_token_length = token_length
                    record_generated_text = generated_text
                    record_step = step
                    record_cos = cosine_sim.item()
                    record_score = score.item()
                    bestadvimage.copy_(input_image)
                    if step == 0:
                        begin_cos = cosine_sim.item()
                        begin_score = score.item()
                        origin_generated_text = generated_text
                        origin_token_length = token_length
                    break
                if token_length > max_token_length:
                    max_token_length = token_length
                    record_generated_text = generated_text
                    record_step = step
                    record_cos = cosine_sim.item()
                    record_score = score.item()
                    bestadvimage.copy_(input_image)
            
            loss = -cosine_sim + self.alpha * score

            loss.backward()
            if self.mu != 0:
                grad = input_image.grad.detach()
                grad_norm = grad.abs().mean(dim=(1,2,3), keepdim=True)
                grad_normalized = grad / (grad_norm + 1e-8)
                g = self.mu * g + grad_normalized
                input_image.data = input_image - self.lr * g.sign()
            else:
                input_image.data = input_image - self.lr * input_image.grad.sign()

            if step == 0:
                print("cos sim shape:", self.cos_similarity.shape)
                begin_cos = cosine_sim.item()
                begin_score = score.item()
                origin_generated_text = generated_text
                origin_token_length = token_length
            if step % 10 == 0 or step == self.num_steps - 1:
                print(f"Step {step}/{self.num_steps} - loss: {loss.item():.4f}, sim: {cosine_sim.item():.4f}, score: {score.item():.4f}, token length: {token_length}")

        self.optimized_image = bestadvimage.detach().cpu()

        return [
            image_path,
            origin_token_length,
            begin_cos,
            begin_score,
            max_token_length,
            record_step,
            record_cos,
            record_score,
            origin_generated_text,
            record_generated_text,
            flag
        ]

    def save_optimized_image(self, save_path):
        optimized_tensor = self.optimized_image.squeeze(0)
        if self.model_type != "qwen":
            optimized_tensor = self.inverse_transform(optimized_tensor)

        optimized_pil = transforms.ToPILImage()(optimized_tensor)

        optimized_path = save_path + ".jpg"
        optimized_pil.save(optimized_path)
        print(f"The optimized image has been saved to: {optimized_path}")
        optimized_pil.save(optimized_path.replace(".jpg", ".png"))
        optimized_pil.save(optimized_path.replace(".jpg", ".bmp"))

        tensorfile = optimized_path.replace(".jpg", ".pt")
        torch.save(self.optimized_image, tensorfile)
        print(f"The optimized tensor has been saved to: {tensorfile}")

