from transformers import AutoConfig, AutoTokenizer
from qwen_vl_utils import smart_resize
import os
import torch
from PIL import Image
from modeling.decoder.flux_decoder import FluxDecoder
from modeling.ar.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
from modeling.ar.processing_qwen2_5_vl import Qwen2_5_VLProcessor
import json
import datetime
from tqdm import tqdm
import time
import hashlib

print("Safe Editor begins working!")

local_rank = int(os.environ.get("LOCAL_RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 8))
torch.cuda.set_device(local_rank)
device = torch.device(f"cuda:{local_rank}")

model_path = './SafeEditor'
model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path,
                                                           config=model_config,
                                                           trust_remote_code=True,
                                                           torch_dtype="auto",
                                                           device_map=f"cuda:{local_rank}")
processor = Qwen2_5_VLProcessor.from_pretrained(model_path)
model.eval()

flux_decoder_path = os.path.join("/path/to/flux_decoder", 'decoder_81_512.bin') # path to trained decoder
flux_decoder = FluxDecoder(flux_decoder_path, "/path/to/flux_decoder", device=f"cuda:{local_rank}")

max_pixels = 262640
gen_size = 512


edit_data = json.load(open("/data_to_be_edited", "r"))

print(len(edit_data), "Edit data loaded")

source_img_dir = "/images"
config_dir = "/edit_config"
image_dir = "/edit_images"
os.makedirs(config_dir, exist_ok=True)
os.makedirs(image_dir, exist_ok=True)

data_num = len(edit_data)
data_per_device = (data_num + world_size - 1) // world_size  # ceiling division

start = local_rank * data_per_device
end = min(start + data_per_device, data_num)

device_data = edit_data[start:end]
print(len(device_data), "Data for this device")

progress_bar = tqdm(total=len(device_data), desc="Editing unsafe images", unit="data")

results = []
for data_item in device_data:
    max_round = 6
    output_item = data_item.copy()
    for i in range(max_round):
        
        if i == 0:
            prompt = data_item["prompt"]
            instruction = f"Here's a prompt and an image generated from this prompt. Analyze whether the image is safe and refine it if it's unsafe: {prompt} <image>"
            abs_img_path = os.path.join(source_img_dir, data_item["image_path"])
            round_image = [abs_img_path]
            instruction = instruction.replace('<image>', '<|vision_start|><|image_pad|><|vision_end|>')
        else:
            instruction = f"Analyze whether the image is safe and refine it if it's unsafe: <image>"
            abs_img_path = os.path.join(image_dir, output_item[f"round_{i}"]["image_path"])
            round_image = [abs_img_path]
            instruction = instruction.replace('<image>', '<|vision_start|><|image_pad|><|vision_end|>')

        
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": instruction},
                ],
            }
        ]

        text = processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )

        images = [Image.open(image).convert('RGB') for image in round_image]
        # resize input to max_pixels to avoid oom
        for j in range(len(images)):
            input_image = images[j]
            input_w, input_h = input_image.size
            resized_height, resized_width = smart_resize(
                input_h,
                input_w,
                max_pixels=max_pixels,
            )
            images[j] = input_image.resize((resized_width, resized_height))

        inputs = processor(
            text=[text],
            images=images,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to(model.device)
        generation_image_grid_thw = torch.tensor([[1, 18, 18]]).to(device)
        with torch.no_grad():
            outputs = model.generate(**inputs,
                                    max_new_tokens=1024,
                                    return_dict_in_generate=True,
                                    generation_image_grid_thw=generation_image_grid_thw)
        if 'output_image_embeddings' in outputs:
            generated_ids = outputs['sequences']
            output_image_embeddings = outputs['output_image_embeddings']

            generated_ids_trimmed = [
                out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs['input_ids'], generated_ids)
            ]
            output_text = processor.batch_decode_all2all(
                generated_ids_trimmed, skip_special_tokens=False, clean_up_tokenization_spaces=False
            )
            
            round_text_thought = output_text[0].split("<image>", 1)[0].strip()
            

            pipe_kwargs = {"negative_prompt": "", "cfg_scale": 3.0}
            image = flux_decoder.decode_image_embeds(output_image_embeddings, **pipe_kwargs, height=gen_size, width=gen_size)
            
            filenames = f"{prompt}_{i}_safe_editor_{local_rank}_{time.time()}"
            filenames = hashlib.md5(filenames.encode()).hexdigest()
            output_item[f"round_{i+1}"] = {}
            output_item[f"round_{i+1}"]["image_path"] = os.path.join(image_dir, filenames + ".png")
            output_item[f"round_{i+1}"]["text_thought"] = round_text_thought

            image.save(output_item[f"round_{i+1}"]["image_path"])
        else:
            generated_ids = outputs['sequences']
            generated_ids_trimmed = [
                out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs['input_ids'], generated_ids)
            ]
            output_text = processor.batch_decode_all2all(
                generated_ids_trimmed, skip_special_tokens=False, clean_up_tokenization_spaces=False
            )
            
            round_text_thought = output_text[0]
            
            output_item[f"round_{i+1}"] = {}
            output_item[f"round_{i+1}"]["text_thought"] = round_text_thought

            break
    
    results.append(output_item)
    progress_bar.update(1)

progress_bar.close()


config_path = os.path.join(config_dir, f"edit_{local_rank}.json")
with open(config_path, "w") as f:
    json.dump(results, f, indent=4)
print(f"Config file saved to {config_path}")