
import os
import torch
import torch.nn.functional as F

from PIL import Image
from tqdm import trange
from torchvision import transforms

from helper.extractor import VitExtractor

class DINO:
    def __init__(self, dino_id='dino_vitb8', device='cuda'):
        self.device = torch.device(device)
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((224, 224), interpolation=3),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
        print('[INFO] Loading DINO model:', dino_id)
        self.extractor = VitExtractor(model_name=dino_id, device=device)
        print('[INFO] DINO model loaded')

    def load_img(self, img_path):
        return self.transform(Image.open(os.path.join(img_path)))

    def compute_ssim_loss(self, data_dir):
        base_images, final_images = [], []
        for i in trange(300):
            base_path = os.path.join(data_dir, str(i), 'baseline.png')
            final_path = os.path.join(data_dir, str(i), 'final.png')
            base_flag = os.path.exists(base_path)
            final_flag = os.path.exists(final_path)
            if base_flag and final_flag:
                base_images.append(base_path)
                final_images.append(final_path)
        src_images = [self.load_img(p).to(self.device) for p in base_images]
        tgt_images = [self.load_img(p).to(self.device) for p in final_images]
        ssim_loss = 0.0
        get_keys = self.extractor.get_keys_self_sim_from_input
        for src_img, tgt_img in zip(src_images, tgt_images):
            with torch.no_grad():
                src_keys = get_keys(src_img.unsqueeze(0), layer_num=11)
                tgt_keys = get_keys(tgt_img.unsqueeze(0), layer_num=11)
            ssim_loss += F.mse_loss(src_keys, tgt_keys)
        dino_score = (ssim_loss / len(src_images)).item()
        print(f'data_dir {data_dir} | N {len(base_images)}) | dino {dino_score:.4f}')

dino = DINO()
for tok in [32, 64, 128, 256, 512]:
    dino.compute_ssim_loss(f'../gen/images_quant/tok{tok}')
    dino.compute_ssim_loss(f'../gen/images_bin/tok{tok}')