import torch
import os
import IPython
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch
import torchvision
import math as ms
import torchvision.transforms as transforms
import torch.optim as optim
import torchvision.models as models
import torchvision.datasets as datasets
import random as rd
from PIL import Image
from IPython import embed
import matplotlib.image as mpimg
import os


mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
#mean = [0.4914, 0.4822, 0.4465]
#std = [0.2023, 0.1994, 0.2010]
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
#normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
#train_loader = torch.utils.data.DataLoader(
#            datasets.CIFAR10(root='./data', train=True, transform=transforms.Compose([
#                # transforms.RandomCrop(32, 4),
#                transforms.ToTensor(),
#                normalize,
#            ]), download=True),
#            batch_size=1, shuffle=True,
#            num_workers=2, pin_memory=True)
#val_loader = torch.utils.data.DataLoader(
#        datasets.CIFAR10(root='./data', train=False, transform=transforms.Compose([
#            transforms.ToTensor(),
#            normalize,
#        ])),
#        batch_size=1, shuffle=False,
#        num_workers=2, pin_memory=True)
train_dataset = datasets.ImageFolder(
  '/cephfs/tiange/stylized_imagenet/val',
  transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    #transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
#    ColorAugmentation(),
    normalize,
  ]))
train_loader = torch.utils.data.DataLoader(
  train_dataset, batch_size=1, shuffle=False,
  num_workers=16, pin_memory=True)
val_loader = torch.utils.data.DataLoader(
  datasets.ImageFolder('/cephfs/tiange/imagenet/imagenet-2012/val', transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize,
  ])),
  batch_size=1, shuffle=False,
  num_workers=16, pin_memory=True)
def save_img(kk, path, name):
        kk = kk[0,:,:,:]
        k1 = kk[0,:,:]
        k2 = kk[1,:,:]
        k3 = kk[2,:,:]
        k1 = np.around((k1*std[0]+mean[0])*255)
        k2 = np.around((k2*std[1]+mean[1])*255)
        k3 = np.around((k3*std[2]+mean[2])*255)
        r = Image.fromarray(k1).convert('L')
        g = Image.fromarray(k2).convert('L')
        b = Image.fromarray(k3).convert('L')
        raw = Image.merge('RGB', (r, g, b))
        raw.save(path+name+'.png')

#from resnet import ResNet18
#from senet import SENet18
#from densenet import DenseNet121
#from resnet import ResNet50
#from vgg import VGG
#from pytorch_cifar.resnet_drop_12_09 import ResNet18 as n2
#from resnet_drop_12_09 import ResNet18
#from resnet_drop_012_07 import ResNet18
#from resnet import resnet18
import utils

#net1 = torch.nn.DataParallel(ResNet18().cuda()).eval()
#net1.load_state_dict(torch.load('./pytorch_cifar/checkpoint/ckpt3.t7')['net'])
#net2 = torch.nn.DataParallel(SENet18().cuda()).eval()
#net2.load_state_dict(torch.load('./pytorch_cifar/checkpoint/ckpt_senet.t7')['net'])
#net3 = torch.nn.DataParallel(DenseNet121().cuda()).eval()
#net3.load_state_dict(torch.load('./pytorch_cifar/checkpoint/ckpt_dense.t7')['net'])
#net4 = torch.nn.DataParallel(ResNet50().cuda()).eval()
#net4.load_state_dict(torch.load('./pytorch_cifar/checkpoint/ckpt_resnet50.t7')['net'])
#net5 = torch.nn.DataParallel(VGG('VGG19').cuda()).eval()
#net5.load_state_dict(torch.load('./pytorch_cifar/checkpoint/ckpt_vgg.t7')['net'])
#net1 = torch.nn.DataParallel(ResNet18().cuda()).eval()
#utils.load_state_ckpt('/cephfs/tiange/sn/checkpoint2/model.pth-100', net1) 
#net2 = torch.nn.DataParallel(resnet18(pretrained=True).cuda()).eval()

