# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import dataclasses
from typing import Literal
import torch
from accelerate import Accelerator
from transformers import HfArgumentParser
from PIL import Image
import json
import itertools
from safetensors.torch import load_file
from core.flux.pipeline_mask import UNOPipeline, preprocess_ref
from core.flux.util_mask import (
    get_lora_rank,
    load_ae,
    load_checkpoint,
    load_clip,
    load_flow_model,
    load_flow_model_only_lora,
    load_flow_model_quintized,
    load_t5,
)
import random
from tqdm import tqdm

def horizontal_concat(images):
    widths, heights = zip(*(img.size for img in images))

    total_width = sum(widths)
    max_height = max(heights)

    new_im = Image.new('RGB', (total_width, max_height))

    x_offset = 0
    for img in images:
        new_im.paste(img, (x_offset, 0))
        x_offset += img.size[0]

    return new_im

@dataclasses.dataclass
class InferenceArgs:
    prompt: str | None = None
    image_paths: list[str] | None = None
    eval_json_path: str = "./test.json"
    offload: bool = False
    num_images_per_prompt: int = 1
    model_type: Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] = "flux-dev"
    width: int = 512
    height: int = 512
    ref_size: int = -1
    num_steps: int = 25
    guidance: float = 4
    seed: int = 41
    save_path: str = "./output/mask"
    only_lora: bool = True
    concat_refs: bool = False
    lora_rank: int = 512
    data_resolution: int = 512
    pe: Literal['d', 'h', 'w', 'o'] = 'd'

    checkpoint: str = './checkpoint'
    step: int = 100000

    mask_ref: bool = False

# dropout
def create_custom_attention_mask(text_len, gen_img_len, ref_img_len, device='cuda'):
    
    seq_len = text_len + gen_img_len + ref_img_len
    mask = torch.ones(seq_len, seq_len, device=device)
    
    gen_img_start = text_len
    gen_img_end = text_len + gen_img_len
    ref_img_start = text_len + gen_img_len

    mask[gen_img_start:gen_img_end, ref_img_start:] = 0
    
    mask = mask == 0

    return mask


def main(args: InferenceArgs):
    accelerator = Accelerator()

    # ==================== init model====================
    use_fp8 = "fp8" in args.model_type
    if args.only_lora:
        model = load_flow_model_only_lora(
            args.model_type,
            device="cpu" if args.offload else accelerator.device,
            lora_rank=args.lora_rank,
            use_fp8=use_fp8
        )
    else:
        model = load_flow_model(args.model_type, device="cpu" if args.offload else accelerator.device)
    
    model.set_MLP(device="cpu" if args.offload else accelerator.device)
    # ==================== init model====================

    # ====================load checkpoint==================== 

    global_mlp_checkpoint_path = f'{args.checkpoint}/checkpoint-{args.step}/step{args.step}_concept_MLP_global_params.safetensors'
    double_ref_injects_path    = f'{args.checkpoint}/checkpoint-{args.step}/step{args.step}_double_ref_injects_params.safetensors'
    single_ref_injects_path    = f'{args.checkpoint}/checkpoint-{args.step}/step{args.step}_single_ref_injects_params.safetensors'

    global_mlp_state_dict     = load_file(global_mlp_checkpoint_path)
    double_ref_injects_state  = load_file(double_ref_injects_path)
    single_ref_injects_state  = load_file(single_ref_injects_path)

    model.concept_MLP_global.load_state_dict(global_mlp_state_dict)
    model.double_ref_injects.load_state_dict(double_ref_injects_state)
    model.single_ref_injects.load_state_dict(single_ref_injects_state)

    model.concept_MLP_global.to(dtype=torch.bfloat16)
    model.double_ref_injects.to(dtype=torch.bfloat16)
    model.single_ref_injects.to(dtype=torch.bfloat16)

    for processor in model.attn_processors.values():
        processor.to(dtype=torch.bfloat16)
    

    # ====================load checkpoint==================== 

    pipeline = UNOPipeline(
        model,
        args.model_type,
        accelerator.device,
        args.offload,
        only_lora=args.only_lora,
        lora_rank=args.lora_rank
    )

    assert args.prompt is not None or args.eval_json_path is not None, \
        "Please provide either prompt or eval_json_path"
    
    # print(args.image_paths)
    # print(len(args.image_paths))
    
    # image_path_list = args.image_paths[0].split(',')
    # print(image_path_list)
    if args.eval_json_path is not None:
        with open(args.eval_json_path, "rt") as f:
            data_dicts = json.load(f)
        data_root = os.path.dirname(args.eval_json_path)
    else:
        data_root = "./"
        data_dicts = [{
            "prompt": args.prompt,
            "image_paths": image_path_list
        }]

    for (i, data_dict), j in tqdm(itertools.product(enumerate(data_dicts), range(args.num_images_per_prompt))):
        if (i * args.num_images_per_prompt + j) % accelerator.num_processes != accelerator.process_index:
            continue
        
        # for img_path in data_dict["image_paths"]:
        #     print(img_path)

        ref_imgs = [
            Image.open(os.path.join(data_root, img_path))
            for img_path in data_dict["image_paths"] # ref_imgs
        ]

        if args.ref_size == -1:
            args.ref_size = 512 if len(ref_imgs) == 1 else 320
        print(len(ref_imgs))
        ref_imgs = [preprocess_ref(img, args.ref_size) for img in ref_imgs]
        concepts = data_dict["ref_txt"]
        # prompt = "a backpack and a stuffed animal in the jungle"
        
        image_gen = pipeline(
            prompt=data_dicts[i]["prompt"],
            width=args.width,
            height=args.height,
            guidance=args.guidance,
            num_steps=args.num_steps,
            seed=args.seed,
            ref_imgs=ref_imgs,
            personalized_concepts=concepts,
            pe=args.pe,
            mask_ref = args.mask_ref
        )
        if args.concat_refs:
            image_gen = horizontal_concat([image_gen, *ref_imgs])

        os.makedirs(args.save_path, exist_ok=True)
        image_gen.save(os.path.join(args.save_path, f"{i}_{j}.png"))

        # save config and image
        args_dict = vars(args)
        args_dict['prompt'] = data_dict["prompt"]
        args_dict['image_paths'] = data_dict["image_paths"]
        with open(os.path.join(args.save_path, f"{i}_{j}.json"), 'w') as f:
            json.dump(args_dict, f, indent=4)        

if __name__ == "__main__":
    parser = HfArgumentParser([InferenceArgs])
    args = parser.parse_args_into_dataclasses()[0]
    main(args)
