import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))))
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
path = os.path.dirname(os.path.abspath(os.path.dirname(os.path.abspath(os.path.dirname(__file__)))))
from torch.autograd import Variable

model_path = os.path.join(path,'model/ciafar10_checkpoints/cifar10_wide10_linf_eps8.pth')

def multiple_adv_loss(model,
                x_natural,
                y,
                optimizer,
                step_size=0.03,
                epsilon=0.1,
                perturb_steps=10,
                beta=1.0,
                distance='l_inf', logger=None, tem=0.1, conf=0.2):

    # define KL-loss
    criterion_kl = nn.KLDivLoss(size_average=False)
    model.eval()
    batch_size = len(x_natural['x'])
    x_adv1= x_natural.copy()
    # generate adversarial examplei

    x_adv1['x'].detach() + 0.001 * torch.randn(x_adv1['x'].shape).cuda().detach()
    #x_adv2 = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).cuda().detach()
    for _ in range(perturb_steps):
        x_adv1['x'].requires_grad_()
        with torch.enable_grad():
            loss_kl = criterion_kl(F.log_softmax(model(x_adv1, None, 0.055)[0].squeeze(0), dim=1),
                                       F.softmax(model(x_natural, None, 0.055)[0].squeeze(0), dim=1))
        grad = torch.autograd.grad(loss_kl, [x_adv1['x']])[0]
        x_adv1['x'] = x_adv1['x'].detach() + step_size * torch.sign(grad.detach())
        x_adv1['x'] = torch.min(torch.max(x_adv1['x'], x_natural['x'] - epsilon), x_natural['x'] + epsilon)
        # x_adv1['x'] = torch.clamp(x_adv1['x'], 0.0, 1.0)

    '''
    for _ in range(perturb_steps):
        x_adv2.requires_grad_()
        with torch.enable_grad():
            loss_kl = criterion_kl(F.log_softmax(model(x_adv2, None, 0.055), dim=1),
                                       F.softmax(model(x_natural, None, 0.055), dim=1))
        grad = torch.autograd.grad(loss_kl, [x_adv2])[0]
        x_adv2 = x_adv1.detach() + step_size * torch.sign(grad.detach())
        x_adv2 = torch.min(torch.max(x_adv2, x_natural - epsilon), x_natural + epsilon)
        x_adv2 = torch.clamp(x_adv2, 0.0, 1.0)
    '''
    model.train()

    # x_adv1['x'] = Variable(torch.clamp(x_adv1['x'], 0.0, 1.0), requires_grad=False)
    # x_adv2 = Variable(torch.clamp(x_adv2, 0.0, 1.0), requires_grad=False)
   
    #x_adv2 = x_natural
    #x_adv2.requires_grad = True
    #outputs = model(x_adv2)
    #loss = F.nll_loss(outputs, y)
    #loss.backward()

    #x_adv2 = fgsm_attack(x_adv2, 8/255.0, x_adv2.grad.data)
    
    # zero gradient
    optimizer.zero_grad()
    # calculate robust loss
    logits,_,_ = model(x_natural,None,0.0005)
    loss_natural = F.cross_entropy(logits, y)
    y_adv,_,_ = model(x_adv1,None,0.0005)

    consit_loss = consistancy_loss(y_adv, logits, tem, conf, logger)
    
    #print(loss_robust)
    loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(y_adv, dim=1),
                                                    F.softmax(logits, dim=1))

    #loss_robust += (1.0 / batch_size) * criterion_kl(F.log_softmax(model(x_adv2), dim=1),
    #                                                F.softmax(model(x_natural), dim=1))


    #print(loss_robust, consit_loss, loss_natural)
    #loss = loss_natural + beta * (loss_robust + 0.1*consit_loss)
    loss = loss_natural + beta * (consit_loss)
    return loss

def consistancy_loss(y_adv, logits, tem, conf, logger):

    y_1 = F.softmax(y_adv, dim=1)
    y_2 = F.softmax(logits, dim=1)
    avg_y = (y_1 + y_2)/2

    
    num = torch.pow(avg_y, 1./tem)
    deno = torch.sum(torch.pow(avg_y, 1./tem), dim=1, keepdim=True) + 1e-8
    sharp_p = (num / deno).detach()
    
    loss = 0
    
    loss += torch.mean((-sharp_p * torch.log(y_1+1e-8)).sum(1)[avg_y.max(1)[0] > conf])
    loss += torch.mean((-sharp_p * torch.log(y_2+1e-8)).sum(1)[avg_y.max(1)[0] > conf])
    return loss/2