import torchvision.models as models
from resnet50_drop_012_09 import ResNet50 as net2
from densenet_drop_012_09 import DenseNet121 as net4
from senet import SENet18 as net5
from senet_drop_012_09 import SENet18 as net6
from vgg_drop_01_09 import VGG as v2
net1 = torch.nn.DataParallel(models.resnet50(pretrained=True).cuda()).eval()
net2 = torch.nn.DataParallel(net2().cuda()).eval()
utils.load_state_ckpt('/cephfs/tiange/DefectiveCNN/RandomShuffle/resnet50-checkpoint/model.pth-100', net2) 
net3 = torch.nn.DataParallel(models.densenet121(pretrained=True).cuda()).eval()
net4 = torch.nn.DataParallel(net4().cuda()).eval()
utils.load_state_ckpt('/cephfs/tiange/DefectiveCNN/RandomShuffle/densenet-checkpoint/model.pth-100', net4) 
net5 = torch.nn.DataParallel(net5().cuda()).eval()
utils.load_state_ckpt('/cephfs/tiange/DefectiveCNN/RandomShuffle/senet50-checkpoint/model.pth-100', net5) 
net6 = torch.nn.DataParallel(net6().cuda()).eval()
utils.load_state_ckpt('/cephfs/tiange/DefectiveCNN/RandomShuffle/senet50_01209-checkpoint/model.pth-100', net6) 
net7 = torch.nn.DataParallel(models.vgg19(pretrained=True).cuda()).eval()
net8 = torch.nn.DataParallel(v2('VGG19').cuda()).eval()
utils.load_state_ckpt('/cephfs/tiange/DefectiveCNN/RandomShuffle/vgg-checkpoint/model.pth-100', net8)
#net2 = torch.nn.DataParallel(resnet18().cuda()).eval()
#utils.load_state_ckpt('./model.pth-200', net2) 
#net6.load_state_dict(torch.load('./pytorch_cifar/checkpoint/ckpt_drop_12_09.t7')['net'])
#net_dict={'12_09':net6}
#net_dict={'dense':net3}
#net_dict={'vgg':net5, 'res50':net4}
#net_dict={'senet':net2, 'res18':net1}
#net_dict={'01207': net1}
#net_dict={'res18': net2}

#path_list = sorted(os.listdir('t_sample_testdata'),key=lambda x:int(x[:-6]))
criterion = nn.CrossEntropyLoss()
criterion.cuda()
tn = 50000

#parameters =[(1, 80, 64.0),
#             (2, 40, 64.0),
#             (4, 20, 64.0),
#             (8, 10, 64.0)]
#parameters =[(1, 20, 16.0),
#             (2, 10, 16.0),
#             (4,  5, 16.0),
#             (1, 40, 32.0),
#             (2, 20, 32.0),
#             (4, 10, 32.0)]

#parameters =[(1,  8,  4.0),
#             (2,  4,  4.0),
#             (4,  2,  4.0),
#             (1, 12,  8.0),
#             (2,  6,  8.0),
#             (4,  3,  8.0),
#             (1, 20, 16.0),
#             (2, 10, 16.0),
#             (4,  5, 16.0),
#             (1, 40, 32.0),
#             (2, 20, 32.0),
#             (4, 10, 32.0)]
#parameters =[
#             (1, 20, 16.0),
#             (1, 40, 32.0),
#             (1, 80, 64.0)]
#parameters = [(1, 200, 64.0), (1, 500, 64.0), (1, 1000,64.0), (1, 200,128.0), (1,500, 128.0), (1,1000,128.0)]

#fo = open('PGD/info.txt','a')

#up = torch.from_numpy(((np.ones([3,224,224]) - np.array(mean).reshape(3,1,1))/np.array(std).reshape(3,1,1)).reshape(1,3,224,224)).type(torch.FloatTensor).cuda()
#down = torch.from_numpy(((np.zeros([3,224,224]) - np.array(mean).reshape(3,1,1))/np.array(std).reshape(3,1,1)).reshape(1,3,224,224)).type(torch.FloatTensor).cuda()
#index = np.load('index2.npy')
count =0
overall_count =-1
index_arr = []

