import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torch.optim as optim
from PIL import Image
import numpy as np

# Torchvision
import torchvision
import torchvision.transforms as transforms
import argparse
import os


from transformers import CLIPProcessor, CLIPModel
from clip_attack import image_embedding,text_embedding, get_cluster_map, get_text_sim_map
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def normalize(images):
    mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cuda()
    std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cuda()
    images = images - mean[None, :, None, None]
    images = images / std[None, :, None, None]
    return images

def denormalize(images):
    mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cuda()
    std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cuda()
    images = images * std[None, :, None, None]
    images = images + mean[None, :, None, None]
    return images

def conver_to_image_mask(mask,image_size=224, grid_size=16, device='cuda:0'):
    mask = mask.squeeze(0)
    row_indices = torch.arange(256) // grid_size
    col_indices = torch.arange(256) % grid_size


    image_mask = torch.zeros((image_size, image_size), dtype=torch.float32)

    for i in range(256):
        row = row_indices[i]
        col = col_indices[i]

        start_row = row * (image_size // grid_size)
        start_col = col * (image_size // grid_size)
        end_row = start_row + (image_size // grid_size)
        end_col = start_col + (image_size // grid_size)


        image_mask[start_row:end_row, start_col:end_col] = mask[i]

    return image_mask.unsqueeze(0).unsqueeze(0).to(device)

def convert_mask_to_patch_mask(mask, patch_size=14, image_size=224):

    num_patches = (image_size // patch_size) ** 2
    mask_reshaped = mask.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)

    patch_mask = mask_reshaped.sum(dim=(4, 5))
    patch_mask = (patch_mask > 0).float()
    patch_mask = patch_mask.view(1, num_patches)

    return patch_mask


model = CLIPModel.from_pretrained(" ")  # add your clip path here
processor = CLIPProcessor.from_pretrained(" ") # add your clip path here
model.eval()
model.to(device)
vision_model = model.vision_model

preprocess = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=processor.image_processor.image_mean, std=processor.image_processor.image_std)
])

with open('disappear_dataset_info.json', 'r', encoding='utf-8') as file:
    data = json.load(file)

num = 0

for image_path, content in tqdm(data.items()):

    num += 1
    if num > 100:
        exit()


    img_id = image_path.split('/')[-1]
    img_id = img_id.split('.')[0]
    target_obj = content['attack_target']["category_name"]
    print(target_obj)


    print(image_path)
    image = Image.open(image_path).convert('RGB')
    image = processor(images=image, return_tensors="pt").to(device)
    image_tensor = image['pixel_values']

    epsilon = 8 / 255.

    with torch.no_grad():

        init_rand = torch.rand((1, 3, 224, 224)).clamp(-epsilon, epsilon).to(device)
        adv_noise = (denormalize(image_tensor) + init_rand).clamp(0, 1) - denormalize(image_tensor)

        image_output = vision_model(image_tensor, output_attentions=True)
        image_embs = image_output.last_hidden_state[:, 1:, :]
        image_embs_proj = model.visual_projection(image_embs)
        image_cls_embs = image_output.pooler_output
        image_cls_embs_proj = model.visual_projection(image_cls_embs)
        image_attention = image_output.attentions
        image_attention = torch.stack(image_attention).mean(dim=0)
        image_attention = torch.mean(image_attention, dim=1)
        image_attention = image_attention[:, 1:, 0]
        image_attention = image_attention / image_attention.sum()
        image_attention = image_attention.unsqueeze(-1)

        target_text = target_obj
        target_text = processor(text=target_text, return_tensors="pt", padding=True, truncation=True, max_length=77).to(
            device)
        target_text_embs = model.text_model(**target_text).pooler_output
        target_text_embs = model.text_projection(target_text_embs)

        mask = get_text_sim_map(image_embs_proj, target_text_embs)
        mask = mask.unsqueeze(0).unsqueeze(-1)

        image_mask = conver_to_image_mask(mask)
        masked_image = (image_mask * denormalize(image_tensor))
        masked_image = torch.clamp(masked_image, 0, 1)
        pil_image = transforms.functional.to_pil_image(masked_image.squeeze(0))
        save_path = ' '  # add your path to save the masked image
        pil_image.save(save_path)

        masked_image = Image.open(save_path).convert('RGB')
        masked_image = processor(images=masked_image, return_tensors="pt").to(device)
        masked_image_tensor = masked_image['pixel_values']
        masked_image_output = vision_model(masked_image_tensor, output_attentions=True)
        masked_image_embs = masked_image_output.last_hidden_state[:, 1:, :]
        masked_image_embs_proj = model.visual_projection(masked_image_embs)
        masked_image_cls = masked_image_output.last_hidden_state[:, 0, :]
        masked_image_cls_proj = model.visual_projection(masked_image_cls)


    for i in range(501):
        adv_noise.requires_grad_()
        optimizer = optim.Adam([{'params': adv_noise, 'lr': 0.01}])

        adv_x = denormalize(image_tensor) + adv_noise
        adv_x = normalize(adv_x)

        adv_output = vision_model(adv_x, output_attentions=True)
        adv_embs = adv_output.last_hidden_state[:, 1:, :]
        adv_cls_embs = adv_output.pooler_output
        adv_embs_proj = model.visual_projection(adv_embs)
        adv_cls_embs_proj = model.visual_projection(adv_cls_embs)

        adv_attention = adv_output.attentions
        adv_attention = torch.stack(adv_attention).mean(dim=0)
        adv_attention = torch.mean(adv_attention, dim=1)
        adv_attention = adv_attention[:, 1:, 0]
        adv_attention = adv_attention / adv_attention.sum()
        adv_attention = adv_attention.unsqueeze(-1)

        loss1 = 1 - F.cosine_similarity(mask * adv_attention * adv_embs,
                                         mask * image_attention * image_embs, dim=-1).sum() / mask.sum()

        loss2 = 1 - F.cosine_similarity(adv_embs_proj, masked_image_embs_proj, dim=-1).mean()
        loss3 = F.cosine_similarity(adv_cls_embs_proj, target_text_embs, dim=-1).mean()

        loss = 0.5 * loss1 + 2 * loss2 + 0.2 * loss3

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 100 ==0:
            print(
                f'Loss1: {loss1.item():.4f}, Loss2: {loss2.item():.4f}, Loss3: {loss3.item():.4f}')


        adv_noise = adv_noise.detach().clamp(-epsilon, epsilon)
        adv_noise = (denormalize(image_tensor) + adv_noise).clamp(0, 1) - denormalize(image_tensor)

        if i % 1 == 0:
            with torch.no_grad():
                adv_x = denormalize(image_tensor) + adv_noise
                adv_x = adv_x.clamp(0, 1)
                pil_image = transforms.functional.to_pil_image(adv_x.squeeze(0))
                pil_image.save('  ') # your save path
