import torch
import torch.optim as optim
import torch.nn as nn
from torchvision import datasets, transforms
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torchvision
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.autograd import Variable
import copy
import math
import numpy as np
import os
import argparse

#from networks import *

device = torch.device("cuda:1")
EPOCH_NUM = 200
BATCH_SIZE = 128
DATA_NUM = 50000

def conv3x3(in_planes, out_planes, stride=1):
    " 3x3 convolution with padding "
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


class PDEBlock(nn.Module):
    expansion=1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(PDEBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ELU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        #out=out+0.1*torch.std(out)*torch.randn_like(out)
        out = self.relu(out)

        return out


class ResNet_Cifar(nn.Module):

    def __init__(self, block, layers, num_classes=10):
        super(ResNet_Cifar, self).__init__()
        self.inplanes = 16
        self.conv1 = nn.Conv2d(4, 16, kernel_size=3, stride=1, padding=1, bias=False)
        #self.conv2 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        #self.weight = nn.Parameter(0.036*torch.randn(3,32,32))


        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ELU(inplace=True)
        self.layer1 = self._make_layer(block, 16, layers[0]) 
        self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
        self.avgpool = nn.AvgPool2d(8, stride=1)
        self.fc = nn.Linear(64 * block.expansion, num_classes)
        #self.fc1 = nn.Linear(1, num_classes)


        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion)
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)
    def forward(self, x,t):
        tt = torch.ones_like(x[:, :1, :, :]) * t   
        x = torch.cat([tt, x], 1)        
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.layer1(x)
        x = self.layer2(x) 
        x = self.layer3(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1) 
        x=self.fc(x)

        return x

def net20_cifar(**kwargs):
    model = ResNet_Cifar(PDEBlock, [3,3,3], **kwargs)
    return model

def net56_cifar(**kwargs):
    model = ResNet_Cifar(PDEBlock, [9, 9, 9], **kwargs)
    return model

def net110_cifar(**kwargs):
    model = ResNet_Cifar(PDEBlock, [18, 18, 18], **kwargs)
    return model


def loss_fn_kd(outputs, labels, teacher_outputs, T, alpha):
    KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1),
                             F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + F.cross_entropy(outputs, labels) * (1. - alpha)
    return KD_loss

h=0.1
tau=0.01
loss_func = nn.CrossEntropyLoss()
loss_fn = nn.MSELoss()
s=torch.ones(1).to(device)

def train(model, train_loader, epoch):
    correct_training = 0
    total_loss = 0
    model.train()

    if epoch < 100:
        lr = 0.1
    elif epoch < 150:
        lr = 0.01
    else:
        lr = 0.001

    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)

    for batch_idx, (inputs, labels) in enumerate(train_loader):
        #inputs=inputs+torch.randn_like(inputs)*0.01
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()

        outputs1 = model(inputs,s)
        gradt = (model(inputs,(1+tau)*s)-model(inputs,(1-tau)*s))/(2*tau)
        e=torch.std(inputs)*torch.randn_like(inputs).to(device)
        laplacex=(model(inputs+h*e,s)+model(inputs-h*e,s)-2*outputs1)/(h)**2
        loss1 =loss_func(outputs1, labels)+0.05*loss_fn(gradt,laplacex)
        outputs2 = model(inputs,0*s)
        gradt = (model(inputs,tau*s)-model(inputs,-tau*s))/(2*tau)
        e1=torch.std(inputs)*torch.randn_like(inputs).to(device)
        laplacex=(model(inputs+h*e1,0*s)+model(inputs-h*e1,0*s)-2*outputs2)/(h)**2
        loss2 =loss_func(outputs2, labels)+0.05*loss_fn(gradt,laplacex)
        loss =loss1+loss2

        loss.backward()

        optimizer.step()

        pred = outputs1.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct_training += pred.eq(labels.view_as(pred)).sum().item()
        total_loss += loss.item()
        accuracy = correct_training / ((batch_idx + 1) * len(inputs))

        if batch_idx % 50 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.4f}\tAccuracy: {:.4f}'.format(
                epoch + 1, (batch_idx + 1) * BATCH_SIZE, DATA_NUM, 100. * (batch_idx + 1) / len(train_loader),
                loss.item(), accuracy))


def test(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            s=Variable(torch.ones(1), requires_grad=True).to(device)
            outputs = model(inputs,s)
            test_loss += nn.CrossEntropyLoss()(outputs, labels).item()  # sum up batch loss
            pred = outputs.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(labels.view_as(pred)).sum().item()

    test_loss /= len(test_loader)

    print('\nTest set: Loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))
    return correct

def main():
    train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(root='./data', train=True, download=True,
                                                                transform=transforms.Compose(
                                                                    [transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor()
                                                                     ])),
                                               batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

    test_loader = torch.utils.data.DataLoader(datasets.CIFAR10(root='./data', train=False, download=True,
                                                               transform=transforms.Compose(
                                                                   [transforms.ToTensor(),
                                                                    ])),
                                              batch_size=BATCH_SIZE)

    model = net110_cifar().to(device)

    max_correct = 0
    for epoch in range(EPOCH_NUM):
        train(model, train_loader, epoch)
        correct = test(model, test_loader)
        if correct > max_correct:
            max_correct = correct
            torch.save(model.state_dict(), '110-5-10.pt')


if __name__ == '__main__':
    main()
