import os
import math
import torch as t
import torch.nn as nn
t.set_num_threads(1)
t.backends.cudnn.benchmark=False
t.backends.cudnn.deterministic=True
import torch.nn.functional as F
import argparse
import pandas as pd
from torch.distributions import Categorical
from torchvision import datasets, transforms
from timeit import default_timer as timer
from ap_spec import APSpec
from bnn.transforms import Identity, Blur


import bnn
import models.resnet8
import models.lenet

parser = argparse.ArgumentParser()
parser.add_argument('output_filename',     type=str,   help='output filename', nargs='?', default='test')
parser.add_argument('--dataset',           type=str,   help='dataset', nargs='?', default='mnist')
parser.add_argument('--ap_lower',          type=str,   help='method', nargs='?', default='gi')
parser.add_argument('--ap_top',            type=str,   help='method', nargs='?', default='gi')
parser.add_argument('--model',             type=str,   help='model', nargs='?', default='lenet')
parser.add_argument('--lr',                type=float, help='learning rate', nargs='?', default=1E-3)
parser.add_argument('--seed',              type=int,   help='seed', nargs='?', default=0)
parser.add_argument('--L',                 type=float, help='temperature scaling', nargs='?', default=1.)
parser.add_argument('--temperL',           action='store_true',  help='temper beta', default=False)
parser.add_argument('--test_samples',      type=int,   help='samples of the weights', nargs='?', default=1)
parser.add_argument('--test_runs',         type=int,   help='samples of the weights', nargs='?', default=1)
parser.add_argument('--train_samples',     type=int,   help='samples of the weights', nargs='?', default=1)
parser.add_argument('--transform_weights', type=str,   help='Identity or Blur', nargs='?', default="Identity")
parser.add_argument('--transform_inputs',  type=str,   help='Identity or Blur', nargs='?', default="Identity")
parser.add_argument('--prior',             type=str,   help='Prior', nargs='?', default="ScalePrior")
parser.add_argument('--device',            type=str,   help='Device', nargs='?', default="cuda")
parser.add_argument('--batch',             type=int,   help='Batch size', nargs='?', default=500)
parser.add_argument('--subset',            type=int,   help='subset of data size', nargs='?', default=500)
parser.add_argument('--epochs_per_period', type=int,   help='default uses subset', nargs='?')
parser.add_argument('--periods',           type=int,   help='number of periods (blocks of epochs)', nargs='?', default=400)
parser.add_argument('--lr_steps',          type=int,   help='learning rate steps', nargs='*')
parser.add_argument('--print', action='store_true', default=False)
parser.add_argument('--noshuffle', action='store_true', help='dont shuffle training data')
parser.add_argument('--aug', type=str, nargs='?', help='data augmentation, aug or noaug', default="noaug")
args = parser.parse_args()

device = args.device


if args.lr_steps is None:
    args.lr_steps = []

if args.subset is None:
    if args.epochs_per_period is None:
        args.epochs_per_period = 1
else:
    if args.epochs_per_period is None:
        args.epochs_per_period = 50000 // args.subset

ap_spec = APSpec(args.ap_lower, args.ap_top)
t.manual_seed(args.seed)

cifar10_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
cifar10_augment = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
])

mnist_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
])

#MNIST data augment from https://www.kaggle.com/tonysun94/pytorch-1-0-1-on-mnist-acc-99-8
mnist_augment = transforms.RandomAffine(degrees=45, translate=(0.1, 0.1), scale=(0.8, 1.2)),


dataset, transform, augment = {
    'mnist'    : (datasets.MNIST,    mnist_transform, mnist_augment),
    'cifar10'  : (datasets.CIFAR10,  cifar10_transform, cifar10_augment),
}[args.dataset]
transform_train = {
    'aug'   : transforms.Compose([augment, transform]),
    'noaug' : transform
}[args.aug]
train_dataset = dataset('data', train=True, download=True, transform=transform_train)
num_classes = max(train_dataset.targets)+1
if args.subset is not None:
    train_dataset = t.utils.data.Subset(train_dataset, range(args.subset))
test_dataset  = dataset('data', train=False, transform=transform)
print(num_classes, flush=True)
print(len(train_dataset), flush=True)

train_loader = t.utils.data.DataLoader(train_dataset, batch_size=args.batch, shuffle=(not args.noshuffle))
test_loader = t.utils.data.DataLoader(test_dataset, batch_size=args.batch)
in_shape = next(iter(train_loader))[0].shape[-3:]

inducing_data, inducing_targets = next(iter(train_loader))
inducing_targets = (t.arange(num_classes) == inducing_targets[:, None]).float() 
inducing_batch = 500

