#!/usr/bin/env python
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.optim import SGD
import torchvision
import torchvision.transforms as transforms

#from models.resnet import ResNet18
from models.vgg import VGG

# arguments
parser = argparse.ArgumentParser(description='Bayesian Inference')
parser.add_argument('--model', type=str, required=True)
parser.add_argument('--defense', type=str, required=True)
parser.add_argument('--data', type=str, required=True)
parser.add_argument('--max_norm', type=str, required=True)
opt = parser.parse_args()


# dataset
print('==> Preparing data..')
if opt.data == 'cifar10':
    nclass = 10
    img_width = 32
elif opt.data == 'stl10':
    nclass = 10
    img_width = 96
elif opt.data == 'cifar100':
    nclass = 100
    img_width = 32
elif opt.data == 'tiny-imagenet':
    nclass = 200
    img_width = 64
else:
    raise ValueError('invlid dataset: {}'.format(opt.data))

# load model
if opt.model == 'vgg':
    if opt.defense in ('adv'):
        from models.vgg import VGG
        net = nn.DataParallel(VGG('VGG16', nclass, img_width=img_width), device_ids=range(1))
    elif opt.defense in ('adv_vi'):
        from models.vgg_vi import VGG
        net = nn.DataParallel(VGG(0.1, 1.0, 0.1, 'VGG16', nclass, img_width=img_width), device_ids=range(1))
    elif opt.defense in ('adv_hvi'):
        from models.vgg_hvi import VGG
        net = nn.DataParallel(VGG(0.1, 1.0, 0.1, 'VGG16', nclass, img_width=img_width), device_ids=range(1))
elif opt.model == 'aaron':
    if opt.defense in ('adv'):
        from models.aaron import Aaron
        net = nn.DataParallel(Aaron(nclass), device_ids=range(1))
    elif opt.defense in ('adv_vi'):
        from models.aaron_vi import Aaron
        net = nn.DataParallel(Aaron(1.0, 1.0, 1.0, nclass), device_ids=range(1))
    elif opt.defense in ('adv_hvi'):
        from models.aaron_hvi import Aaron
        net = nn.DataParallel(Aaron(1.0, 1.0, 1.0, nclass), device_ids=range(1))
else:
    raise ValueError('invalid opt.model')


if opt.max_norm[0] == 0:
    opt.max_norm[0] = int(opt.max_norm[0])

net.load_state_dict(torch.load('./checkpoint/{}_{}_{}_{}.pth'.format(opt.data, opt.model, opt.max_norm[0], opt.defense)))
net.cuda()
net.eval() # must set to evaluation mode
loss_f = nn.CrossEntropyLoss()
softmax = nn.Softmax(dim=1)
cudnn.benchmark = True

mu_weight = 0
sigma_weight = 0
len_ = 0

if opt.defense in ('adv_hvi'):
    numbering = 0
    kl_weight = []
    for layer in net.module.features:
        if 'Rand' in type(layer).__name__:

            mu_weight = layer.mu_weight
            sigma_weight = layer.sigma_weight

            sig_weight = torch.exp(sigma_weight)
            kl_weight.append(torch.mean(1/2*torch.log(1+mu_weight**2/sig_weight**2)).item())
            print(type(layer).__name__ + ' : ', torch.mean(1/2*torch.log(1+mu_weight**2/sig_weight**2)).item())

    layer = net.module.classifier
    mu_weight = layer.mu_weight
    sigma_weight = layer.sigma_weight

    sig_weight = torch.exp(sigma_weight)
    kl_weight.append(torch.mean(1/2*torch.log(1+mu_weight**2/sig_weight**2)).item())
    print('[adv-hvi] KLD: ', sum(kl_weight)/len(kl_weight))


elif opt.defense in ('adv_vi'):
    kl_weight = []
    import math
    for layer in net.module.features:
        if 'Rand' in type(layer).__name__:

            len_ += 1
            mu_weight = layer.mu_weight
            sigma_weight = layer.sigma_weight
            sig_weight = torch.exp(sigma_weight)

            kl_weight.append(torch.mean(math.log(0.1) - sigma_weight + (sig_weight**2 + mu_weight**2) / (2*0.1**2 ) - 0.5).item())
            print(type(layer).__name__ + ' : ', torch.mean(math.log(0.1)- sigma_weight + (sig_weight**2 + mu_weight**2) / (2*0.1**2 ) - 0.5).item())

    layer = net.module.classifier
    len_ += 1
    mu_weight = layer.mu_weight
    sigma_weight = layer.sigma_weight
    sig_weight = torch.exp(sigma_weight)

    kl_weight.append(torch.mean( math.log(0.1)- sigma_weight + (sig_weight**2 + mu_weight**2) / (2*0.1**2) - 0.5).item())
    print('[adv-vi] KLD: ', sum(kl_weight)/len(kl_weight))

     
len_ = 0
numbering = 0
mu_weight = []
sigma_weight = []
for idx, (name, param) in enumerate(net.named_parameters()):
    if 'sigma_weight' in name:
        len_ += 1
        sigma_weight.append(torch.mean(torch.exp(param.data)).item())

    elif 'mu_weight' in name:
        mu_weight.append(torch.sum(param.data).item())
        numbering += param.data.numel()        
    else:
        continue

avg_mu = sum(mu_weight)/numbering
avg_sigma = sum(sigma_weight)/len_

print('avg_mu: {:.4f}, avg_sigma: {:.4f}'.format(avg_mu, avg_sigma))
