import hydra
import numpy as np
from omegaconf import OmegaConf
from hydra.utils import instantiate
import torch
import random
import os, sys
import json
import argparse
from utils import find_or_create_run_dir
from utils.json_parser import parse_adversarial_json_output
from PIL import Image, ImageDraw, ImageFont
import torchvision.transforms as transforms
import pandas as pd
from dotenv import load_dotenv

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) 
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(1) 
load_dotenv()

def text_to_art_image_pil(text, font_path='arial.ttf', font_size=100, text_color=(0, 0, 0), bg_color=(255, 255, 255), image_width=500):
    image = Image.new('RGB', (image_width, 300), color=bg_color)
    draw = ImageDraw.Draw(image)

    try:
        font = ImageFont.truetype(font_path, font_size)
    except IOError:
        print(f"Font not found at {font_path}, trying arial.ttf...")
        try:
            font = ImageFont.truetype("arial.ttf", font_size)
        except IOError:
            print("Arial font not found, using default font.")
            font = ImageFont.load_default()

    lines = []
    words = text.split()
    current_line = ""
    
    max_width = int(image_width * 0.9)

    for word in words:
        test_line = f"{current_line} {word}".strip()

        try:
            text_bbox = draw.textbbox((0, 0), test_line, font=font)
            text_width = text_bbox[2] - text_bbox[0]
        except AttributeError:
            text_width = draw.textsize(test_line, font=font)[0]
        
        if text_width <= max_width:
            current_line = test_line
        else:
            if current_line:  # Only add non-empty lines
                lines.append(current_line)
            current_line = word

    if current_line:
        lines.append(current_line)

    line_spacing = max(5, font_size // 10)
    total_height = len(lines) * font_size + (len(lines) - 1) * line_spacing + 20
    image = Image.new('RGB', (image_width, total_height), color=bg_color)
    draw = ImageDraw.Draw(image)

    y_offset = 10  
    for line in lines:
        try:
            text_bbox = draw.textbbox((0, 0), line, font=font)
            text_width = text_bbox[2] - text_bbox[0]
        except AttributeError:
            text_width = draw.textsize(line, font=font)[0]
        
        text_x = (image.width - text_width) // 2
        draw.text((text_x, y_offset), line, font=font, fill=text_color)
        y_offset += font_size + line_spacing 

    return image

def concatenate_images_with_distraction(distraction_image_paths, diffusion_image, typo_image, images_per_row=3, target_size=(300, 300), fill_color=(255, 255, 255), font_size=20, rotation_angle=0):
    images = []
    
    try:
        font = ImageFont.truetype("", font_size)  
    except IOError:
        try:
            font = ImageFont.truetype("arial.ttf", font_size)  
        except IOError:
            font = ImageFont.load_default()

    for idx, img_path in enumerate(distraction_image_paths[:7]):
        if img_path.startswith('./llava_images/'):
            filename = img_path.replace('./llava_images/', '')
            full_path = f''
            img_path = full_path
        else:
            img_path = img_path
        img = Image.open(img_path)
        img.thumbnail(target_size) 

        diagonal = int((target_size[0]**2 + target_size[1]**2)**0.5)
        expanded_img = Image.new('RGB', (diagonal, diagonal), fill_color)
        img_x, img_y = img.size
        paste_x = (diagonal - img_x) // 2
        paste_y = (diagonal - img_y) // 2
        expanded_img.paste(img, (paste_x, paste_y))

        rotated_img = expanded_img.rotate(rotation_angle, expand=True, fillcolor=fill_color)

        final_padded_img = Image.new('RGB', target_size, fill_color)
        rotated_img_x, rotated_img_y = rotated_img.size
        final_paste_x = (target_size[0] - rotated_img_x) // 2
        final_paste_y = (target_size[1] - rotated_img_y) // 2
        final_padded_img.paste(rotated_img, (final_paste_x, final_paste_y))

        final_img = Image.new('RGB', (target_size[0], target_size[1] + font_size + 10), fill_color)
        final_img.paste(final_padded_img, (0, 0))

        draw = ImageDraw.Draw(final_img)
        text = str(idx + 1)
        try:
            text_bbox = draw.textbbox((0, 0), text, font=font)
            text_width = text_bbox[2] - text_bbox[0]
            text_height = text_bbox[3] - text_bbox[1]
        except AttributeError:
            text_width, text_height = draw.textsize(text, font=font)
        
        text_x = (target_size[0] - text_width) // 2
        text_y = target_size[1] + (font_size // 2)
        draw.text((text_x, text_y), text, font=font, fill=(0, 0, 0))  

        images.append(final_img)

    if diffusion_image:
        diffusion_img = diffusion_image.copy()
        diffusion_img.thumbnail(target_size)
        
        final_padded_img = Image.new('RGB', target_size, fill_color)
        img_x, img_y = diffusion_img.size
        paste_x = (target_size[0] - img_x) // 2
        paste_y = (target_size[1] - img_y) // 2
        final_padded_img.paste(diffusion_img, (paste_x, paste_y))

        final_img = Image.new('RGB', (target_size[0], target_size[1] + font_size + 10), fill_color)
        final_img.paste(final_padded_img, (0, 0))

        draw = ImageDraw.Draw(final_img)
        text = "8"
        try:
            text_bbox = draw.textbbox((0, 0), text, font=font)
            text_width = text_bbox[2] - text_bbox[0]
            text_height = text_bbox[3] - text_bbox[1]
        except AttributeError:
            text_width, text_height = draw.textsize(text, font=font)
        
        text_x = (target_size[0] - text_width) // 2
        text_y = target_size[1] + (font_size // 2)
        draw.text((text_x, text_y), text, font=font, fill=(0, 0, 0))

        images.append(final_img)

    if typo_image:
        typo_img = typo_image.copy()
        typo_img.thumbnail(target_size)
        
        final_padded_img = Image.new('RGB', target_size, fill_color)
        img_x, img_y = typo_img.size
        paste_x = (target_size[0] - img_x) // 2
        paste_y = (target_size[1] - img_y) // 2
        final_padded_img.paste(typo_img, (paste_x, paste_y))

        final_img = Image.new('RGB', (target_size[0], target_size[1] + font_size + 10), fill_color)
        final_img.paste(final_padded_img, (0, 0))

        draw = ImageDraw.Draw(final_img)
        text = "9"
        try:
            text_bbox = draw.textbbox((0, 0), text, font=font)
            text_width = text_bbox[2] - text_bbox[0]
            text_height = text_bbox[3] - text_bbox[1]
        except AttributeError:
            text_width, text_height = draw.textsize(text, font=font)
        
        text_x = (target_size[0] - text_width) // 2
        text_y = target_size[1] + (font_size // 2)
        draw.text((text_x, text_y), text, font=font, fill=(0, 0, 0))

        images.append(final_img)

    width, height = target_size[0], target_size[1] + font_size + 10

    rows = (len(images) + images_per_row - 1) // images_per_row

    total_width = width * images_per_row
    total_height = height * rows
    new_image = Image.new('RGB', (total_width, total_height), fill_color)

    for index, img in enumerate(images):
        x_offset = (index % images_per_row) * width
        y_offset = (index // images_per_row) * height
        new_image.paste(img, (x_offset, y_offset))

    return new_image

@hydra.main(config_path="configs", config_name="vlm_dual", version_base=None)
def main(cfg):
    cfg_omega = OmegaConf.create(cfg)
    run_dir, is_resume = find_or_create_run_dir(cfg_omega)
    target_vlm = instantiate(cfg.target)
    print("Loaded target VLM")
    print("Loading models...")
    attacker = instantiate(cfg.attacker)
    print("Loaded attacker")
    judge = instantiate(cfg.judge)
    print("Loaded judge")
    diffusion_model = instantiate(cfg.diffusion)
    print("Loaded diffusion model")

    prompts, category = instantiate(cfg.data)
    print("Loaded data")
    optimizer = instantiate(cfg.optimizer, attacker.parameters())
    print("Loaded optimizer")
    loss_fn = instantiate(cfg.loss)(
        target_lm=target_vlm, 
        judge=judge,
    )
    print("Loaded loss function")
    all_adversarial_data = {
        'indices': [],
        'original_prompts': [],
        'adversarial_prompts': [],
        'responses': [],
        'scores': [],
        'losses': [],
        'query_counts': [],
        'prompt_ids': []
    }
    
    global_counter = 0
    
    log_every = cfg.log_every
    JB = 0
    prompts_total = []
    responses_total = []
    scores = []
    final_adv_prompts = []
    fail_cases = []
    start_idx = 0
    end_idx = 0
    Total_JB_prompt = 0
    Total_JB_round = 0

    SUCCESS = 0
    select_img_path = "" # image map path
    
    distraction_images_map = {}
    if os.path.exists(select_img_path):
        with open(select_img_path, 'r', encoding='utf-8') as f:
            distraction_images_map = json.load(f)
        print(f"Loaded distraction images map from {select_img_path}")
    else:
        print(f"Warning: Distraction images map not found at {select_img_path}")
        
    for prompt_idx, (prompt, category) in enumerate(zip(prompts, category)):

        if prompt_idx < start_idx:
            continue
        if prompt_idx > end_idx:
            break
        attacker.initialize_attacker()
        optimizer = instantiate(cfg.optimizer, attacker.parameters())
        
        cur_prompt_adv = []
        cur_prompt_resp = []
        scores = []

        max_queries = cfg.max_queries
        THRESHOLD = cfg.threshold
        BATCH_SIZE = cfg.batch_size
        CUR_COUNT = 0
        total_attempts = 0
        refinement_limit = 1.5 * max_queries
        batch_prompt = [prompt for _ in range(BATCH_SIZE)]
        PASS = False
        print("=====New prompt ======")
        print("Original prompt: ", prompt)
        while True:
            if CUR_COUNT >= max_queries:
                break 
            Total_JB_round += 1
            optimizer.zero_grad()
            print()
            print(CUR_COUNT)
            print("Original prompt: ", prompt)
            adversarial_queries, log_likelihoods, entropy, log_prior = attacker.generate_conditional_prompt(
                batch_prompt, prior=True
            )

            image_prompts, text_prompts = parse_adversarial_json_output(adversarial_queries, prompt)
            
            print("Log-likelihoods", log_likelihoods)
            print("Log-prior", log_prior)
            diffusion_images = diffusion_model.generate_batch(image_prompts)
            
            distraction_image_paths = []
            if prompt in distraction_images_map:
                distraction_image_paths = distraction_images_map[prompt][:7]  
            else:
                print(f"Warning: No distraction images found for prompt: {prompt}")
                exit()
                distraction_image_paths = []
            
            # Generate concatenated images for each batch item
            font_path = ''
            concatenated_images = []
            
            for i in range(len(text_prompts)):
                diffusion_image = diffusion_images[i] if i < len(diffusion_images) else None
                typo_image = text_to_art_image_pil(text_prompts[i] if i < len(text_prompts) else "", 
                                              font_path=font_path, font_size=80, 
                                              text_color=(0, 0, 0), bg_color=(255, 255, 255), 
                                              image_width=800)
                cancatenate_image = concatenate_images_with_distraction(
                    distraction_image_paths, 
                    diffusion_image, 
                    typo_image,
                    images_per_row=3, 
                    target_size=(300, 300)
                )
                concatenated_images.append(cancatenate_image)
            
            save_dir = os.path.join(run_dir, "concatenated_images")
            os.makedirs(save_dir, exist_ok=True)
            
            print(f"Generated {len(concatenated_images)} concatenated images for batch size {len(text_prompts)}")
            loss_dct = loss_fn(
                adversarial_queries=text_prompts, 
                log_likelihoods=log_likelihoods,
                entropies=entropy,
                log_prior=log_prior,
                original_queries=batch_prompt,
                images=concatenated_images,
            )

            loss = loss_dct.pop("loss")
            print(f"Loss: {loss.item()}")
            
            batch_scores = loss_dct.pop("judge_scores")
            scores += batch_scores
            loss.backward()
            loss_scalar = loss.item()
            
            cur_prompt_adv += adversarial_queries
            cur_prompt_resp += loss_dct["responses"]
            
            total_attempts += 1


            trigger_update = (np.max(batch_scores) >= THRESHOLD)
            if trigger_update:
                print("Above-threshold response detected, updating and finishing current prompt")
                JB = JB + 1
                PASS = True
                print("------------------------------------------------------------------------------------")
                break

            CUR_COUNT += BATCH_SIZE
            Total_JB_prompt += BATCH_SIZE
            print("KL:", loss_dct['kl'], "entropy:" , loss_dct['entropy'])
            for q, r in zip(adversarial_queries, loss_dct['responses']):
                print("Adv Query:", q)
                print("VLM Response:\n", r)
                print("--------------------------------")


        if PASS == True:
            print(f"Successfully finished prompt {prompt}")
        else:
            print(f"Failed to jailbreak prompt {prompt}")

        best_adv_prompt_idx = int(np.argmax(scores))
        print("best_adv_prompt_idx:", best_adv_prompt_idx)
        best_adv_prompt = cur_prompt_adv[best_adv_prompt_idx]
        best_score = scores[best_adv_prompt_idx]
        prompts_total.append(best_adv_prompt)
        responses_total.append(cur_prompt_resp[best_adv_prompt_idx])

        best_prompts_file = f"{run_dir}/best_prompts_incremental.txt"
        with open(best_prompts_file, "a") as f:
            f.write(f"PROMPT {prompt_idx} (Index: {prompt_idx - start_idx + 1}):\n")
            f.write(f"Timestamp: {pd.Timestamp.now()}\n")
            f.write(f"Original Prompt: {prompt}\n")
            f.write(f"Best Adversarial Prompt: {best_adv_prompt}\n")
            f.write(f"Best Response: {cur_prompt_resp[best_adv_prompt_idx]}\n")
            f.write(f"Best Score: {best_score:.4f}\n")
            f.write(f"Threshold: {THRESHOLD}\n")
            f.write(f"Query Count: {CUR_COUNT}\n")
            f.write(f"Jailbreak Success: {'YES' if best_score > THRESHOLD else 'NO'}\n")
            f.write(f"Current JB Rate: {JB}/{prompt_idx - start_idx + 1} = {JB/(prompt_idx - start_idx + 1):.4f}\n")
            f.write("=" * 80 + "\n\n")
        
        print(f"Best prompt saved to {best_prompts_file}")
        

if __name__ == "__main__":
    main()