kwargs = {
    'transform_inputs'  : getattr(bnn.transforms, args.transform_inputs),
    'transform_weights' : getattr(bnn.transforms, args.transform_weights),
    'prior'             : getattr(bnn.priors, args.prior),
}
kwargs_lower = {
    'fac' : dict(kwargs),
    'gfac': dict(kwargs),
    'rand': dict(kwargs),
    'gi'  : dict(kwargs, log_prec_lr=3., inducing_batch=500),
    'li'  : dict(kwargs, log_prec_lr=3.),
    'det' : dict(kwargs, unit_std=False)
}[args.ap_lower]
kwargs_top = {
    'fac' : dict(kwargs),
    'gfac': dict(kwargs),
    'rand': dict(kwargs),
    'gi'  : dict(kwargs, log_prec_lr=3., log_prec_init=0., inducing_targets=inducing_targets, inducing_batch=500),
    'li'  : dict(kwargs, log_prec_lr=3., log_prec_init=0.),
    'det' : dict(kwargs, unit_std=False)
}[args.ap_top]


net = {
    'resnet8' : models.resnet8.net,
    'lenet' : models.lenet.net,
}[args.model](ap_spec, in_shape, num_classes, kwargs_lower, kwargs_top)
if (args.ap_lower == 'gi') or (args.ap_top == 'gi'):
    net = nn.Sequential(
        bnn.InducingAdd(inducing_batch, inducing_data=inducing_data), 
        net, 
        bnn.InducingRemove(inducing_batch)
    )
net = net.to(device=device)


#initialize with a forward pass
net(next(iter(train_loader))[0].to(device).unsqueeze(0))
opt = t.optim.Adam(net.parameters(), lr=args.lr)


epoch = []
elbo = []
train_ll = []
train_KL = []
test_ll = []
train_correct = []
test_correct = []


def train(epoch):
    iters = 0
    total_elbo = 0.
    total_ll = 0.
    total_Lll = 0.
    total_KL = 0.
    total_correct = 0.

    if args.temperL and epoch < 100:
        tempered_beta = 0.1*math.floor((epoch-1)/10.)/args.L
    else:
        tempered_beta = 1/args.L

    beta = 1/args.L

    for data, target in train_loader:
        opt.zero_grad()
        data, target = data.to(device), target.to(device)
        data = data.expand(args.train_samples, *data.shape)
        outputs = net(data)
        logPQw = bnn.logpq(net)
        outputs = outputs.squeeze(-1).squeeze(-1)

        dist = Categorical(logits=outputs)
        ll = dist.log_prob(target).mean()
        nloss = ll.mean() + tempered_beta * logPQw.mean()/len(train_dataset)  # tempered ELBO
        elbo = ll.mean() + beta * logPQw.mean() / len(train_dataset)
        (-nloss*len(train_dataset)).backward()
        opt.step()

        output = outputs.log_softmax(-1).logsumexp(0) - math.log(outputs.shape[0])
        pred = output.argmax(dim=-1, keepdim=True)
        correct = pred.eq(target.view_as(pred)).float().mean()

        iters         += 1
        total_elbo    += elbo.item()
        total_ll      += ll.item()
        total_KL      -= (beta*logPQw.mean()/len(train_dataset)).item()
        total_correct += correct.item()

    return (total_elbo/iters, total_ll/iters, total_KL/iters, total_correct/iters)


def test():
    iters = 0
    total_elbo = 0.
    total_ll = 0.
    total_correct = 0.
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        data = data.expand(args.test_samples, *data.shape)

        #run net multiple times, average
        outputs = []
        for i in range(args.test_runs):
            _output = net(data).squeeze(-1).squeeze(-1)
            outputs.append(_output)

        outputs = t.cat(outputs, 0)

        output = outputs.log_softmax(-1).logsumexp(0) - math.log(outputs.shape[0])

        ll = -F.nll_loss(output, target)

        pred = output.argmax(dim=-1, keepdim=True)
        correct = pred.eq(target.view_as(pred)).float().mean()

        iters         += 1
        total_ll      += ll.item()
        total_correct += correct.item()

    return (total_ll/iters, total_correct/iters)


scheduler = t.optim.lr_scheduler.MultiStepLR(opt, args.lr_steps, gamma=0.1)

_epoch = 0
for period in range(args.periods):
    start_time = timer()
    for _ in range(args.epochs_per_period):
        _epoch += 1
        _elbo, _train_ll, _train_KL, _train_correct = train(_epoch)

    epoch.append(_epoch)
    elbo.append(_elbo)
    train_ll.append(_train_ll)
    train_KL.append(_train_KL)
    train_correct.append(_train_correct)

    with t.no_grad():
        _test_ll, _test_correct = test()
    test_ll.append(_test_ll)
    test_correct.append(_test_correct)

    time = timer() - start_time
    if args.print:
        print(f"{os.path.basename(args.output_filename):<32}, period:{period:03d}, time:{time: 3.1f}, elbo:{_elbo:.3f}, KL:{_train_KL:.3f}, ll:{_test_ll:.3f}, train_c:{_train_correct:.3f}, test_c:{_test_correct:.3f}", flush=True)

    scheduler.step()

pd.DataFrame({
    'epoch' : epoch,
    'elbo' : elbo,
    'train_ll' : train_ll,
    'train_KL' : train_KL,
    'test_ll' : test_ll,
    'train_correct' : train_correct,
    'test_correct' : test_correct,
    'model': args.model,
    'dataset': args.dataset,
    'ap_lower' : args.ap_lower,
    'ap_top' : args.ap_top,
    'seed' : args.seed,
}).to_csv(args.output_filename)
