import os
import time
import json
import sys
import pickle
import shutil
import random
import argparse
from tqdm import tqdm

import seaborn as sns
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont

from experiments.blip_experiments.transferable.transform_image import DIM, SIM, SGA, SIA, TIM, Admix, AIP


def remove_image_extensions(text):
    text = text.replace(".jpg", "")
    text = text.replace(".png", "")
    return text


def crop_resize(image, image_size):
    width, height = image.size
    new_size = min(width, height)
    left = (width - new_size)/2
    top = (height - new_size)/2
    right = (width + new_size)/2
    bottom = (height + new_size)/2
    image = image.crop((left, top, right, bottom)).resize((image_size, image_size))
    return image


def render_typos(image, texts, font_path, font_size, font_color, max_attempts=1000):
    draw = ImageDraw.Draw(image)
    image_width, image_height = image.size

    text_positions = []

    for text in texts:
        font = ImageFont.truetype(font_path, font_size)
        text_width = int(draw.textlength(text, font=font))
        text_height = 30

        attempt = 0
        while attempt < max_attempts:
            text_x = random.randint(0, max(0, image_width - text_width))
            text_y = random.randint(0, max(0, image_height - text_height))

            overlap = False
            for pos in text_positions:
                if (text_x < pos[0] + pos[2] and text_x + text_width > pos[0] and
                    text_y < pos[1] + pos[3] and text_y + text_height > pos[1]):
                    overlap = True
                    break
            
            if not overlap:
                text_positions.append((text_x, text_y, text_width, text_height))
                draw.text((text_x, text_y), text, fill=font_color, font=font)
                break
            attempt += 1

        if attempt == max_attempts:
            print(f"Failed to place text '{text}' without overlap after {max_attempts} attempts.")

    return image


