import math
import torch as t
t.set_num_threads(1)
import argparse
import pandas as pd
import numpy as np
from torch.distributions import Normal
from timeit import default_timer as timer
from ap_spec import APSpec

import bnn
import models.fc_uci
from data.UCI.uci import UCI

# Defaults to help debugging
parser = argparse.ArgumentParser()
parser.add_argument('output_filename', type=str, help='output filename', nargs='?', default='test')
parser.add_argument('--dataset',       type=str, help='UCI dataset: boston, concrete, energy, kin8nm, naval, power, protein, wine, yacht', nargs='?', default='boston')
parser.add_argument('--split',         type=int, help='Split number', nargs='?', default=1)
parser.add_argument('--ap_lower',      type=str, help='variational family', nargs='?', default='gi')
parser.add_argument('--ap_top',        type=str, help='variational family', nargs='?', default='gi')
parser.add_argument('--model',         type=str, help='model', nargs='?', default='fc')
parser.add_argument('--depth',         type=int, help='number of layers', nargs='?', default=2)
parser.add_argument('--seed',          type=int, help='random seed', nargs='?', default=0)
parser.add_argument('--prior',         type=str,   help='InsanePrior, NealPrior or ScalePrior', nargs='?', default="ScalePrior")
parser.add_argument('--print', action='store_true', default=True)
parser.add_argument('--steps', type=int, nargs='?')
args = parser.parse_args()


if args.ap_top == 'fac' or args.ap_top == 'facLR':
    device = t.device('cpu')
else:
    device = t.device('cuda')

inducing_batch = 500
dtype = t.float64

t.manual_seed(args.seed)
np.random.seed(args.seed)

dataset = args.dataset
split = args.split

# Define hyperparameters
if args.ap_top == 'facLR' or args.ap_top == 'fac':
    gradient_steps = 25000
    mbatch_size_array = [32, 100, 500]
    lr_array = [3e-4, 1e-3, 3e-3, 1e-2]
else:
    gradient_steps = 10000
    mbatch_size_array = [10000]
    lr_array = [3e-3, 1e-2]

if args.steps is not None:
    gradient_steps = args.steps
    

n_train_samples = 10
n_test_samples = 100


def train(epoch, uci):
    iters = 0
    total_elbo = 0.
    total_ll = 0.
    total_KL = 0.

    for data, target in uci.trainloader:
        opt.zero_grad()
        data, target = data.to(device, dtype), target.to(device, dtype)
        data = data.expand(n_train_samples, *data.shape)
        output = net(data)
        logPQw = bnn.logpq(net)

        ll = Normal(output, t.exp(0.5*log_s2())).log_prob(target).mean(0).sum()
        elbo = ll/target.shape[0] + logPQw.mean() / uci.num_train_set
        (-elbo).backward()
        opt.step()

        iters += 1
        total_elbo += elbo.detach().item()
        total_ll += ll.detach().item()/target.shape[0]
        total_KL -= logPQw.mean().detach().item()/uci.num_train_set

    return (total_elbo / iters, total_ll/iters, total_KL/iters)


def test(uci):
    mean_y_train = uci.y_mean
    std_y_train = uci.y_std

    with t.no_grad():
        test_SE = 0
        test_ll = 0
        n_data = 0
        for data, target in uci.testloader:
            data, target = data.to(device, dtype), target.to(device, dtype)
            n_data += target.shape[0]
            data = data.expand(n_test_samples, *data.shape)
            output = net(data)

            d_target = uci.denormalize_y(target)
            Py = Normal(output, t.exp(0.5*log_s2().detach()))
            d_Py = uci.denormalize_Py(Py)
            ind_ll = d_Py.log_prob(d_target).detach()
            test_ll += (t.logsumexp(ind_ll, 0) - math.log(n_test_samples)).sum()

            mean_y = t.mean(d_Py.loc, 0, keepdim=True)
            test_SE += ((d_target-mean_y)**2).sum().detach()

    test_ll /= n_data
    test_SE /= n_data
    test_RMSE = t.sqrt(test_SE)
    return (test_ll.item(), test_RMSE.item())



ELBOs_array = np.zeros([len(lr_array), len(mbatch_size_array)])

