import pretrainedmodels
import torch 
from torchvision import models
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transforms as T
import os
from tqdm import tqdm
from Normalize import Normalize
from torchvision.datasets import ImageFolder
import pretrainedmodels.utils as utils
from collections import Counter
import timm
import torch.nn.functional as F

mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

res152 = torch.nn.Sequential(Normalize(mean, std),models.resnet152(pretrained=True).eval()).cuda()
inc_v3 = torch.nn.Sequential(Normalize(mean, std),models.inception_v3(pretrained=True).eval()).cuda()
dense121 = torch.nn.Sequential(Normalize(mean, std),models.densenet121(pretrained=True).eval()).cuda()
dense169 = torch.nn.Sequential(Normalize(mean, std),models.densenet169(pretrained=True).eval()).cuda()
shufflenet = torch.nn.Sequential(Normalize(mean, std),models.shufflenet_v2_x1_0(pretrained=True).eval()).cuda()
squee = torch.nn.Sequential(Normalize(mean, std), models.squeezenet1_1(pretrained=True).eval()).cuda()
mobile = torch.nn.Sequential(Normalize(mean, std), models.mobilenet_v2(pretrained=True).eval()).cuda()
wrn50 = torch.nn.Sequential(Normalize(mean, std), models.wide_resnet50_2(pretrained=True).eval()).cuda()
wrn101 = torch.nn.Sequential(Normalize(mean, std), models.wide_resnet101_2(pretrained=True).eval()).cuda()
vgg = torch.nn.Sequential(Normalize(mean, std), models.vgg19_bn(pretrained=True).eval()).cuda()
se = torch.nn.Sequential(Normalize(mean, std), timm.create_model('legacy_senet154', pretrained=True).eval()).cuda()
pna = torch.nn.Sequential(Normalize(mean, std), timm.create_model('pnasnet5large', pretrained=True).eval()).cuda()

non_res, non_v3, non_dense121, non_dense169, non_shuff, non_squee, non_mob, non_rext, non_wrn50, non_wrn101, non_se, non_vgg, non_pna = 0, 0, 0, 0, 0, 0, 0,0,0,0,0,0, 0
pre_res, pre_v3, pre_dense, pre_shuff, pre_squee, pre_mob, pre_rext = 0, 0, 0, 0, 0, 0, 0
inc_list, res_list, dense_list, shuff_list, squee_list, mob_list, rext_list = [], [], [], [], [], [], []
data_path = 'output/'
val_loader = torch.utils.data.DataLoader(
    ImageFolder(data_path, T.Compose([
        T.ToTensor()])),
    batch_size=64, shuffle=False,
    num_workers=8, pin_memory=True)

with torch.no_grad():
    for cur_img, gt_cpu in tqdm(val_loader):
        cuda_img = cur_img.cuda()
        gt = gt_cpu.cuda()
        batch_size = len(gt)
        non_res += (res152(cuda_img).max(1)[1] != gt).detach().cpu().sum()
        non_v3 += (inc_v3(cuda_img).max(1)[1] != gt).detach().cpu().sum()
        non_dense121 += (dense121(cuda_img).max(1)[1] != gt).detach().cpu().sum()
        non_dense169 += (dense169(cuda_img).max(1)[1] != gt).detach().cpu().sum()
        non_shuff += (shufflenet(cuda_img).max(1)[1] != gt).detach().cpu().sum()
        non_squee += (squee(cuda_img).max(1)[1] != gt).detach().cpu().sum()
        non_mob += (mobile(cuda_img).max(1)[1] != gt).detach().cpu().sum()
        # non_rext += (resnext50_32x4d(cuda_img).max(1)[1] != gt).detach().cpu().sum()
        non_wrn50 += (wrn50(cuda_img).max(1)[1] != gt).detach().cpu().sum()
        non_wrn101 += (wrn101(cuda_img).max(1)[1] != gt).detach().cpu().sum()
        non_se += (se(cuda_img).max(1)[1] != gt).detach().cpu().sum()
        non_vgg += (vgg(cuda_img).max(1)[1] != gt).detach().cpu().sum()
        non_pna += (pna(cuda_img).max(1)[1] != gt).detach().cpu().sum()


print('vgg = ', non_vgg)
print('Inc_v3 = ', non_v3)
print('res152 = ', non_res)
print('dense121 = ', non_dense121)
print('dense169 = ', non_dense169)
print('wrn50 = ', non_wrn50)
print('wrn101 = ', non_wrn101)
print('se = ', non_se)
print('pna = ', non_pna)
print('shuffle = ', non_shuff)
print('squeeze = ', non_squee)
print('mobile = ', non_mob)





