import torch
from transformers import CLIPVisionModelWithProjection, CLIPTextModelWithProjection, CLIPTokenizer, AutoProcessor
from diffusers import StableUnCLIPImg2ImgPipeline, StableDiffusionImg2ImgPipeline
from PIL import Image

PREPROCESS = torchvision.transforms.Compose([
             torchvision.transforms.Resize(size=224),
             torchvision.transforms.ToTensor(),
             torchvision.transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616]),
            ])

class LargeMultiModalModel:
    def __init__(self, device):
        self.device = device
        
        # Load the StableUnCLIP Image-to-Image Pipeline
        self.unclip_pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(
            "stabilityai/stable-diffusion-2-1-unclip",
            torch_dtype=torch.float16
        ).to(device)

        # Load the Stable Diffusion v1.5 Pipeline
        self.sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
            "runwayml/stable-diffusion-v1-5",
            torch_dtype=torch.float16
        ).to(device)

        # Load CLIP model components
        self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
            "laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
            torch_dtype=torch.float16
        ).to(device)
        
        self.processor = AutoProcessor.from_pretrained(
            "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
        )
        
        self.text_encoder = CLIPTextModelWithProjection.from_pretrained(
            "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
        ).to(device)
        
        self.tokenizer = CLIPTokenizer.from_pretrained(
            "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
        )

        # Enable memory efficient attention in the UnCLIP pipeline
        self.unclip_pipe.enable_xformers_memory_efficient_attention()
    
    def encode_text(self, prompt):
        inputs = self.tokenizer([prompt], padding=True, return_tensors="pt").to(self.device)
        return self.text_encoder(**inputs).text_embeds.cpu()

    def encode_images(self, images):
        return self.image_encoder(images.half().to(self.device)).image_embeds.cpu().float()

    def encode_one_image(self, image):
        return self.image_encoder(image.unsqueeze(0).half().to(self.device)).image_embeds.cpu().float().squeeze()
    
    def decode_one_embed(self, embed, n_steps=20, noise_level=0, prompt="", negative_prompt="", guidance_scale=10.0):
        return self.unclip_pipe(
            image_embeds=embed.unsqueeze(0).half().to(self.device),
            num_inference_steps=n_steps,
            noise_level=noise_level,
            prompt=prompt,
            negative_prompt=negative_prompt,
            guidance_scale=guidance_scale
        ).images[0]

    def classification_score(self, image=None, image_embed=None, text=None, text_embed=None):
        assert (text != None or text_embed != None) and (image != None or image_embed != None)
        if text != None:
            text_embed = self.encode_text(text)
        if image != None:
            image_embed = self.encode_one_image(image)
        return torch.cosine_similarity(image_embed, text_embed)
    
    def classification(self, image=None, image_embed=None, texts=None, text_embeds=None):
        assert (texts != None or text_embeds != None) and (image != None or image_embed != None)
        if texts != None:
            text_embeds = []
            for text in texts:
                text_embeds.append(self.encode_text(text))
            text_embeds = torch.cat(text_embeds)
        if image != None:
            image_embed = self.encode_one_image(image)
        return torch.cosine_similarity(image_embed, text_embeds).argmax(-1)

# Usage example
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    lmm = LargeMultiModalModel(device)
    
    # Example text encoding
    text_embed = lmm.encode_text("A photo of a futuristic city at sunset.")
    
    # Example image encoding
    image = Image.open("path_to_image.jpg").convert("RGB")
    processed_image = lmm.processor(image, return_tensors="pt").to(device)
    image_embed = lmm.encode_one_image(processed_image.image)
    
    # Decode an embedding
    generated_image = lmm.decode_one_embed(image_embed)
