
import os
import torch
from tqdm import trange

from PIL import Image
from torchvision import transforms
from transformers import CLIPProcessor, CLIPModel

class CLIP:
    def __init__(self, clip_id='openai/clip-vit-large-patch14', device='cuda'):
        self.device = torch.device(device)
        print('[INFO] Loading CLIP model:', clip_id)
        self.processor = CLIPProcessor.from_pretrained(clip_id)
        self.resize = transforms.Resize((224, 224), interpolation=3)
        self.model = CLIPModel.from_pretrained(clip_id).to(self.device)
        self.model.eval()
        print('[INFO] CLIP model loaded')

    def compute_clip_scores(self, data_dir):
        with open('../data/prompt.txt', 'r') as f:
            base_prompts = f.read().split('\n')
        prompts, base_images, final_images = [], [], []
        for i in trange(300):
            base_path = os.path.join(data_dir, str(i), 'baseline.png')
            final_path = os.path.join(data_dir, str(i), 'final.png')
            base_flag = os.path.exists(base_path)
            final_flag = os.path.exists(final_path)
            if base_flag and final_flag:
                prompts.append(base_prompts[i % len(base_prompts)])
                base_images.append(self.resize(Image.open(base_path)))
                final_images.append(self.resize(Image.open(final_path)))
        with torch.no_grad():
            inputs = self.processor(text=prompts, images=base_images, return_tensors="pt", padding=True, truncation=True)
            output = self.model(**{k:v.to(self.device) for k, v in inputs.items()})
            base_scores = torch.diag(output.logits_per_text).detach().cpu()
            inputs = self.processor(text=prompts, images=final_images, return_tensors="pt", padding=True, truncation=True)
            output = self.model(**{k:v.to(self.device) for k, v in inputs.items()})
            final_scores = torch.diag(output.logits_per_text).detach().cpu()
        print(f'base {base_scores.mean().item()} (N = {len(base_images)})')
        print(f'final {final_scores.mean().item()} (N = {len(final_images)})')
        print(f'clip diff {(final_scores.mean().item() - base_scores.mean().item()):.2f}')

clip = CLIP()
for tok in [32, 64, 128, 256, 512]:
    clip.compute_clip_scores(f'../gen/images_quant/tok{tok}')
    clip.compute_clip_scores(f'../gen/images_bin/tok{tok}')
