import os
os.environ['YOLO_VERBOSE'] = 'False'

import json
from ultralytics import YOLO
model = YOLO("yolo11m.pt")
import matplotlib.pyplot as plt
import cv2
import argparse
import numpy as np
from PIL import Image

from diffusers import DDIMScheduler, StableDiffusionPipeline
import torch

class BlendedLatnetDiffusion:
    def __init__(self, **initial_args):
        self.args = self.parse_args()
        self.update_args(**initial_args)
        self.load_models()

    def parse_args(self):
        args = {
            "prompt": "",
            "init_image": "",
            "mask": "",
            "model_path": "stabilityai/stable-diffusion-2-1-base",
            "batch_size": 4,
            "blending_start_percentage": 0.25,
            "device": "cuda",
            "output_path": "outputs/res.jpg"
        }
        return argparse.Namespace(**args)

    def update_args(self, **kwargs):
        for key, value in kwargs.items():
            if hasattr(self.args, key):
                setattr(self.args, key, value)
            else:
                raise AttributeError(f"Invalid argument: {key}")
    
    def load_models(self):
        pipe = StableDiffusionPipeline.from_pretrained(
            self.args.model_path, torch_dtype=torch.float16
        )
        self.vae = pipe.vae.to(self.args.device)
        self.tokenizer = pipe.tokenizer
        self.text_encoder = pipe.text_encoder.to(self.args.device)
        self.unet = pipe.unet.to(self.args.device)
        self.scheduler = DDIMScheduler(
            beta_start=0.00085,
            beta_end=0.012,
            beta_schedule="scaled_linear",
            clip_sample=False,
            set_alpha_to_one=False,
        )

    @torch.no_grad()
    def edit_image(
        self,
        image_img,
        mask_img,
        prompts,
        height=512,
        width=512,
        num_inference_steps=50,
        guidance_scale=7.5,
        generator=torch.manual_seed(42),
        blending_percentage=0.25,
    ):
        batch_size = len(prompts)
        if isinstance(image_img,str):
            image = Image.open(image_path)
        else:
            image=  image_img
        image = image.resize((height, width), Image.BILINEAR)
        image = np.array(image)[:, :, :3]
        source_latents = self._image2latent(image)
        if isinstance(mask_img,str):
            mask_img = Image.open(mask_img)
        latent_mask, org_mask = self._read_mask(mask_img)

        text_input = self.tokenizer(
            prompts,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )
        text_embeddings = self.text_encoder(text_input.input_ids.to("cuda"))[0]

        max_length = text_input.input_ids.shape[-1]
        uncond_input = self.tokenizer(
            [""] * batch_size,
            padding="max_length",
            max_length=max_length,
            return_tensors="pt",
        )
        uncond_embeddings = self.text_encoder(uncond_input.input_ids.to("cuda"))[0]
        text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

        latents = torch.randn(
            (batch_size, self.unet.config.in_channels, height // 8, width // 8),
            generator=generator,
        )
        latents = latents.to("cuda").half()

        self.scheduler.set_timesteps(num_inference_steps)

        for t in self.scheduler.timesteps[
            int(len(self.scheduler.timesteps) * blending_percentage) :
        ]:
            # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
            latent_model_input = torch.cat([latents] * 2)

            latent_model_input = self.scheduler.scale_model_input(
                latent_model_input, timestep=t
            )

            # predict the noise residual
            with torch.no_grad():
                noise_pred = self.unet(
                    latent_model_input, t, encoder_hidden_states=text_embeddings
                ).sample

            # perform guidance
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (
                noise_pred_text - noise_pred_uncond
            )

            # compute the previous noisy sample x_t -> x_t-1
            latents = self.scheduler.step(noise_pred, t, latents).prev_sample

            # Blending
            noise_source_latents = self.scheduler.add_noise(
                source_latents, torch.randn_like(latents), t
            )
            latents = latents * latent_mask + noise_source_latents * (1 - latent_mask)

        latents = 1 / 0.18215 * latents

        with torch.no_grad():
            image = self.vae.decode(latents).sample

        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
        images = (image * 255).round().astype("uint8")

        return images

    @torch.no_grad()
    def _image2latent(self, image):
        image = torch.from_numpy(image).float() / 127.5 - 1
        image = image.permute(2, 0, 1).unsqueeze(0).to("cuda")
        image = image.half()
        latents = self.vae.encode(image)["latent_dist"].mean
        latents = latents * 0.18215

        return latents

    def _read_mask(self, mask_img, dest_size=(64, 64)):
        org_mask = mask_img.convert("L")
        mask = org_mask.resize(dest_size, Image.NEAREST)
        mask = np.array(mask) / 255
        mask[mask < 0.5] = 0
        mask[mask >= 0.5] = 1
        mask = mask[np.newaxis, np.newaxis, ...]
        mask = torch.from_numpy(mask).half().to(self.args.device)

        return mask, org_mask

bld = BlendedLatnetDiffusion()
def generate_mask(xyxy, orig_shape):
    # 创建空白掩码图像，大小为 orig_shape，宽×高
    mask = np.zeros((orig_shape[1], orig_shape[0]), dtype=np.uint8)
    if xyxy.ndim == 1:
        xyxy = xyxy[np.newaxis, :]
    xyxy = xyxy.astype(int)
    # 在掩码上绘制矩形，颜色为 1，填充矩形
    for line in xyxy:
        x1, y1, x2, y2 = line
        cv2.rectangle(mask, (x1, y1), (x2, y2), color=255, thickness=-1)

    return mask

from tqdm.auto import tqdm
path = 'coco/images/val2017/'
output_path = 'outputs_erase'
acc = []
files = sorted(os.listdir(path))
pbar = tqdm(files)
for filename in pbar:
    img = Image.open(os.path.join(path,filename)).convert('RGB')
    # img = img.resize((512,512),Image.BILINEAR)
    res = model(img)
    annotated_image = res[0].plot()

    if len(res[0].boxes.cls)==0 or res[0].boxes.data[0][4] < 0.8:
        continue
    x1,y1,x2,y2,conf,cls = res[0].boxes.data[0].cpu().numpy()
    xyxys = []
    for i in range(len(res[0].boxes.data)):
        if res[0].boxes.data[i][4] > 0.5 and res[0].boxes.data[i][5] == cls:
            xyxys.append(res[0].boxes.xyxy[i].cpu().numpy())
    xyxys = np.stack(xyxys)

    mask_img = Image.fromarray(generate_mask(xyxys, img.size))

    perturbed_cls = np.random.choice([i for i in range(80) if i != cls])
    new_img = bld.edit_image(
        img,
        mask_img,
        ["Integrate this area into the surrounding environment, be like the background around"],
        num_inference_steps=100
    )
    new_img = Image.fromarray(new_img[0])
    new_res = model(new_img)
    flag = 1
    for nx1,ny1,nx2,ny2,nconf,ncls in new_res[0].boxes.data:
        if ncls == cls and nconf > 0.5:
            flag = 0
    acc.append(flag)
    pbar.set_postfix({'Success rate':sum(acc)/len(acc)})
    
    
    plt.imshow(cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB))

    os.makedirs(os.path.join(output_path,filename.split('.')[0]+('_s' if flag==1 else '_f')),exist_ok=True)
    plt.savefig(os.path.join(output_path,filename.split('.')[0]+('_s' if flag==1 else '_f'),'annotated.jpg'))

    plt.imshow(cv2.cvtColor(new_res[0].plot(), cv2.COLOR_BGR2RGB))
    plt.savefig(os.path.join(output_path,filename.split('.')[0]+('_s' if flag==1 else '_f'),'new.jpg'))