def load(device):
    model = Robust_Overfitting(device)
    model.load()
    model.name = 'robust_overfitting'
    return model

class Robust_Overfitting(torch.nn.Module):
    def __init__(self, model, device, model_path,):
        torch.nn.Module.__init__(self)
        self.device = device
        self.model_path = model_path
        # self.model = WideResNet(depth=34,
        #                         num_classes=10,
        #                         widen_factor=10)
        # self.model = torch.nn.DataParallel(self.model).to(device)
        self.model = model
        self.model = torch.nn.DataParallel(self.model).to(device)
        self._mean_torch = torch.tensor((0.4914, 0.4822, 0.4465)).view(3,1,1).to(device)
        self._std_torch = torch.tensor((0.2471, 0.2435, 0.2616)).view(3,1,1).to(device)

    def forward(self, x):
        input_var = x.to(self.device)
        labels = self.model(input_var)
        return labels

    def load(self):
        checkpoint = torch.load(self.model_path)
        self.model.load_state_dict(checkpoint)
        self.model.eval()


class BasicBlock(nn.Module):
    def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
        super(BasicBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_planes)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.droprate = dropRate
        self.equalInOut = (in_planes == out_planes)
        self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
                               padding=0, bias=False) or None

    def forward(self, x):
        #if type(x[0])!=list:
        #    all_out = [x[0]]
        #else:
        #    all_out = x[0]
        #x = x[1]
        all_out = x[0]

        x = x[1]
        if not self.equalInOut:
            x = self.relu1(self.bn1(x))
        else:
            out = self.relu1(self.bn1(x))

        out = self.conv1(out if self.equalInOut else x)
        out1 = (out**2).sum(1)
        out1 = F.interpolate(out1[None,:], size =(14, 14), mode='bilinear')
        out = self.relu2(self.bn2(out))
        if self.droprate > 0:
            out = F.dropout(out, p=self.droprate, training=self.training)
        out = self.conv2(out)
        out2 = (out**2).sum(1)
        out2 = F.interpolate(out2[None,:], size =(14, 14), mode='bilinear')
        all_out = torch.cat((all_out, out1, out2))
        return all_out, torch.add(x if self.equalInOut else self.convShortcut(x), out)

class NetworkBlock(nn.Module):
    def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
        super(NetworkBlock, self).__init__()
        self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
    def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
        layers = []
        for i in range(int(nb_layers)):
            layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
        return nn.Sequential(*layers)
    def forward(self, x):
        return self.layer(x)


class WideResNet(nn.Module):
    def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0):
        super(WideResNet, self).__init__()
        nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
        assert((depth - 4) % 6 == 0)
        n = (depth - 4) / 6
        block = BasicBlock
        # 1st conv before any network block
        self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
                               padding=1, bias=False)
        # 1st block
        self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
        # 2nd block
        self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
        # 3rd block
        self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
        # global average pooling and classifier
        self.bn1 = nn.BatchNorm2d(nChannels[3])
        self.relu = nn.ReLU(inplace=True)
        self.fc = nn.Linear(nChannels[3], num_classes)
        self.nChannels = nChannels[3]

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()
    
    def forward(self, x):
        out = self.conv1(x)
        all_out = out
        all_out = (all_out**2).sum(1)
        all_out = F.interpolate(all_out[None,:], size =(14, 14), mode='bilinear')
        all_out, out = self.block1((all_out, out))
        all_out, out = self.block2((all_out,out))
        all_out, out = self.block3((all_out,out))
        out = self.relu(self.bn1(out))
        out = F.avg_pool2d(out, 8)
        out = out.view(-1, self.nChannels)
        return all_out, self.fc(out)
if __name__ == '__main__':
    if not os.path.exists(model_path):
        if not os.path.exists(os.path.dirname(model_path)):
            os.makedirs(os.path.dirname(model_path), exist_ok=True)
        url = 'https://drive.google.com/file/d/1b4ikBAFDevxGskNtG-GU8FDHOW7j61-2/view'
        print('Please download "{}" to "{}".'.format(url, model_path))
