from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
import os
import json
import glob

class QwenFeedback:
    
    def __init__(self, device, path="Qwen/Qwen2.5-VL-7B-Instruct", greedy=False):
        self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            path, torch_dtype=torch.bfloat16, device_map=device
        )
        self.model.eval()
        self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
        self.device = device
        self.greedy = greedy
    def build_message(self,filepath,filepath2,prompt):
        messages = [
            {
                "role": "system",
                "content": [{"type": "text", "text": 
                "You are an expert multimodal AI assistant specializing in analyzing image generation quality. Your task is to meticulously analyze a provided output image (Image 2), potentially referencing an intermediate image (Image 1) and the original positive prompt. Your goal is to identify specific visual flaws that degrade **perceptual quality** and **realism** (e.g., blur, noise, artifacts, anatomical errors, unnatural textures/lighting). Based on these identified quality-impacting visual issues, you will generate a targeted negative prompt (a comma-separated list of keywords) designed to **significantly improve** the quality of future generations, while also including standard high-impact quality and safety terms."
                }],
            },
            {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                        "image": filepath,
                    },
                    {
                        "type": "image",
                        "image": filepath2,
                    },
                    {
                        "type": "text",
                        "text": f"""**Positive Prompt Used:** {prompt}, ultra realistic, concept art, intricate details

                        **Task:** Analyze **Image 2** (the flawed output) against the positive prompt's goal and general quality standards. Identify specific visual flaws degrading its quality/realism. Based *only* on these observed flaws, generate a comma-separated negative prompt list.

                        **Guidelines:**

                        1.  **Identify Key Flaws & Keywords:** Look for issues like:
                            * **Clarity/Detail:** `blurry, noisy, jpeg artifacts, low detail, unclear`
                            * **Anatomy/Structure:** `bad anatomy, deformed, disfigured, bad hands, extra limbs, distorted face, malformed`
                            * **Texture/Material:** `plastic look, unrealistic texture, smooth skin, glossy`
                            * **Lighting/Color:** `flat lighting, harsh lighting, unnatural colors, oversaturated, bad colors`
                            * **Artifacts/Style:** `text, watermark, signature, artifacts, glitches, cropped, bad composition, wrong style (cartoon, drawing, illustration, 3d render)`
                        2.  **Be Specific:** Use precise keywords targeting the *actual problems* you see in Image 2.
                        3.  **Prioritize:** Focus on flaws most harmful to realism and detail.
                        4.  **Add Standard Terms:** **Always** append this core list: `low quality, worst quality, bad quality, normal quality, lowres, ugly, nsfw, text, signature, watermark`
                        5.  **Output Format:** Provide **only** the comma-separated keywords. No explanations.

                        **Example (If Image 2 is blurry, has a deformed hand, and a watermark):**
                        blurry, unclear, bad hands, deformed limbs, bad anatomy, deformed, watermark, low quality, worst quality, bad quality, normal quality, lowres, ugly, nsfw, text, signature
                        """
            }
                ],
            }
        ]
        return messages

    def evaluate_image(self, filepath, filepath2, prompt):
        messages = self.build_message(filepath, filepath2, prompt)
        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        image_inputs, video_inputs = process_vision_info(messages)
        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to(self.device)
       
        generated_ids = self.model.generate(**inputs, max_new_tokens=128)
        generated_ids_trimmed = [
            out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        output_text = self.processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
        res = output_text[0]
        return res


if __name__ == "__main__":
    vlm_path = "Qwen/Qwen2.5-VL-7B-Instruct"
    feedback = QwenFeedback('cuda', vlm_path, greedy=False)
    prompt_dict = []
    for i in range(5000):
        json_path = os.path.join("/path/to/prompt", f"{i:05d}.json") 
        with open(json_path, 'r', encoding='utf-8') as f:
            info = json.load(f)
        prompt = info['caption']
        base_name = f"{i:05d}"
        file_pattern = os.path.join("/path/to/output/cofusion", f"{base_name}*.jpg") 
        matching_files = glob.glob(file_pattern)
        image_path2 = matching_files[0] 
        parts = image_path2.split('/')  
        image_name = parts[-1] 
        image_path1 = os.path.join("/path/to/output/pred", image_name) 
        begin = time.time()
        feedback_results = feedback.evaluate_image(image_path1, image_path2, prompt)
        en = time.time()
       
        prompt_dict.append(feedback_results)
      
    
    result = []
    for i, item in enumerate(prompt_dict):
        result.append({str(i): item})
    with open("/path/to/reflection.json", "w", encoding="utf-8") as f:
        json.dump(result, f, ensure_ascii=False, indent=4)