from ops import FFHQDataset
from torchvision import transforms
from torch.utils.data import DataLoader
import torch
from torchmetrics import JaccardIndex
from tqdm import tqdm
import numpy as np
from ops import LinearOperator
from roomsegmentation import SegmentationModule, ModelBuilder

label_folder = ''
predi_folder = ''
num_classes = 150
device = torch.device('cuda:0')

class Segmentation(LinearOperator):
    def __init__(self,device):
        self.encoder = ModelBuilder.build_encoder(arch="resnet50dilated",fc_dim=2048,weights="./bins/encoder_epoch_20_B.pth").to('cuda')
        self.decoder = ModelBuilder.build_decoder(arch="ppm_deepsup",fc_dim=2048,num_class=150,weights="./bins/decoder_epoch_20_B.pth",use_softmax=True).to('cuda')
        for name, param in self.encoder.named_parameters():
            param.requires_grad = False
        for name, param in self.decoder.named_parameters():
            param.requires_grad = False
        self.transform = transforms.Compose([transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
    
    def forward(self, data, **kwargs):
        data = (data + 1) / 2.0
        data = self.transform(data)
        pred = self.decoder(self.encoder(data, return_feature_maps=True), segSize=(256,256))
        assert 'mode' in kwargs
        if kwargs['mode'] == 'init':
            return torch.argmax(pred, dim=1, keepdim=True)
        else:
            return pred
    
    def transpose(self, data):
        return data

model = Segmentation(device=device)
transform = transforms.Compose([transforms.ToTensor()])

label_set = FFHQDataset(label_folder, transform)
predi_set = FFHQDataset(predi_folder, transform)

label_loader = list(DataLoader(label_set, 1))
predi_loader = list(DataLoader(predi_set, 1))

jaccard = JaccardIndex(task='multiclass', num_classes=num_classes, average='micro').cuda()

mious = []
for i in tqdm(range(len(label_loader))):
    label_img = ((label_loader[i] * num_classes).to(torch.int32) - 1).cuda()[:,0:1]
    predi_img = predi_loader[i].cuda()
    predi_seg = model.forward(predi_img, mode='init')
    miou = jaccard(label_img, predi_seg)
    mious.append(miou.item())

print("miou: {0:.4f}".format(np.mean(mious)))
