import torch.backends.cudnn as cudnn
import torch.utils.data as data
from model import *
from data import *
import sys
import numpy as np
from sklearn.metrics import roc_auc_score

ki = cls

def segmentation_auroc(mask, anomaly_maps):
    gt = mask.astype(np.int64)
    auroc = roc_auc_score(gt.reshape(-1), anomaly_maps.reshape(-1))
    return auroc

gan = Generator_test().cuda()
gps = [0,1]
gan = torch.nn.DataParallel(gan,  device_ids=gps)

size = 256
test_dataset = MVTecDataset_test(sss = size, resize = size)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1,
                                              pin_memory=True,shuffle=False,num_workers=4)

path ='./cls_last_auc_tile.pth'
gan.load_state_dict(torch.load(path), strict=False)


gan.eval()

m1 = []
i1 = []

images = torch.zeros([0, 3, size, size])
mask = torch.zeros([0, 1, size, size])
for batch_idx, (inputs, targets, mas) in enumerate(test_dataloader):

    inputs, mas = inputs.cuda(), mas.cuda()

    if images.size(0) < len(test_dataset) + 1:
        images = torch.cat([images, inputs.cpu()])
        mask = torch.cat([mask, mas.cpu()])

    X = inputs

    with torch.no_grad():
        # _,a_X = auto(X)
        a_X = X
        anomap = gan(a_X)

    m1.append(mas.cpu())
    i1.append(1 - anomap.cpu().detach())

m1 = torch.cat(m1, dim=0)
i1 = torch.cat(i1, dim=0)

pred_mask = i1.reshape(-1).numpy()
GT_mask = m1.reshape(-1).numpy()

score = segmentation_auroc(GT_mask, pred_mask)

torch.save(gan.state_dict(), 'tile.pth')

print('\x1b[1;31mAUC :\x1b[1;m'
      , 'seg score : ', score, '\n'
      )

tot = torch.cat([i1.cpu().detach(), mask], dim=0)
vutils.save_image(tot, './test/mask_samples_{}.png'.format(ki), normalize=True, nrow=len(test_dataset))




