import argparse
import torch
import os
from tqdm import tqdm
import pickle

from PIL import Image
from llava_llama_2.utils import get_model
torch.set_num_threads(8)


def parse_args():

    parser = argparse.ArgumentParser(description="Demo")
    parser.add_argument("--model-path", type=str, default="ckpts/llava_llama_2_13b_chat_freeze")
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument("--gpu_id", type=int, default=0, help="specify the gpu to load the model.")

    parser.add_argument("--save_dir", type=str, default='output',
                        help="save directory")
    parser.add_argument("--attacked_image_fold", type=str)
    parser.add_argument("--raw_image_fold", type=str)
    parser.add_argument("--output_fold", type=str)
    parser.add_argument("--batch_size", type=int, default=1)

    args = parser.parse_args()
    return args


def load_image(image_path):
    image = Image.open(image_path).convert('RGB')
    return image

# ========================================
#             Model Initialization
# ========================================


print('>>> Initializing Models')
args = parse_args()

print('model = ', args.model_path)

tokenizer, model, image_processor, model_name = get_model(args)
model.eval()
print('[Initialization Finished]\n')


def get_embeddings(file_prefix, file_name, file_suffix):
    img = load_image(os.path.join(file_prefix, file_name + file_suffix))
    img = image_processor.preprocess(img, return_tensors='pt')['pixel_values']
    img = model.encode_images(img.half())
    return img


out = []
attacked_image_files = os.listdir(args.attacked_image_fold)
raw_image_files = os.listdir(args.raw_image_fold)
attacked_image_files = sorted(attacked_image_files)
os.makedirs(args.output_fold, exist_ok=True)
with torch.no_grad():
    for image_file in tqdm(attacked_image_files):
        image_file = image_file.split('.')[0]
        assert (image_file + '.jpg') in raw_image_files
        attacked_img = get_embeddings(args.attacked_image_fold, image_file, '.bmp').half()
        raw_img = get_embeddings(args.raw_image_fold, image_file, '.jpg').half()
        pickle.dump(
            torch.cat([attacked_img, raw_img], dim=0).cpu(),
            open(os.path.join(args.output_fold, image_file + '.pkl'), 'wb')
        )
