# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import Literal
from torch import Tensor
import torch
from einops import rearrange
from safetensors.torch import load_file
from PIL import ExifTags, Image
import torchvision.transforms.functional as TVF

from core.flux.modules.layers import (
    DoubleStreamBlockLoraProcessor,
    DoubleStreamBlockProcessor,
    SingleStreamBlockLoraProcessor,
    SingleStreamBlockProcessor,
)
from core.flux.sampling_mask import denoise, get_noise, get_schedule, prepare_multi_ip, unpack
from core.flux.util_mask import (
    get_lora_rank,
    load_ae,
    load_checkpoint,
    load_clip,
    load_flow_model,
    load_flow_model_only_lora,
    load_flow_model_quintized,
    load_t5,
)

def create_custom_attention_mask(text_len, gen_img_len, ref_img_len, device='cuda'):
    
    seq_len = text_len + gen_img_len + ref_img_len
    mask = torch.ones(seq_len, seq_len, device=device)
    
    gen_img_start = text_len
    gen_img_end = text_len + gen_img_len
    ref_img_start = text_len + gen_img_len

    mask[gen_img_start:gen_img_end, ref_img_start:] = 0
    
    mask = mask == 0

    return mask

def find_sublist_range(main_list, sublist):
    sublist_len = len(sublist)
    for i in range(len(main_list) - sublist_len + 1):
        if main_list[i:i + sublist_len] == sublist:
            return range(i, i + sublist_len)
    return None  # Return None if sublist is not found

def find_nearest_scale(image_h, image_w, predefined_scales):
    
    image_ratio = image_h / image_w

    min_diff = float('inf')
    nearest_scale = None

    for scale_h, scale_w in predefined_scales:
        predefined_ratio = scale_h / scale_w
        diff = abs(predefined_ratio - image_ratio)

        if diff < min_diff:
            min_diff = diff
            nearest_scale = (scale_h, scale_w)

    return nearest_scale

def preprocess_ref(raw_image: Image.Image, long_size: int = 512):
    image_w, image_h = raw_image.size

    if image_w >= image_h:
        new_w = long_size
        new_h = int((long_size / image_w) * image_h)
    else:
        new_h = long_size
        new_w = int((long_size / image_h) * image_w)

    raw_image = raw_image.resize((new_w, new_h), resample=Image.LANCZOS)
    target_w = new_w // 16 * 16
    target_h = new_h // 16 * 16

    left = (new_w - target_w) // 2
    top = (new_h - target_h) // 2
    right = left + target_w
    bottom = top + target_h

    raw_image = raw_image.crop((left, top, right, bottom))

    raw_image = raw_image.convert("RGB")
    return raw_image