def main(args):

    print(f"curr used gpu: {args.device}")
    device_id = args.device.split(':')[-1]
    os.environ['CUDA_VISIBLE_DEVICES'] = device_id

    import torch
    import torch.nn.functional as F
    from transformers import TextStreamer
    from torchvision.utils import save_image

    from llava.constants import (IGNORE_INDEX, DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
    from llava.conversation import SeparatorStyle, conv_templates
    from llava.mm_utils import (KeywordsStoppingCriteria, get_model_name_from_path, tokenizer_image_token)
    from llava.model.builder import load_pretrained_model
    from llava.utils import disable_torch_init

    def normalize(images):
        mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cuda()
        std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cuda()
        images = images - mean[None, :, None, None]
        images = images / std[None, :, None, None]
        return images


    def denormalize(images):
        mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cuda()
        std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cuda()
        images = images * std[None, :, None, None]
        images = images + mean[None, :, None, None]
        return images

    disable_torch_init()
    
    model_path = "models/llava-v1.5-7b"
    image_size = 336
    # model_path = "models/LLaVA-7B-Lightening-v1-1"
    # image_size = 224
    
    model_name = get_model_name_from_path(model_path)
    
    epsilon = args.eps / 255
    alpha = 1 / 255
    num_iter = args.num_iter
    
    target_responses = args.target_responses

    dataset = args.dataset
    image_folder = None
    if dataset == "mscoco":
        image_folder = "dataset/transferable/mscoco_clean300"
    elif dataset == "celeba":
        image_folder = "dataset/transferable/celeba_clean300"

    test_prompt = None
    if dataset == "mscoco":
        test_prompt = "describe the image."
    elif dataset == "celeba":
        test_prompt = "describe the image."

    multiprompt = args.multiprompt
    questions_file = "dataset/transferable/question_describe.txt"
    period = 10
    batch_size = 1

    add_typo = args.add_typo
    typo_num = args.typo_num
    imagenet = args.imagenet
    wordnetnoun = args.wordnetnoun
    wordnetverb = args.wordnetverb
    wordnetadj = args.wordnetadj
    mixed = args.mixed
    typo_size = args.typo_size
    typo_color = (255, 255, 255)
    typo_font = 'fonts/arial_bold.ttf'
    imagenet_class_file = 'dataset/transferable/imagenet_class.pkl'
    noun_file = 'dataset/transferable/nouns.json'
    verb_file = 'dataset/transferable/verbs.json'
    adj_file = 'dataset/transferable/adjectives.json'
    
    # pixel augmentation
    dim = args.dim
    sim = args.sim
    sga = args.sga
    sia = args.sia
    tim = args.tim
    admix = args.admix
    aip = args.aip
    
    tag = f"describe-{dataset}-{model_name}-response_{', '.join(target_responses)}-iter{num_iter}-eps{args.eps}"
    
    if add_typo:
        tag = tag + f"-random{typo_num}typo-fs{typo_size}"
        if imagenet:
            tag = tag + "-imagenet"
        if wordnetnoun:
            tag = tag + "-wordnetnoun"
        if wordnetverb:
            tag = tag + "-wordnetverb"
        if wordnetadj:
            tag = tag + "-wordnetadj"
        if mixed:
            tag = tag + "-mixed"
    if multiprompt:
        tag = tag + "-multiprompt"
    if dim:
        tag = tag + "-dim"
    if sim:
        tag = tag + "-sim"
    if sga:
        tag = tag + "-sga"
    if sia:
        tag = tag + "-sia"
    if tim:
        tag = tag + "-tim"
    if admix:
        tag = tag + "-admix"
    if aip:
        tag = tag + "-aip"

    print(f"Curr Experiment Setting: {tag}", flush=True)

    advimg_dir = f"dataset/transferable/{tag}"
    if os.path.exists(advimg_dir):
        sys.exit()
    else:
        os.makedirs(advimg_dir)

    log_dir = 'temp_log'
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    loss_dir = os.path.join(log_dir, f"loss_{tag}")
    if os.path.exists(loss_dir):
        shutil.rmtree(loss_dir)
    os.makedirs(loss_dir)

    question_pool = []
    if multiprompt:
        with open(questions_file, 'r') as file:
            for line in file:
                question_pool.append(line.strip())
    else:
        question_pool.append(test_prompt)
        batch_size = 1

    class_pool = []
    if add_typo:
        if imagenet:
            file = open(imagenet_class_file,'rb')
            class_pool = pickle.load(file)
            class_pool = [item for tup in class_pool for item in tup]
        elif wordnetnoun:
            with open(noun_file, 'r') as file:
                class_pool = [json.loads(line.strip()) for line in file if line.strip()]
        elif wordnetverb:
            with open(verb_file, 'r') as file:
                class_pool = [json.loads(line.strip()) for line in file if line.strip()]
        elif wordnetadj:
            with open(adj_file, 'r') as file:
                class_pool = [json.loads(line.strip()) for line in file if line.strip()]
        elif mixed:
            with open(noun_file, 'r') as file:
                noun_pool = [json.loads(line.strip()) for line in file if line.strip()]
            with open(verb_file, 'r') as file:
                verb_pool = [json.loads(line.strip()) for line in file if line.strip()]
            with open(adj_file, 'r') as file:
                adj_pool = [json.loads(line.strip()) for line in file if line.strip()]     
            min_length = min(len(noun_pool), len(verb_pool), len(adj_pool))
            noun_pool = noun_pool[:min_length]
            verb_pool = verb_pool[:min_length]
            adj_pool = adj_pool[:min_length]
            class_pool = noun_pool + verb_pool + adj_pool
            random.shuffle(class_pool)
            
    print('>>> Initializing Models')
    tokenizer, model, image_processor, context_len = load_pretrained_model(model_path=model_path, model_base=None, model_name=model_name, load_8bit=False, load_4bit=False, device='cuda')
    model.eval()
    model.requires_grad_(False)
    print('[Initialization Finished]\n')

    records = []
    image_files = os.listdir(image_folder)
    image_files.sort()
    for k, image_file in enumerate(tqdm(image_files)):
        
        best_loss, best_noise = None, None
        
        if 'llama-2' in model_name.lower():
            conv_mode = "llava_llama_2"
        elif "v1" in model_name.lower():
            conv_mode = "llava_v1"
        elif "mpt" in model_name.lower():
            conv_mode = "mpt"
        else:
            conv_mode = "llava_v0"

        origin_image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
        origin_image = crop_resize(origin_image, image_size)
        origin_image_tensor = image_processor.preprocess(origin_image, return_tensors='pt')['pixel_values'].half().cuda()
        
        image = origin_image.copy()
        image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()
        adv_noise = torch.rand_like(image_tensor).cuda() * 2 * epsilon - epsilon
        adv_noise.requires_grad_(True)
        adv_noise.retain_grad()
        
        loss_buffer = []
        for t in range(num_iter + 1):
            image = origin_image.copy()    
            if add_typo:
                typos = random.sample(class_pool, k=typo_num)
                image = render_typos(image, typos, typo_font, typo_size, typo_color)
                image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()
            if dim:
                image = DIM(image)
                image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()
            if sim:
                image = SIM(image)
                image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()
            if sga:
                image = SGA(image)
                image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()
            if sia:
                image = SIA(image)
                image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()
            if tim:
                image = TIM(image)
                image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()
            if admix:
                added_image = Image.open(os.path.join(image_folder, random.choice(image_files))).convert('RGB')
                added_image = crop_resize(added_image, image_size)
                image = Admix(image, added_image)
                image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()
            if aip:
                added_image = Image.open(os.path.join(image_folder, random.choice(image_files))).convert('RGB')
                added_image = crop_resize(added_image, image_size)
                image = AIP(image, added_image)
                image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()
            
            if t%period==0:
                image.save(os.path.join(log_dir, f"typoimg_{tag}.png"))
            
            x = denormalize(image_tensor).clone().cuda()
            adv_noise.data = (adv_noise.data + x.data).clamp(0, 1) - x.data
            
            conversations = []
            conv = conv_templates[conv_mode].copy()
            questions = random.sample(question_pool, batch_size)
            for question in questions:
                conv.messages = []
                if model.config.mm_use_im_start_end:
                    question = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + question
                else:
                    question = DEFAULT_IMAGE_TOKEN + '\n' + question
                conv.append_message(conv.roles[0], question)
                conv.append_message(conv.roles[1], random.choice(target_responses))
                conversations.append(conv.get_prompt())

            # Mask targets
            input_ids = [tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') for prompt in conversations]
            max_length = max(tensor.shape[0] for tensor in input_ids)
            padded_input_ids = [F.pad(tensor, (0, max_length - tensor.shape[0]), 'constant', tokenizer.pad_token_id) for tensor in input_ids]
            input_ids = torch.stack(padded_input_ids, dim=0).cuda()
            attention_masks = []

            targets = input_ids.clone()
            if 'llama-2' in model_name.lower():
                sep = "[/INST] "
            else:
                sep = conv.sep + conv.roles[1] + ": "
            for conversation, target in zip(conversations, targets):
                total_len = int(target.ne(tokenizer.pad_token_id).sum())
                attention_mask = target.ne(tokenizer.pad_token_id)
                attention_masks.append(attention_mask)
                rounds = conversation.split(conv.sep2)
                cur_len = 1
                target[:cur_len] = IGNORE_INDEX
                for i, rou in enumerate(rounds):
                    if rou == "":
                        break
                    parts = rou.split(sep)
                    if len(parts) != 2:
                        break
                    parts[0] += sep
                    round_len = len(tokenizer_image_token(rou, tokenizer))
                    instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
                    target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
                    cur_len += round_len
                target[cur_len:] = IGNORE_INDEX
                
            attention_masks = torch.stack(attention_masks, dim=0).cuda()
            inputs = {'input_ids':input_ids, 'labels':targets, 'attention_mask':attention_masks, 'images':image_tensor}
                    
            x_adv = x + adv_noise
            x_adv = normalize(x_adv)
            images = x_adv.repeat(len(questions), 1, 1, 1)
            inputs['images'] = images.half()
            
            outputs = model(**inputs)
            loss = outputs.loss
            loss.backward()
            loss_buffer.append(loss.item())
            
            adv_noise.data = (adv_noise.data - alpha * adv_noise.grad.detach().sign()).clamp(-epsilon, epsilon)
            x = denormalize(origin_image_tensor).clone().cuda()
            adv_noise.data = (adv_noise.data + x.data).clamp(0, 1) - x.data
            adv_noise.grad.zero_()
            model.zero_grad()
            
            if best_loss is None or loss_buffer[-1] < best_loss:
                best_loss = loss_buffer[-1]
                best_noise = adv_noise.clone()

            if t % 100 == 0:
                sns.set_theme()
                num_iters = len(loss_buffer)
                x_ticks = list(range(0, num_iters))
                plt.plot(x_ticks, loss_buffer, label='Target Loss')
                plt.title('Loss Plot')
                plt.xlabel('Iters')
                plt.ylabel('Loss')
                plt.legend(loc='best')
                plt.savefig(os.path.join(loss_dir, image_file))
                plt.clf()
                
                print(f'######### Output - Iter = {t} loss = {best_loss}##########')
                x = denormalize(origin_image_tensor).clone().cuda()
                x_adv = x + best_noise
                x_adv = normalize(x_adv)
                for question in questions:
                    inp = question
                    if model.config.mm_use_im_start_end:
                        inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
                    else:
                        inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
                    conv = conv_templates[conv_mode].copy()    
                    conv.append_message(conv.roles[0], inp)
                    conv.append_message(conv.roles[1], None)
                    prompt = conv.get_prompt()
                    input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
                    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
                    keywords = [stop_str]
                    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
                    streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
                    print(f"Image file: {image_file}")
                    print(f"Question: {question}")
                    print(f"Answer: ", end="")
                    with torch.inference_mode():
                        output_ids = model.generate(
                            input_ids,
                            images=x_adv.half(),
                            do_sample=True,
                            temperature=args.temperature,
                            max_new_tokens=args.max_new_tokens,
                            streamer=streamer,
                            use_cache=True,
                            stopping_criteria=[stopping_criteria])
                    output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
                    print()
                    
                x_adv = x + best_noise
                adv_img = x_adv.detach().cpu().squeeze(0)
                save_image(adv_img, os.path.join(advimg_dir, remove_image_extensions(image_file)+".png"))
                    
                record = {'image file': image_file, "prompt": question, 'iter': t, 'loss': best_loss, 'answer': output}
                records.append(record)
                with open(f'{log_dir}/output_{tag}.json', 'w') as json_file:
                    json.dump(records, json_file)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='mscoco')
    parser.add_argument('--target-responses', type=json.loads)
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--num-iter', type=int, default=1000)
    parser.add_argument('--eps', type=int, default=16)
    parser.add_argument('--multiprompt', action='store_true')
    parser.add_argument('--add-typo', action='store_true')
    parser.add_argument('--typo-num', type=int, default=1)
    parser.add_argument('--typo-size', type=int, default=15)
    parser.add_argument('--imagenet', action='store_true')
    parser.add_argument('--wordnetnoun', action='store_true')
    parser.add_argument('--wordnetverb', action='store_true')
    parser.add_argument('--wordnetadj', action='store_true')
    parser.add_argument('--mixed', action='store_true')
    parser.add_argument('--dim', action='store_true')
    parser.add_argument('--sim', action='store_true')
    parser.add_argument('--sga', action='store_true')
    parser.add_argument('--sia', action='store_true')
    parser.add_argument('--tim', action='store_true')
    parser.add_argument('--admix', action='store_true')
    parser.add_argument('--aip', action='store_true')
    
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--max-new-tokens", type=int, default=200)
    
    args = parser.parse_args()
    start_time = time.time()
    main(args)
    end_time = time.time()
    print(f"execution time: {(end_time - start_time) / 3600}h")