import argparse
import torch
from diffusers.utils import load_image, check_min_version
from diffusers import FluxPriorReduxPipeline, FluxFillPipeline
from diffusers import FluxTransformer2DModel
import numpy as np
from torchvision import transforms
import base64
from io import BytesIO
from PIL import Image
from vllm import LLM
from vllm.sampling_params import SamplingParams

prompts_1 = """Please describe the characters on the upper-body garment in each of the two images separately."""
prompts_2 = """Please rate the consistency of the characters on the upper-body garments in the two images on a scale from 0 to 5. If the characters are completely consistent, please output 5. If the characters are completely inconsistent, please output 0, and do not output anything else."""
prompts_3 = """Please indicate which part of the upper-body garment of the second image needs to be modified based on the description of the garments."""


def encode_image_base64(
        
        image: Image.Image,
        *,
        image_mode: str = "RGB",
        format: str = "JPEG",
) -> str:
    buffered = BytesIO()
    image = image.convert(image_mode)
    image.save(buffered, format)
    return base64.b64encode(buffered.getvalue()).decode('utf-8')

def resize_max_side(image, max_side=1024):
    width, height = image.size
    if width > height:
        new_width = max_side
        new_height = int(height * max_side / width)
    else:
        new_height = max_side
        new_width = int(width * max_side / height)
    return image.resize((new_width, new_height), Image.LANCZOS)

def get_response_multi_round(llm, images, instruction_1, other_instructions,sampling_params):
    if isinstance(instruction_1, str):
        instruction_1 = [instruction_1] * len(images)
    messages = []
    for image, inst in zip(images, instruction_1):
        content = [{"type": "text", "text": inst}]
        if isinstance(image, Image.Image):
            image = [image, ]
        for img in image:
            base64_image = encode_image_base64(img)
            content.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}})

        messages.append([
            {
                "role": "user",
                "content": content,
            }]
        )
    outputs = llm.chat(messages=messages, sampling_params=sampling_params)
    all_text = [k.outputs[0].text for k in outputs]
    output_full = [all_text, ]
    for this_round_instruction in other_instructions:
        for idx in range(len(messages)):
            messages[idx].append(
                {
                    "role": "system",
                    "content": [{"type": "text", "text": all_text[idx]}]
                }
            )

            messages[idx].append(
                {
                    "role": "user",
                    "content": [{"type": "text", "text": this_round_instruction}]
                }
            )
        new_outputs = llm.chat(messages=messages, sampling_params=sampling_params)
        all_text = [k.outputs[0].text for k in new_outputs]
        output_full.append(all_text)
    return output_full

def run_inference(
    image_path,
    mask_path,
    garment_path,
    size=(768, 1024),
    num_steps=50,
    guidance_scale=30,
    seed=42,
    pipe=None
):
    # Build pipeline
    if pipe is None:
        transformer = FluxTransformer2DModel.from_pretrained(
            "flux_fineturn", 
            torch_dtype=torch.bfloat16
        )
        pipe = FluxFillPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-dev",
            transformer=transformer,
            torch_dtype=torch.bfloat16
        ).to("cuda")
    else:
        pipe.to("cuda")

    pipe.transformer.to(torch.bfloat16)

    # Add transform
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])  # For RGB images
    ])
    mask_transform = transforms.Compose([
        transforms.ToTensor()
    ])

    # Load and process images
    # print("image_path", image_path)
    image = load_image(image_path).convert("RGB").resize(size)
    mask = load_image(mask_path).convert("RGB").resize(size)
    garment = load_image(garment_path).convert("RGB").resize(size)

    # Transform images using the new preprocessing
    image_tensor = transform(image)
    mask_tensor = mask_transform(mask)[:1]  # Take only first channel
    garment_tensor = transform(garment)

    # Create concatenated images
    inpaint_image = torch.cat([garment_tensor, image_tensor], dim=2)  # Concatenate along width
    garment_mask = torch.zeros_like(mask_tensor)
    extended_mask = torch.cat([garment_mask, mask_tensor], dim=2)

    prompt_prior = f"The pair of images highlights a clothing and its styling on a model, high resolution, 4K, 8K; " \
            f"[IMAGE1] Detailed product shot of a clothing" \
            f"[IMAGE2] The same cloth is worn by a model in a lifestyle setting."

    generator = torch.Generator(device="cuda").manual_seed(seed)

    cur_responses = []
    vlm_model_name = "/mnt/pretrained_model/pixtral_fineturn/"
    sampling_params = SamplingParams(max_tokens=16384, temperature=0.70,seed=1024)

    for i in range (3):
        
        result = pipe(
            height=size[1],
            width=size[0] * 2,
            image=inpaint_image,
            mask_image=extended_mask,
            num_inference_steps=num_steps,
            generator=generator,
            max_sequence_length=512,
            guidance_scale=guidance_scale,
            prompt=prompt,
        ).images[0]

        tryon_result = result.crop((size[0], 0, size[0] * 2, size[1]))

        llm = LLM(model=vlm_model_name, limit_mm_per_prompt={"image": 10, "video": 10},max_model_len=32768)

        cur_images = [[resize_max_side(garment), resize_max_side(result)]]

        cur_responses = get_response_multi_round(llm, cur_images, prompts_1, [prompts_2, prompts_3],sampling_params)
        score = cur_responses[0]
        instruct_prompt = cur_responses[1]
        if score>=4:
            break
        else:
            prompt = prompt_prior+instruct_prompt
            continue
           
    return tryon_result

def main():
    parser = argparse.ArgumentParser(description='Run FLUX virtual try-on inference')
    parser.add_argument('--image', required=True, help='Path to the model image')
    parser.add_argument('--mask', required=True, help='Path to the agnostic mask')
    parser.add_argument('--garment', required=True, help='Path to the garment image')
    parser.add_argument('--output_garment', default='flux_inpaint_garment.png', help='Output path for garment result')
    parser.add_argument('--output_tryon', default='flux_inpaint_tryon.png', help='Output path for try-on result')
    parser.add_argument('--steps', type=int, default=50, help='Number of inference steps')
    parser.add_argument('--guidance_scale', type=float, default=30, help='Guidance scale')
    parser.add_argument('--seed', type=int, default=0, help='Random seed')
    parser.add_argument('--width', type=int, default=576, help='Width')
    parser.add_argument('--height', type=int, default=768, help='Height')
    
    args = parser.parse_args()
    
    check_min_version("0.30.2")

    tryon_result = run_inference(
        image_path=args.image,
        mask_path=args.mask,
        garment_path=args.garment,
        num_steps=args.steps,
        guidance_scale=args.guidance_scale,
        seed=args.seed,
        size=(args.width, args.height)
    )
    output_tryon_path=args.output_tryon
    
    tryon_result.save(output_tryon_path)
    
    print("Successfully saved garment and try-on images")

if __name__ == "__main__":
    main()