from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
import torch
import torch.nn.functional as F
from tqdm import tqdm
import os
from PIL import Image
from torchvision import transforms
#Loading CLIP model
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to("cuda")
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip_model.eval()  

clip_preprocess = transforms.Compose([
    transforms.Resize((224, 224)),  # size required by CLIP
    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                         std=[0.26862954, 0.26130258, 0.27577711]),
])


device = "cuda" if torch.cuda.is_available() else "cpu"

def compute_CLIP_score(image, prompt):
    print(image)
    return_mean = False
    if isinstance(image, torch.Tensor):
        image = image.clamp(0, 1).float()
        image = image.squeeze(0)  # remove batch dim if present
        processed_image = clip_preprocess(image)
        processed_image = processed_image.unsqueeze(0)
        image = processed_image.to("cuda")
        return_mean = True
    else:
        image = image.convert("RGB")

    inputs = clip_processor(text=prompt, images=image, return_tensors="pt", padding=True).to(device)

    with torch.no_grad():
        image_features = clip_model.get_image_features(inputs.pixel_values)
        text_features = clip_model.get_text_features(inputs.input_ids)

        image_features = F.normalize(image_features, dim=-1)
        text_features = F.normalize(text_features, dim=-1)

        clip_score = F.cosine_similarity(image_features, text_features).mean()
    if return_mean:
        return clip_score
    else:
        return clip_score.item()



def preprocess_for_clip(decoded, target_size=224):
    # 1. Resize with bilinear interpolation
    x = F.interpolate(decoded, size=target_size, mode='bilinear', align_corners=False)

    # 2. Normalize with CLIP’s mean and std
    # These are the OpenAI CLIP mean/std values
    mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=decoded.device).view(1, 3, 1, 1)
    std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=decoded.device).view(1, 3, 1, 1)
    
    x = (x - mean) / std
    return x

def get_clip_features_folder(folder):
    features = []
    for fname in tqdm(sorted(os.listdir(folder))):
        if not fname.lower().endswith((".png", ".jpg", ".jpeg")):
            continue
        image = Image.open(os.path.join(folder, fname)).convert("RGB")
        inputs = clip_processor(images=image, return_tensors="pt").to(device)
        with torch.no_grad():
            feat = clip_model.get_image_features(**inputs)
        features.append(feat.cpu())
    return torch.cat(features)

