import math
import torch as t
import torch.nn as nn
t.set_num_threads(1)
import argparse
import pandas as pd
from torch.distributions import Normal, MultivariateNormal
from timeit import default_timer as timer
from ap_spec import APSpec

import bnn

parser = argparse.ArgumentParser()
parser.add_argument('output_filename', type=str,   help='output filename', nargs='?', default='test')
parser.add_argument('--ap_lower',          type=str,   help='method', nargs='?', default='fac')
parser.add_argument('--ap_top',            type=str,   help='method', nargs='?', default='gi')
parser.add_argument('--depth',         type=int,   help='network depth', nargs='?', default=2)
parser.add_argument('--width',         type=int,   help='network width', nargs='?', default=10)
parser.add_argument('--test_samples',      type=int,   help='samples of the weights', nargs='?', default=10)
parser.add_argument('--test_runs',         type=int,   help='samples of the weights', nargs='?', default=1)
parser.add_argument('--train_samples',     type=int,   help='samples of the weights', nargs='?', default=1)
parser.add_argument('--prior',             type=str,   help='NealPrior or InsanePrior', nargs='?', default="NealPrior")
parser.add_argument('--device',            type=str,   help='NealPrior or InsanePrior', nargs='?', default="cpu")
parser.add_argument('--batch',             type=int,   help='NealPrior or InsanePrior', nargs='?', default=500)
parser.add_argument('--lr',                type=float, help='learning rate', nargs='?', default=1E-2)
parser.add_argument('--seed',          type=int,   help='random seed', default=0)
args = parser.parse_args()

ap = APSpec(args.ap_lower, args.ap_top)

t.manual_seed(args.seed)

in_features = 5
out_features = 1
train_batch = 1000
test_batch = 100
inducing_batch = 10

s2 = 0.1  # noise variance
W = t.randn(in_features, out_features, device=args.device) / math.sqrt(in_features)

X_train = t.randn(train_batch, in_features, device=args.device)
y_train = X_train@W + math.sqrt(s2)*t.randn(train_batch, out_features, device=args.device)

mean = t.zeros(train_batch, device=args.device)
cov = X_train @ X_train.t()/in_features + s2 * t.eye(train_batch, device=args.device)
Py = MultivariateNormal(mean, cov) 
logP = Py.log_prob(y_train.squeeze(-1)).item()/train_batch
print(logP)

X_test = t.randn(test_batch, in_features, device=args.device)
y_test = X_test@W + math.sqrt(s2)*t.randn(test_batch, out_features, device=args.device)

if (args.ap_lower=='gi') or (args.ap_top=='gi'):
    inducing_data = X_train[:inducing_batch, :].to(device='cpu')
    inducing_targets = y_train[:inducing_batch, :].to(device='cpu')
else:
    (inducing_data, inducing_targets) = (None, None)

kwargs = {
    'prior' : getattr(bnn.priors, args.prior),
}
kwargs_lower = {
    'fac' : dict(kwargs),
    'gfac': dict(kwargs),
    'rand': dict(kwargs),
    'gi'  : dict(kwargs, log_prec_lr=3., inducing_batch=inducing_batch),
    'li'  : dict(kwargs, log_prec_lr=3.),
    'det' : dict(kwargs)
}[args.ap_lower]
kwargs_top = {
    'fac' : dict(kwargs),
    'gfac': dict(kwargs),
    'rand': dict(kwargs),
    'gi'  : dict(kwargs, log_prec_lr=3., log_prec_init=0., inducing_targets=inducing_targets, inducing_batch=inducing_batch),
    'li'  : dict(kwargs, log_prec_init=0., log_prec_lr=3.),
    'det' : dict(kwargs)
}[args.ap_top]


def net(inducing_inputs, inducing_targets):
    if args.depth == 1:
        net = ap.top_linear(in_features, out_features, **kwargs_top)
    else:
        net = nn.Sequential(
                ap.lower_linear(in_features, args.width, **kwargs_lower),
                *[ap.lower_linear(args.width, args.width, **kwargs_lower) for _ in range(args.depth-2)],
                ap.top_linear(args.width, out_features, **kwargs_top)
            )

    if args.ap_top == 'gi':
        net = nn.Sequential(
            bnn.InducingAdd(inducing_data.shape[0], inducing_data=inducing_data),
            net,
            bnn.InducingRemove(inducing_data.shape[0])
        )
    return net


_net = net(inducing_data, inducing_targets).to(device=args.device)

opt = t.optim.Adam(_net.parameters(), lr=args.lr)

_iter = []
_elbo = []
_train_ll = []
_test_ll = []
_train_rmse = []
_test_rmse = []

for epoch in range(40):
    iters = 0
    total_elbo = 0
    total_train_ll = 0
    total_test_ll = 0
    total_train_se = 0
    total_test_se = 0

    start_time = timer()
    for i in range(1000):
        opt.zero_grad()

        data, target = X_train.to(args.device), y_train.to(args.device)
        data = data.expand(args.train_samples, *data.shape)
        yhat = _net(data)
        logPQw = bnn.logpq(_net)
        train_se = ((yhat - target)**2).mean()
        train_ll = Normal(yhat, math.sqrt(s2)).log_prob(target).mean()
        elbo = train_ll + logPQw.mean() / train_batch
        (-elbo).backward()
        opt.step()

        data_test, target_test = X_test.to(args.device), y_test.to(args.device)
        data_test = data_test.expand(args.test_samples, *data_test.shape)
        yhat = _net(data_test)
        test_se = ((yhat - target_test)**2).mean()
        test_ll = Normal(yhat, math.sqrt(s2)).log_prob(target_test).mean()

        iters          += 1
        total_elbo     += elbo.item()
        total_train_ll += train_ll.item()
        total_test_ll  += test_ll.item()
        total_train_se += train_se.item()
        total_test_se  += test_se.item()

    time = timer() - start_time

    _iter.append(epoch*iters)
    _elbo.append(total_elbo/iters)
    _train_ll.append(total_train_ll/iters)
    _test_ll.append(total_test_ll/iters)
    _train_rmse.append(math.sqrt(total_train_se/iters))
    _test_rmse.append(math.sqrt(total_test_se/iters))

    print((time, total_train_ll/iters, total_elbo/iters))

pd.DataFrame({
    'iter' : _iter,
    'elbo' : _elbo,
    'train_ll' : _train_ll,
    'test_ll'  : _test_ll,
    'train_se' : _train_rmse,
    'test_se'  : _test_rmse,
}).to_csv(args.output_filename)