class UNOPipeline:
    def __init__(
        self,
        model,
        model_type: str,
        device: torch.device,
        offload: bool = False,
        only_lora: bool = False,
        lora_rank: int = 16
    ):
        self.device = device
        self.offload = offload
        self.model_type = model_type

        self.clip = load_clip(self.device)
        self.t5 = load_t5(self.device, max_length=512)
        self.ae = load_ae(model_type, device="cpu" if offload else self.device)
        self.use_fp8 = "fp8" in model_type
        self.model = model

    def load_ckpt(self, ckpt_path):
        if ckpt_path is not None:
            from safetensors.torch import load_file as load_sft
            print("Loading checkpoint to replace old keys")
            # load_sft doesn't support torch.device
            if ckpt_path.endswith('safetensors'):
                sd = load_sft(ckpt_path, device='cpu')
                missing, unexpected = self.model.load_state_dict(sd, strict=False, assign=True)
            else:
                dit_state = torch.load(ckpt_path, map_location='cpu')
                sd = {}
                for k in dit_state.keys():
                    sd[k.replace('module.','')] = dit_state[k]
                missing, unexpected = self.model.load_state_dict(sd, strict=False, assign=True)
                self.model.to(str(self.device))
            print(f"missing keys: {missing}\n\n\n\n\nunexpected keys: {unexpected}")

    def set_lora(self, local_path: str = None, repo_id: str = None,
                name: str = None, lora_weight: int = 0.7):
        checkpoint = load_checkpoint(local_path, repo_id, name)
        self.update_model_with_lora(checkpoint, lora_weight)

    def set_lora_from_collection(self, lora_type: str = "realism", lora_weight: int = 0.7):
        checkpoint = load_checkpoint(
            None, self.hf_lora_collection, self.lora_types_to_names[lora_type]
        )
        self.update_model_with_lora(checkpoint, lora_weight)

    def update_model_with_lora(self, checkpoint, lora_weight):
        rank = get_lora_rank(checkpoint)
        lora_attn_procs = {}

        for name, _ in self.model.attn_processors.items():
            lora_state_dict = {}
            for k in checkpoint.keys():
                if name in k:
                    lora_state_dict[k[len(name) + 1:]] = checkpoint[k] * lora_weight

            if len(lora_state_dict):
                if name.startswith("single_blocks"):
                    lora_attn_procs[name] = SingleStreamBlockLoraProcessor(dim=3072, rank=rank)
                else:
                    lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=3072, rank=rank)
                lora_attn_procs[name].load_state_dict(lora_state_dict)
                lora_attn_procs[name].to(self.device)
            else:
                if name.startswith("single_blocks"):
                    lora_attn_procs[name] = SingleStreamBlockProcessor()
                else:
                    lora_attn_procs[name] = DoubleStreamBlockProcessor()

        self.model.set_attn_processor(lora_attn_procs)


    def __call__(
        self,
        prompt: str,
        width: int = 512,
        height: int = 512,
        guidance: float = 4,
        num_steps: int = 50,
        seed: int = 123456789,
        **kwargs
    ):
        width = 16 * (width // 16)
        height = 16 * (height // 16)

        device_type = self.device if isinstance(self.device, str) else self.device.type
        if device_type == "mps":
            device_type = "cpu"  # for support macos mps
        with torch.autocast(enabled=self.use_fp8, device_type=device_type, dtype=torch.bfloat16):
            return self.forward(
                prompt,
                width,
                height,
                guidance,
                num_steps,
                seed,
                **kwargs
            )

    @torch.inference_mode()
    def gradio_generate(
        self,
        prompt: str,
        width: int,
        height: int,
        guidance: float,
        num_steps: int,
        seed: int,
        image_prompt1: Image.Image,
        image_prompt2: Image.Image,
        image_prompt3: Image.Image,
        image_prompt4: Image.Image,
    ):
        ref_imgs = [image_prompt1, image_prompt2, image_prompt3, image_prompt4]
        ref_imgs = [img for img in ref_imgs if isinstance(img, Image.Image)]
        ref_long_side = 512 if len(ref_imgs) <= 1 else 320
        ref_imgs = [preprocess_ref(img, ref_long_side) for img in ref_imgs]

        seed = seed if seed != -1 else torch.randint(0, 10 ** 8, (1,)).item()

        img = self(prompt=prompt, width=width, height=height, guidance=guidance,
                num_steps=num_steps, seed=seed, ref_imgs=ref_imgs)

        filename = f"output/gradio/{seed}_{prompt[:20]}.png"
        os.makedirs(os.path.dirname(filename), exist_ok=True)
        exif_data = Image.Exif()
        exif_data[ExifTags.Base.Make] = "UNO"
        exif_data[ExifTags.Base.Model] = self.model_type
        info = f"{prompt=}, {seed=}, {width=}, {height=}, {guidance=}, {num_steps=}"
        exif_data[ExifTags.Base.ImageDescription] = info
        img.save(filename, format="png", exif=exif_data)
        return img, filename

    @torch.inference_mode
    def forward(
        self,
        prompt: str,
        width: int,
        height: int,
        guidance: float,
        num_steps: int,
        seed: int,
        personalized_concepts: list[str],
        ref_imgs: list[Image.Image] | None = None,
        pe: Literal['d', 'h', 'w', 'o'] = 'd',
        mask_ref: bool = False,
    ):
        x = get_noise(
            1, height, width, device=self.device,
            dtype=torch.bfloat16, seed=seed
        )
        timesteps = get_schedule(
            num_steps,
            (width // 8) * (height // 8) // (16 * 16),
            shift=True,
        )
        if self.offload:
            self.ae.encoder = self.ae.encoder.to(self.device)
        
        prompt_t5_ids = self.t5.get_t5_idx_of_word(prompt)
        concept_tokens_range_list = []
        x_1_refs = []
        
        for concept, ref_img in zip(personalized_concepts, ref_imgs):
            concept_t5_ids = self.t5.get_t5_idx_of_word(concept)[:-1]
            token_range = find_sublist_range(prompt_t5_ids, concept_t5_ids)
            if token_range is None:
                print(f"[Warning] Concept token '{concept}' not found in prompt. Skipping.")
                continue
            concept_tokens_range_list.append(token_range)

            ref_img_tensor = (TVF.to_tensor(ref_img) * 2.0 - 1.0).unsqueeze(0).to(self.device, torch.float32)
            ref_img_encoded = self.ae.encode(ref_img_tensor).to(torch.bfloat16)
            x_1_refs.append(ref_img_encoded)

        if self.offload:
            self.offload_model_to_cpu(self.ae.encoder)
            self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)

        inp_cond = prepare_multi_ip(
            t5 = self.t5, clip=self.clip,
            img = x,
            prompt = prompt, ref_imgs = x_1_refs, pe = pe
        )

        gen_img_len = x.shape[2] * x.shape[3] // 4
        ref_img_len = 0
        for ref_img in x_1_refs:
            ref_img_len += ref_img.shape[2] * ref_img.shape[3] // 4
            # print("ref_img_shape: ", ref_img.shape)
        text_len = inp_cond["txt"].shape[1]
        attn_mask = create_custom_attention_mask(text_len, gen_img_len, ref_img_len, self.device)

        # print("img shape :", x.shape)
        if self.offload:
            self.offload_model_to_cpu(self.t5, self.clip)
            self.model = self.model.to(self.device)

        # here
        # personalized_concept_word = 'Corgi'
        x = denoise(
            self.model,
            **inp_cond,
            timesteps = timesteps,
            guidance = guidance,
            concept_tokens_range_list = concept_tokens_range_list,
            mask_ref = mask_ref,
            attn_mask = attn_mask
        )

        if self.offload:
            self.offload_model_to_cpu(self.model)
            self.ae.decoder.to(x.device)

        x = unpack(x.float(), height, width)
        x = self.ae.decode(x)
        self.offload_model_to_cpu(self.ae.decoder)

        x1 = x.clamp(-1, 1)
        x1 = rearrange(x1[-1], "c h w -> h w c")
        output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy())
        return output_img

    def offload_model_to_cpu(self, *models):
        if not self.offload: return
        for model in models:
            model.cpu()
            torch.cuda.empty_cache()
