import torch
device = torch.device('cuda:0')

from ops import FFHQDataset
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torchmetrics import JaccardIndex
from tqdm import tqdm
import numpy as np
from ops import Layout
import os

label_folder = ''
image_folder = ''
num_classes = 5
model = Layout(weight_path='./bins/model_retrained.ckpt', device=device)
transform1 = transforms.Compose([transforms.ToTensor()])
transform2 = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

label_set = FFHQDataset(label_folder, transform1)
image_set = FFHQDataset(image_folder, transform2)

label_loader = list(DataLoader(label_set, 1))
image_loader = list(DataLoader(image_set, 1))

jaccard = JaccardIndex(task='multiclass', num_classes=num_classes, average ='macro')
mious = []
for i in tqdm(range(len(label_loader))):
    fname = str(i).zfill(5) + '.png'
    label_img = ((label_loader[i] * num_classes).to(torch.int32) - 1).cuda()[:,0:1]
    predi_img = image_loader[i].cuda() 
    predi_label = model.forward(predi_img + torch.randn_like(predi_img) * 0.2, mode='init').to(torch.int32)
    miou = jaccard(label_img.cpu(), predi_label.cpu())
    mious.append(miou)

print("miou: {0:.4f}".format(np.mean(mious)))