for inputs, labels in val_loader:
    overall_count +=1
    if count == tn:
        break
    inputs = Variable(inputs.cuda(), requires_grad=True)
    outputs = net1(inputs.cuda())
    score, predicted = torch.max(nn.Softmax()(outputs), 1)
    #predicted = index[predicted.data.cpu().numpy()[0]]
    if labels != predicted.data.cpu():
        continue
    if score < 0.90:
        continue

    inputs = Variable(inputs.cuda(), requires_grad=True)
    outputs = net2(inputs.cuda())
    score, predicted = torch.max(nn.Softmax()(outputs), 1)
    if labels != predicted.data.cpu():
        continue
    if score < 0.90:
        continue

    inputs = Variable(inputs.cuda(), requires_grad=True)
    outputs = net3(inputs.cuda())
    score, predicted = torch.max(nn.Softmax()(outputs), 1)
    if labels != predicted.data.cpu():
        continue
    if score < 0.90:
        continue

    inputs = Variable(inputs.cuda(), requires_grad=True)
    outputs = net4(inputs.cuda())
    score, predicted = torch.max(nn.Softmax()(outputs), 1)
    if labels != predicted.data.cpu():
        continue
    if score < 0.90:
        continue

    inputs = Variable(inputs.cuda(), requires_grad=True)
    outputs = net5(inputs.cuda())
    score, predicted = torch.max(nn.Softmax()(outputs), 1)
    if labels != predicted.data.cpu():
        continue
    if score < 0.90:
        continue

    inputs = Variable(inputs.cuda(), requires_grad=True)
    outputs = net6(inputs.cuda())
    score, predicted = torch.max(nn.Softmax()(outputs), 1)
    if labels != predicted.data.cpu():
        continue
    if score < 0.90:
        continue

    inputs = Variable(inputs.cuda(), requires_grad=True)
    outputs = net7(inputs.cuda())
    score, predicted = torch.max(nn.Softmax()(outputs), 1)
    if labels != predicted.data.cpu():
        continue
    if score < 0.90:
        continue

    inputs = Variable(inputs.cuda(), requires_grad=True)
    outputs = net8(inputs.cuda())
    score, predicted = torch.max(nn.Softmax()(outputs), 1)
    if labels != predicted.data.cpu():
        continue
    if score < 0.90:
        continue
    index_arr.append(overall_count)
    #save_img(inputs.data.cpu().numpy(), 'sample_01207/', str(count)+'_'+str(labels.data.cpu().numpy()[0]))
    #save_img(inputs.data.cpu().numpy(), 'sample_ori/', str(count)+'_'+str(labels.data.cpu().numpy()[0]))
    count +=1

embed()
count = 0
overall_count =-1 
predict_arr1=[]
predict_arr2=[]
label_arr=[]
for inputs, labels in train_loader:
    overall_count += 1
    if count == tn:
        break
    if overall_count not in index_arr:
        continue
    save_img(inputs.data.cpu().numpy(), 'sample_style/', str(count)+'_'+str(labels.data.cpu().numpy()[0]))
    count +=1
    continue
    inputs = Variable(inputs.cuda(), requires_grad=True)
    outputs = net1(inputs.cuda())
    score, predicted = torch.max(nn.Softmax()(outputs), 1)
    predicted = index[predicted.data.cpu().numpy()[0]]
    predict_arr1.append(predicted)
    #if labels != predicted.data.cpu():
    #if labels.data.cpu().numpy()[0] != predicted:
    #    continue
    #if score < 0.99:
    #    continue
    inputs = Variable(inputs.cuda(), requires_grad=True)
    outputs = net2(inputs.cuda())
    score, predicted = torch.max(nn.Softmax()(outputs), 1)
    predict_arr2.append(predicted.data.cpu().numpy()[0])
    label_arr.append(labels)
    #predicted = index[predicted.data.cpu().numpy()[0]]
    #if labels.data.cpu().numpy()[0] != predicted:
    #if labels != predicted.data.cpu():
    #    continue
    #if score < 0.99:
    #    continue
    #index_arr.append(count)
    save_img(inputs.data.cpu().numpy(), 'sample_01207/', str(count)+'_'+str(labels.data.cpu().numpy()[0]))
    #save_img(inputs.data.cpu().numpy(), 'sample_01207/', str(count)+'_'+'style')
    count +=1
embed()
