import os
import sys
sys.path.append("../LAVIS")
sys.path.append("../LLaVA")
os.chdir("../LAVIS")

import random
import pickle
import torch
from PIL import Image
from torchvision.utils import save_image

from lavis.models import load_model_and_preprocess

import torch
from tqdm import tqdm
import random
from torchvision.utils import save_image

import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont

import cv2
import numpy as np
import time
import json
import subprocess
import re
import argparse
import seaborn as sns
import shutil

from experiments.blip_experiments.transferable.transform_image import DIM, SIM, SGA, SIA, TIM, Admix, AIP


def remove_image_extensions(text):
    text = text.replace(".jpg", "")
    text = text.replace(".png", "")
    return text


def crop_resize(image, image_size):
    width, height = image.size
    new_size = min(width, height)
    left = (width - new_size)/2
    top = (height - new_size)/2
    right = (width + new_size)/2
    bottom = (height + new_size)/2
    image = image.crop((left, top, right, bottom)).resize((image_size, image_size))
    return image


def render_typos(image, texts, font_path, font_size, font_color, max_attempts=1000):
    draw = ImageDraw.Draw(image)
    image_width, image_height = image.size

    text_positions = []

    for text in texts:
        font = ImageFont.truetype(font_path, font_size)
        text_width = int(draw.textlength(text, font=font))
        text_height = 30

        attempt = 0
        while attempt < max_attempts:
            text_x = random.randint(0, max(0, image_width - text_width))
            text_y = random.randint(0, max(0, image_height - text_height))

            overlap = False
            for pos in text_positions:
                if (text_x < pos[0] + pos[2] and text_x + text_width > pos[0] and
                    text_y < pos[1] + pos[3] and text_y + text_height > pos[1]):
                    overlap = True
                    break
            
            if not overlap:
                text_positions.append((text_x, text_y, text_width, text_height))
                draw.text((text_x, text_y), text, fill=font_color, font=font)
                break
            attempt += 1

        if attempt == max_attempts:
            print(f"Failed to place text '{text}' without overlap after {max_attempts} attempts.")

    return image


def normalize(images):
    mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).to(device)
    std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).to(device)
    images = images - mean[None, :, None, None]
    images = images / std[None, :, None, None]
    return images


def denormalize(images):
    mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).to(device)
    std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).to(device)
    images = images * std[None, :, None, None]
    images = images + mean[None, :, None, None]
    return images


device = None

