import argparse
import os
import random
from tqdm import tqdm

import numpy as np
import torch
import torch.backends.cudnn as cudnn
from PIL import Image
from torchvision.utils import save_image
from llava_llama_2.utils import get_model


def load_image(image_path):
    image = Image.open(image_path).convert('RGB')
    return image


def normalize(images):
    mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cuda()
    std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cuda()
    images = images - mean[None, :, None, None]
    images = images / std[None, :, None, None]
    return images


def denormalize(images):
    mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cuda()
    std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cuda()
    images = images * std[None, :, None, None]
    images = images + mean[None, :, None, None]
    return images


def parse_args():
    parser = argparse.ArgumentParser(description="Demo")
    parser.add_argument("--model-path", type=str, default="ckpts/llava_llama_2_13b_chat_freeze")
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument("--gpu_id", type=int, default=0, help="specify the gpu to load the model.")
    parser.add_argument("--attacked_image_fold", type=str)
    parser.add_argument("--raw_image_fold", type=str)
    parser.add_argument("--output_fold", type=str)

    parser.add_argument(
        "--options",
        nargs="+",
        help="override some settings in the used config, the key-value pair "
             "in xxx=yyy format will be merged into config file (deprecate), "
             "change to --cfg-options instead.",
    )
    args = parser.parse_args()
    return args


def setup_seeds(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    cudnn.benchmark = False
    cudnn.deterministic = True


# ========================================
#             Model Initialization
# ========================================

print('>>> Initializing Models')

args = parse_args()
print('model = ', args.model_path)

tokenizer, model, image_processor, model_name = get_model(args)
model.eval()
print('Initialization Finished')

out = []
attacked_image_files = os.listdir(args.attacked_image_fold)
raw_image_files = os.listdir(args.raw_image_fold)
attacked_image_files = sorted(attacked_image_files)
os.makedirs(args.output_fold, exist_ok=True)
with torch.no_grad():
    for image_file in tqdm(attacked_image_files):
        image_file = image_file.split('.')[0]
        assert (image_file + '.jpg') in raw_image_files
        raw_img = load_image(os.path.join(args.raw_image_fold, image_file + '.jpg'))
        raw_img = image_processor.preprocess(raw_img, return_tensors='pt')['pixel_values'].cuda()
        print(raw_img.size())
        exit()
        # 随机噪声
        epsilon = 32 / 255
        adv_noise = torch.rand_like(raw_img).to(model.device) * 2 * epsilon - epsilon
        x = denormalize(raw_img).clone().to(model.device)
        adv_noise.data = (adv_noise.data + x.data).clamp(0, 1) - x.data
        x_adv = x + adv_noise
        x_adv = normalize(x_adv)
        x_adv = denormalize(x_adv)
        x_adv = x_adv.squeeze(0).cpu()
        save_image(x_adv, os.path.join(args.output_fold, image_file + '.bmp'))