for lr_i in range(len(lr_array)):
    lr = lr_array[lr_i]

    for mbatch_i in range(len(mbatch_size_array)):
        mbatch_size = mbatch_size_array[mbatch_i]

        uci = UCI(args.dataset, args.split, mbatch_size)

        # calculate number of epochs
        epochs = math.ceil(gradient_steps / math.ceil(uci.num_train_set / mbatch_size))

        inducing_data, inducing_targets = next(iter(uci.trainloader))
        if mbatch_size > inducing_batch:
            inducing_data = inducing_data[:inducing_batch]
            inducing_targets = inducing_targets[:inducing_batch]
        if inducing_data.shape[0] < inducing_batch:
            inducing_data = t.cat([inducing_data, t.randn(inducing_batch-inducing_data.shape[0], *inducing_data.shape[1:],
                                                          dtype=inducing_data.dtype)], 0)
            inducing_targets = t.cat(
                [inducing_targets, t.randn(inducing_batch - inducing_targets.shape[0], *inducing_targets.shape[1:],
                                        dtype=inducing_targets.dtype)], 0)
        in_features = inducing_data.shape[-1]
        if not ((args.ap_lower in ['gi', 'gigp']) or (args.ap_top in ['gi', 'gigp'])):
            (inducing_data, inducing_targets) = (None, None)

        kwargs = {
            'prior': getattr(bnn.priors, args.prior),
        }
        kwargs_lower = {
            'fac': dict(kwargs),
            'facLR': dict(kwargs),
            'gi': dict(kwargs, log_prec_lr=3, log_prec_init=-4., inducing_batch=inducing_batch, neuron_prec=True),
            'li': dict(kwargs, log_prec_lr=3, inducing_batch=inducing_batch, neuron_prec=True),
            'det': dict(kwargs),
            'gigp': dict(kwargs),
            'ligp': dict(kwargs),
        }[args.ap_lower]
        kwargs_top = {
            'fac': dict(kwargs),
            'facLR': dict(kwargs),
            'gi': dict(kwargs, log_prec_lr=3., log_prec_init=0., inducing_targets=inducing_targets, inducing_batch=inducing_batch, neuron_prec=True),
            'li': dict(kwargs, log_prec_lr=3., log_prec_init=0., inducing_batch=inducing_batch, neuron_prec=True),
            'det': dict(kwargs),
            'gigp': dict(kwargs),
            'ligp': dict(kwargs)
        }[args.ap_top]
        ap_spec = APSpec(args.ap_lower, args.ap_top)

        net = {
            'fc': models.fc_uci.net,
        }[args.model](ap_spec, inducing_data, in_features, args.depth, kwargs_lower, kwargs_top)
        net = net.to(device=device, dtype=dtype)

        epoch = []
        elbo = []
        elbo_ll = []
        elbo_KL = []
        test_ll = []
        test_RMSE = []

        factor = 10.
        log_s2_scaled = t.tensor(-3./factor, requires_grad=True, device=device)

        def log_s2():
            return factor*log_s2_scaled
        opt = t.optim.Adam([*net.parameters(), log_s2_scaled], lr=lr)

        for _epoch in range(epochs):
            start_time = timer()

            _elbo, _elbo_ll, _elbo_KL = train(_epoch, uci)

            if _epoch + 1 == epochs:
                epoch.append(_epoch)
                elbo.append(_elbo)
                elbo_ll.append(_elbo_ll)
                elbo_KL.append(_elbo_KL)

                _test_ll, _test_RMSE = test(uci)
                test_ll.append(_test_ll)
                test_RMSE.append(_test_RMSE)

            time = timer() - start_time
            if args.print:
                print(f"epoch:{_epoch: 3d}, time:{time:.2f}, elbo:{_elbo:.3f}, elbo_ll: {_elbo_ll:.3f} KL:{_elbo_KL:.3f}")

        ELBOs_array[lr_i, mbatch_i] = _elbo

        pd.DataFrame({
            'epoch': epoch,
            'elbo': elbo,
            'elbo_ll': elbo_ll,
            'elbo_KL': elbo_KL,
            'test_ll': test_ll,
            'test_RMSE': test_RMSE,
            'dataset': args.dataset,
            'split': args.split,
            'ap_lower': args.ap_lower,
            'ap_upper': args.ap_top,
            'model': args.model,
            'depth': args.depth,
            'seed': args.seed,
        }).to_csv(args.output_filename + '_lr{}_mb{}'.format(lr, mbatch_size))

pd.DataFrame(ELBOs_array).to_csv(args.output_filename + '_ELBOs')