def main(args):
    model_name='blip2_vicuna_instruct'
    model_type='vicuna7b'

    image_size = 224
    epsilon = args.eps / 255
    alpha = 1 / 255
    num_iter = args.num_iter

    target_responses = args.target_responses

    global device
    device = args.device
    print(f"Using GPU: {device}")

    dataset = args.dataset
    image_folder = None
    if dataset == "mscoco":
        image_folder = "../LLaVA/dataset/transferable/mscoco_clean300"
    elif dataset == "celeba":
        image_folder = "../LLaVA/dataset/transferable/celeba_clean300"

    multiprompt = args.multiprompt
    questions_file = "../LLaVA/dataset/transferable/question_describe.txt"
    period = 10
    batch_size = 1
    
    test_prompt = None
    if dataset == "mscoco":
        test_prompt = "describe the image."
    elif dataset == "celeba":
        test_prompt = "describe the image."

    add_typo = args.add_typo
    typo_num = args.typo_num
    imagenet = args.imagenet
    wordnetnoun = args.wordnetnoun
    wordnetverb = args.wordnetverb
    wordnetadj = args.wordnetadj
    mixed = args.mixed
    typo_size = args.typo_size
    typo_color = (255, 255, 255)
    typo_font = '../LLaVA/fonts/arial_bold.ttf'
    imagenet_class_file = '../LLaVA/dataset/transferable/imagenet_class.pkl'
    noun_file = '../LLaVA/dataset/transferable/nouns.json'
    verb_file = '../LLaVA/dataset/transferable/verbs.json'
    adj_file = '../LLaVA/dataset/transferable/adjectives.json'

    # pixel augmentation
    dim = args.dim
    sim = args.sim
    sga = args.sga
    sia = args.sia
    tim = args.tim
    admix = args.admix
    aip = args.aip

    tag = f"describe-{dataset}-{model_name}-{model_type}-response_{', '.join(target_responses)}-iter{num_iter}-eps{args.eps}"

    if add_typo:
        tag = tag + f"-random{typo_num}typo-fs{typo_size}"
        if imagenet:
            tag = tag + "-imagenet"
        if wordnetnoun:
            tag = tag + "-wordnetnoun"
        if wordnetverb:
            tag = tag + "-wordnetverb"
        if wordnetadj:
            tag = tag + "-wordnetadj"
        if mixed:
            tag = tag + "-mixed"
    if multiprompt:
        tag = tag + "-multiprompt"
    if dim:
        tag = tag + "-dim"
    if sim:
        tag = tag + "-sim"
    if sga:
        tag = tag + "-sga"
    if sia:
        tag = tag + "-sia"
    if tim:
        tag = tag + "-tim"
    if admix:
        tag = tag + "-admix"
    if aip:
        tag = tag + "-aip"

    print(f"Curr Experiment Setting: {tag}", flush=True)

    advimg_dir = f"../LLaVA/dataset/transferable/{tag}"
    if os.path.exists(advimg_dir):
        sys.exit()
    else:
        os.makedirs(advimg_dir)

    log_dir = 'temp_log'
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    
    loss_dir = os.path.join(log_dir, f"loss_{tag}")
    if os.path.exists(loss_dir):
        shutil.rmtree(loss_dir)
    os.makedirs(loss_dir)

    question_pool = []
    if multiprompt:
        with open(questions_file, 'r') as file:
            for line in file:
                question_pool.append(line.strip())
    else:
        question_pool.append(test_prompt)
        batch_size = 1

    class_pool = []
    if add_typo:
        if imagenet:
            file = open(imagenet_class_file,'rb')
            class_pool = pickle.load(file)
            class_pool = [item for tup in class_pool for item in tup]
        elif wordnetnoun:
            with open(noun_file, 'r') as file:
                class_pool = [json.loads(line.strip()) for line in file if line.strip()]
        elif wordnetverb:
            with open(verb_file, 'r') as file:
                class_pool = [json.loads(line.strip()) for line in file if line.strip()]
        elif wordnetadj:
            with open(adj_file, 'r') as file:
                class_pool = [json.loads(line.strip()) for line in file if line.strip()]
        elif mixed:
            with open(noun_file, 'r') as file:
                noun_pool = [json.loads(line.strip()) for line in file if line.strip()]
            with open(verb_file, 'r') as file:
                verb_pool = [json.loads(line.strip()) for line in file if line.strip()]
            with open(adj_file, 'r') as file:
                adj_pool = [json.loads(line.strip()) for line in file if line.strip()]     
            min_length = min(len(noun_pool), len(verb_pool), len(adj_pool))
            noun_pool = noun_pool[:min_length]
            verb_pool = verb_pool[:min_length]
            adj_pool = adj_pool[:min_length]
            class_pool = noun_pool + verb_pool + adj_pool
            random.shuffle(class_pool)

    # ========================================
    #             Model Initialization
    # ========================================


    print('>>> Initializing Models')
    model, vis_processor, _ = load_model_and_preprocess(name=model_name, model_type=model_type, is_eval=True, device=device)
    model.eval()
    model.requires_grad_(False)
    print('[Initialization Finished]\n')

    records = []
    image_files = os.listdir(image_folder)
    image_files.sort()
    for k, image_file in enumerate(tqdm(image_files)):
        best_loss, best_noise = None, None
        
        image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
        image = crop_resize(image, image_size)
        image = vis_processor["eval"](image).unsqueeze(0).to(device)
        origin_x = denormalize(image).clone().to(device)
        
        adv_noise = torch.rand_like(image).to(device) * 2 * epsilon - epsilon
        adv_noise.data = (adv_noise.data + origin_x.data).clamp(0, 1) - origin_x.data
        adv_noise.requires_grad_(True)
        adv_noise.retain_grad()

        loss_buffer = []
        for t in range(num_iter + 1):
            # if t % period == 0:
            #     batch_inputs = random.sample([test_prompt], batch_size)
            # else:
            #     batch_inputs = random.sample(question_pool, batch_size)
            batch_inputs = random.sample(question_pool, batch_size)
                
            image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
            image = crop_resize(image, image_size)
            if add_typo:
                typos = random.sample(class_pool, k=typo_num)
                image = render_typos(image, typos, typo_font, typo_size, typo_color)
            if dim:
                image = DIM(image)
            if sim:
                image = SIM(image)
            if sga:
                image = SGA(image)
            if sia:
                image = SIA(image)
            if tim:
                image = TIM(image)
            if admix:
                added_image = Image.open(os.path.join(image_folder, random.choice(image_files))).convert('RGB')
                added_image = crop_resize(added_image, image_size)
                image = Admix(image, added_image)
            if aip:
                added_image = Image.open(os.path.join(image_folder, random.choice(image_files))).convert('RGB')
                added_image = crop_resize(added_image, image_size)
                image = AIP(image, added_image)
            
            if t % period == 0:
                image.save(f"{log_dir}/typoimg_{tag}.png")
            
            if add_typo or dim or sim or sia or sga or admix or tim or aip:
                image = vis_processor["eval"](image).unsqueeze(0).to(device)
                x = denormalize(image).clone().to(device)
            else:
                x = origin_x
                
            adv_noise.data = (adv_noise.data + x.data).clamp(0, 1) - x.data
            x_adv = x + adv_noise
            x_adv = normalize(x_adv).repeat(batch_size, 1, 1, 1)

            samples = {
                'image': x_adv,
                'text_input': batch_inputs,
                'text_output': [random.choice(target_responses)] * batch_size
            }

            target_loss = model(samples)['loss']
            target_loss.backward()
            loss_buffer.append(target_loss.item())

            adv_noise.data = (adv_noise.data - alpha * adv_noise.grad.detach().sign()).clamp(-epsilon, epsilon)
            adv_noise.data = (adv_noise.data + origin_x.data).clamp(0, 1) - origin_x.data
            adv_noise.grad.zero_()
            model.zero_grad()
            
            if best_loss is None or loss_buffer[-1] < best_loss:
                best_loss = loss_buffer[-1]
                best_noise = adv_noise.clone()
            
            if t % 100 == 0:
                sns.set_theme()
                num_iters = len(loss_buffer)
                x_ticks = list(range(0, num_iters))
                plt.plot(x_ticks, loss_buffer, label='Target Loss')
                plt.title('Loss Plot')
                plt.xlabel('Iters')
                plt.ylabel('Loss')
                plt.legend(loc='best')
                plt.savefig(os.path.join(loss_dir, image_file))
                plt.clf()
                
                print(f'######### Output - Iter = {t} loss = {best_loss}##########')
                for i in range(len(batch_inputs)):
                    print(f"Image file: {image_file}")
                    print(f'Question: {batch_inputs[i]}')
                    x_adv = origin_x + best_noise
                    x_adv = normalize(x_adv)
                    with torch.no_grad():
                        output = model.generate({"image": x_adv, "prompt": batch_inputs[i]}, use_nucleus_sampling=True, top_p=0.9, temperature=1)[0]
                        print("Answer:", output)
                        print()

                x_adv = denormalize(x_adv).detach().cpu()
                x_adv = x_adv.detach().cpu().squeeze(0)
                save_image(x_adv, os.path.join(advimg_dir, remove_image_extensions(image_file)+".png"))
                
                record = {'image file': image_file, "prompt": batch_inputs[0], 'iter': t, 'loss': best_loss, 'answer': output}
                records.append(record)
                with open(f'{log_dir}/output_{tag}.json', 'w') as json_file:
                    json.dump(records, json_file)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='mscoco')
    parser.add_argument('--target-responses', type=json.loads, default='["../LLaVA/dataset/transferable/test_images"]')
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--num-iter', type=int, default=1000)
    parser.add_argument('--eps', type=int, default=16)
    parser.add_argument('--multiprompt', action='store_true')
    parser.add_argument('--add-typo', action='store_true')
    parser.add_argument('--typo-num', type=int, default=1)
    parser.add_argument('--typo-size', type=int, default=15)
    parser.add_argument('--imagenet', action='store_true')
    parser.add_argument('--wordnetnoun', action='store_true')
    parser.add_argument('--wordnetverb', action='store_true')
    parser.add_argument('--wordnetadj', action='store_true')
    parser.add_argument('--mixed', action='store_true')
    parser.add_argument('--dim', action='store_true')
    parser.add_argument('--sim', action='store_true')
    parser.add_argument('--sga', action='store_true')
    parser.add_argument('--sia', action='store_true')
    parser.add_argument('--tim', action='store_true')
    parser.add_argument('--admix', action='store_true')
    parser.add_argument('--aip', action='store_true')
    
    args = parser.parse_args()
    start_time = time.time()
    main(args)
    end_time = time.time()
    print(f"execution time: {(end_time - start_time) / 3600}h")