import torch
from torchvision import transforms
from PIL import Image
import clip
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from diffusers import StableDiffusionPipeline
from transformers import BlipProcessor, BlipForConditionalGeneration
from models.clipcap import ClipCap  
from diffusers import StableDiffusionXLImg2ImgPipeline
from transformers import Blip2Processor, Blip2ForConditionalGeneration
class RAGGenerator:
    def __init__(self, device=None):
        
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device

        print("Loading BLIP model for image-to-text generation...")

        self.blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
        self.blip_model = Blip2ForConditionalGeneration.from_pretrained(
            "Salesforce/blip2-opt-2.7b",  torch_dtype=torch.float16
        )
        self.blip_model.to(self.device)
        print("BLIP model loaded successfully")      
        self.sd_pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained(
            "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
        )
        self.sd_pipeline.to(self.device)
        
        if torch.cuda.is_available():
            self.sd_pipeline.enable_attention_slicing()
            self.sd_pipeline.enable_model_cpu_offload()
        
        print("StableDiffusion pipeline loaded successfully")
        
    def text_to_image(self,sample_img, text, num_inference_steps=10, 
                                   guidance_scale=7.5, height=512, width=512):

        img = Image.open(sample_img).convert('RGB')
        width, height = img.size
        if width < 256 or height < 256:
            img = img.resize((256, 256), Image.BICUBIC)  # or Image.LANCZOS for high-quality
        elif width>512 and height > 512:
            img = img.resize((512,512), Image.BICUBIC)

        generated_img = self._generate_img(
            text,
            img,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            height=height,
            width=width
        )
        return generated_img

    def text_to_image_enhance(self, text, num_inference_steps=10, 
                                   guidance_scale=7.5, height=512, width=512):

        generated_img = self._generate_img_enhance(
            text,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            height=height,
            width=width
        )
        
        return generated_img
    def text_to_image_3(self,sample_img_1,sample_img_2,sample_img_3, text, num_inference_steps=10, 
                                   guidance_scale=7.5, height=512, width=512):
        
        imgs = [Image.open(p).resize((256, 256)) for p in [sample_img_1, sample_img_2, sample_img_3]]
        concat_img = Image.new('RGB', (256 * 3, 256))  # 横向拼接
        for i, img in enumerate(imgs):
            concat_img.paste(img, (i * 256, 0))
        generated_img = self._generate_img(
            text,
            concat_img,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            height=height,
            width=width
        )
        return generated_img
    
    def image_to_text(self, image,sample_txt):    
        generated_text = self._generate_text(image,sample_txt)
        #print(generated_text)
        return generated_text
    

    def _generate_text(self,image_path,sample_txt):
        
        from PIL import Image
        # Open and convert image to RGB if needed
        raw_image = Image.open(image_path)
        if raw_image.mode != "RGB":
            raw_image = raw_image.convert("RGB")
        inputs = self.blip_processor(images=raw_image, text=sample_txt, return_tensors="pt").to(self.device,torch.float16)
        generated_ids =  self.blip_model.generate(**inputs,do_sample=True,
                                temperature=1.0,
                                top_p=0.9,
                                max_new_tokens=50)
        generated_text = self.blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
        
        return generated_text

        
    def _generate_img(self, prompt_text,sample_img, num_inference_steps=20, 
                                   guidance_scale=7.5, height=256, width=256):
        
        return self._generate_with_text_prompt(
            prompt_text, sample_img,num_inference_steps, guidance_scale, height, width
        )
   
    def _generate_with_text_prompt(self, prompt_text, sample_img,num_inference_steps, guidance_scale, height, width):
        
        with torch.no_grad():
            
            images = self.sd_pipeline(
                prompt=prompt_text,
                image=sample_img,
                strength=0.75,   
                guidance_scale=7.5,
                num_inference_steps=10,
                num_images_per_prompt=5 
            ).images
            
        return images
    