import os
import re
import time
from dataclasses import dataclass
from glob import iglob
import torch
import numpy as np
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import FileResponse
from PIL import ExifTags, Image
from einops import rearrange
from flux.sampling import prepare
from flux.util import (configs, load_ae, load_clip, load_t5)
from models.kv_edit import Flux_kv_edit
os.environ["TRANSFORMERS_OFFLINE"] = "1"
# FastAPI app initialization
app = FastAPI()
@dataclass
class SamplingOptions:
    source_prompt: str = ''
    target_prompt: str = ''
    width: int = 512
    height: int = 512
    inversion_num_steps: int = 0
    denoise_num_steps: int = 0
    skip_step: int = 0
    inversion_guidance: float = 1.5
    denoise_guidance: float = 5.5
    seed: int = 42
    re_init: bool = False
    attn_mask: bool = True
    attn_scale: float = 1.0
class FluxEditorAPI:
    def __init__(self, name="flux-dev"):
        self.device = [torch.device("cuda:5"), torch.device("cuda:4")]
        self.name = name
        self.is_schnell = name == "flux-schnell"
        self.output_dir = 'regress_result'
        self.t5 = load_t5(self.device[1], max_length=256 if self.name == "flux-schnell" else 512)
        self.clip = load_clip(self.device[1])
        self.model = Flux_kv_edit(self.device[0], name=self.name)
        self.ae = load_ae(self.name, device=self.device[1])
        self.t5.eval()
        self.clip.eval()
        self.ae.eval()
        self.model.eval()
        self.info = {}
    @torch.inference_mode()
    def inverse(self, brush_canvas, opts):
        if hasattr(self, 'z0'):
            del self.z0
            del self.zt
        if 'feature' in self.info:
            key_list = list(self.info['feature'].keys())
            for key in key_list:
                del self.info['feature'][key]
        self.info = {}
        rgba_init_image = brush_canvas["background"]
        init_image = rgba_init_image[:, :, :3]
        height = init_image.shape[0] - (init_image.shape[0] % 16)
        width = init_image.shape[1] - (init_image.shape[1] % 16)
        init_image = init_image[:height, :width, :]
        rgba_init_image = rgba_init_image[:height, :width, :]
        opts.width, opts.height = width, height
        torch.manual_seed(opts.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(opts.seed)
        torch.cuda.empty_cache()
        rgba_mask = brush_canvas["layers"][0][:height, :width, :]
        mask = rgba_mask[:, :, 3] / 255
        mask = mask.astype(int)
        mask = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0).to(torch.bfloat16).to(self.device[0])
        self.init_image = self.encode(init_image, self.device[1]).to(self.device[0])
        with torch.no_grad():
            inp = prepare(self.t5, self.clip, self.init_image, prompt=opts.source_prompt)
            self.z0, self.zt, self.info = self.model.inverse(inp, mask, opts)
    @torch.inference_mode()
    def edit(self, brush_canvas, opts):
        torch.cuda.empty_cache()
        rgba_init_image = brush_canvas["background"]
        init_image = rgba_init_image[:, :, :3]
        height = init_image.shape[0] - (init_image.shape[0] % 16)
        width = init_image.shape[1] - (init_image.shape[1] % 16)
        init_image = init_image[:height, :width, :]
        rgba_init_image = rgba_init_image[:height, :width, :]
        rgba_mask = brush_canvas["layers"][0][:height, :width, :]
        mask = rgba_mask[:, :, 3] / 255
        mask = mask.astype(int)
        mask = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0).to(torch.bfloat16).to(self.device[0])
        rgba_mask[:, :, 3] = rgba_mask[:, :, 3] // 2
        masked_image = Image.alpha_composite(Image.fromarray(rgba_init_image, 'RGBA'),
                                             Image.fromarray(rgba_mask, 'RGBA'))
        torch.manual_seed(opts.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(opts.seed)
        with torch.no_grad():
            inp_target = prepare(self.t5, self.clip, self.init_image, prompt=opts.target_prompt)
        x = self.model.denoise(self.z0, self.zt, inp_target, mask, opts, self.info)
        with torch.autocast(device_type=self.device[1].type, dtype=torch.bfloat16):
            x = self.ae.decode(x.to(self.device[1]))
        x = x.clamp(-1, 1)
        x = x.float().cpu()
        x = rearrange(x[0], "c h w -> h w c")
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        output_name = os.path.join(self.output_dir, "img_{idx}.jpg")
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)
            idx = 0
        else:
            fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
            idx = 0 if not fns else max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
        fn = output_name.format(idx=idx)
        img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
        exif_data = Image.Exif()
        exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
        exif_data[ExifTags.Base.Make] = "Black Forest Labs"
        exif_data[ExifTags.Base.Model] = self.name
        exif_data[ExifTags.Base.ImageDescription] = opts.target_prompt
        img.save(fn, exif=exif_data, quality=95, subsampling=0)
        masked_image.save(fn.replace(".jpg", "_mask.png"), format='PNG')
        return fn
    @torch.inference_mode()
    def encode(self, init_image, torch_device):
        init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
        init_image = init_image.unsqueeze(0).to(torch_device)
        self.ae.encoder.to(torch_device)
        return self.ae.encode(init_image).to(torch.bfloat16)
editor = FluxEditorAPI()
@app.post("/edit_image")
async def edit_image(
    image: UploadFile = File(...),
    mask: UploadFile = File(...),
    source_prompt: str = Form(default="a group of red apples with one green apple on the left"),
    target_prompt: str = Form(default="a group of red apples with one pear on the left"),
    inversion_num_steps: int = Form(default=24),
    denoise_num_steps: int = Form(default=24),
    skip_step: int = Form(default=0),
    inversion_guidance: float = Form(default=1.5),
    denoise_guidance: float = Form(default=5.5),
    seed: int = Form(default=0),
    re_init: bool = Form(default=False),
    attn_scale: float = Form(default=1.0)
):
    try:
        # Load input image and mask
        input_image = Image.open(image.file).convert("RGBA")
        mask_image = Image.open(mask.file).convert("L")
        image_array = np.array(input_image)

        if mask_image.size != input_image.size:
            mask_image = mask_image.resize(input_image.size, Image.Resampling.NEAREST)
        mask_array = np.array(mask_image)
        mask_binary = (mask_array > 0).astype(np.uint8) * 255
        mask_rgba = np.zeros((mask_array.shape[0], mask_array.shape[1], 4), dtype=np.uint8)
        mask_rgba[:, :, 3] = mask_binary

        brush_canvas = {"background": image_array, "layers": [mask_rgba]}
        opts = SamplingOptions(
            source_prompt=source_prompt,
            target_prompt=target_prompt,
            inversion_num_steps=inversion_num_steps,
            denoise_num_steps=denoise_num_steps,
            skip_step=skip_step,
            inversion_guidance=inversion_guidance,
            denoise_guidance=denoise_guidance,
            seed=seed,
            re_init=re_init,
            attn_mask=True,
            attn_scale=attn_scale
        )
        editor.inverse(brush_canvas, opts)
        output_file = editor.edit(brush_canvas, opts)
        return FileResponse(output_file, media_type="image/jpeg", filename=os.path.basename(output_file))
    except Exception as e:
        print(f"Error: {e}")
        raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=5000)