import sys
import math
import torch as t
t.set_num_threads(1)
import argparse
import pandas as pd
import numpy as np
from torch.distributions import Normal
from timeit import default_timer as timer
from ap_spec import APSpec
import gc

import bnn
import models.fc_uci
import models.uci_resgp

import psutil
import os

from data.UCI.uci import UCI

# Defaults to help debugging
parser = argparse.ArgumentParser()
parser.add_argument('output_filename', type=str, help='output filename', nargs='?', default='test_protein_3_1000.csv')
parser.add_argument('--dataset',       type=str, help='UCI dataset: boston, concrete, energy, kin8nm, naval, power, protein, wine, yacht', nargs='?', default='boston')
parser.add_argument('--split',         type=int, help='Split number', nargs='?', default=0)
parser.add_argument('--ap_lower',      type=str, help='variational family', nargs='?', default='gigp')
parser.add_argument('--ap_top',        type=str, help='variational family', nargs='?', default='gigp')
parser.add_argument('--model',         type=str, help='model', nargs='?', default='uciresgp')
parser.add_argument('--depth',         type=int, help='number of layers', nargs='?', default=2)
parser.add_argument('--seed',          type=int, help='random seed', nargs='?', default=0)
parser.add_argument('--test_samples',      type=int,   help='samples of the weights', nargs='?', default=100)
parser.add_argument('--train_samples',     type=int,   help='samples of the weights', nargs='?', default=1)
parser.add_argument('--device',            type=str,   default="cpu")
parser.add_argument('--dtype',            type=str,   default="float64")
parser.add_argument('--lr',            type=float,   default=1E-2)
parser.add_argument('--print', action='store_true', default=True)
parser.add_argument('--noshuffle', action='store_true')
parser.add_argument('--step', action='store_true')
parser.add_argument('--steps', type=int, nargs='?')
args = parser.parse_args()

device = t.device(args.device)
dtype  = getattr(t, args.dtype)

t.manual_seed(args.seed)

dataset = args.dataset

gradient_steps = args.steps if (args.steps is not None) else 20000
inducing_batch = 100
mbatch_size = 10000

epoch = []
elbo = []
elbo_ll = []
elbo_KL = []
test_ll = []
test_RMSE = []
splits = []


def train(net, uci, opt):
    iters = 0
    total_elbo = 0.
    total_ll = 0.
    total_KL = 0.
    lenX = len(uci.trainset)

    for data, target in uci.trainloader:
        opt.zero_grad()
        data, target = data.to(device=device, dtype=dtype), target.to(device=device, dtype=dtype)
        data = data.expand(args.train_samples, *data.shape)
        Pf = net(data)
        logPQw = bnn.logpq(net)

        assert target.shape == Pf.loc.shape[1:]

        ll = Pf.log_prob(target.unsqueeze(0)).mean(0).sum()
        elbo = ll + logPQw.mean()*target.shape[0]/lenX

        (-elbo).backward()
        opt.step()

        iters += 1
        total_elbo += elbo.detach().item()
        total_ll += ll.detach().item()
        total_KL -= logPQw.mean().detach().item()/lenX

    return (total_elbo / lenX, total_ll/lenX, total_KL/iters)


def test(net, uci):
    with t.no_grad():
        test_SE = 0
        test_ll = 0
        lenX = len(uci.testloader.dataset)
        
        for data, target in uci.testloader:
            data, target = data.to(device=device, dtype=dtype), target.to(device=device, dtype=dtype)
            data = data.expand(args.test_samples, *data.shape)
            Py = net(data)
            d_Py = uci.denormalize_Py(Py)
            d_target = uci.denormalize_y(target)
            ind_ll = d_Py.log_prob(d_target).detach()

            test_ll += (t.logsumexp(ind_ll, 0) - math.log(args.test_samples)).sum().item()/lenX

            mean_y = t.mean(d_Py.loc, 0, keepdim=True)
            test_SE += ((d_target-mean_y)**2).sum().item()/lenX

    test_RMSE = math.sqrt(test_SE)
    return (test_ll, test_RMSE)


def run(split):
    uci = UCI(dataset, split, mbatch_size)

    # calculate number of epochs
    epochs = math.ceil(gradient_steps/math.ceil(uci.num_train_set/mbatch_size))
    print(epochs, flush=True)

    inducing_data, inducing_targets = next(iter(uci.trainloader))
    if inducing_data.shape[0] > inducing_batch:
        inducing_data = inducing_data[:inducing_batch]
        inducing_targets = inducing_targets[:inducing_batch]
    in_features = inducing_data.shape[-1]
    if not ((args.ap_lower in ['gi', 'gigp']) or (args.ap_top in ['gi', 'gigp'])):
        (inducing_data, inducing_targets) = (None, None)

    kwargs_lower = {
        'gigp' : dict(inducing_batch=inducing_batch),
        'ligp' : dict(inducing_batch=inducing_batch),
    }[args.ap_lower]
    kwargs_top = {
        'gigp' : dict(log_prec_init=0., inducing_targets=inducing_targets, inducing_batch=inducing_batch),
        'ligp' : dict(inducing_batch=inducing_batch),
    }[args.ap_top]
    ap_spec = APSpec(args.ap_lower, args.ap_top)

    channels = min(uci.in_features, 30)
    print(channels, flush=True)
    net = {
        'uciresgp': models.uci_resgp.net
    }[args.model](ap_spec, inducing_data, in_features, args.depth, kwargs_lower, kwargs_top, channels=channels)
    net = net.to(device=device, dtype=dtype)

    opt = t.optim.Adam(net.parameters(), lr=args.lr)
    scheduler = t.optim.lr_scheduler.MultiStepLR(opt, [epochs//2])

    for _epoch in range(epochs):
        start_time = timer()

        _elbo, _elbo_ll, _elbo_KL = train(net, uci, opt)

        if _epoch + 1 == epochs:
            _test_ll, _test_RMSE = test(net, uci)
            test_ll.append(_test_ll)
            test_RMSE.append(_test_RMSE)
            splits.append(split)
            epoch.append(_epoch)
            elbo.append(_elbo)
            elbo_ll.append(_elbo_ll)
            elbo_KL.append(_elbo_KL)

        time = timer() - start_time
        if args.step:
            scheduler.step()
        if args.print:
            print(f"epoch:{_epoch: 3d}, time:{time:.2f}, elbo:{_elbo:.2f}")


for split in range(1 if args.dataset=="protein" else 20):
    run(split)

    print(epoch[-1])
    print(elbo[-1])
    print(elbo_ll[-1])
    print(elbo_KL[-1])
    print(test_ll[-1])
    print(test_RMSE[-1])
    print(splits[-1])

    pd.DataFrame({
        'epoch': epoch,
        'elbo': elbo,
        'elbo_ll': elbo_ll,
        'elbo_KL': elbo_KL,
        'test_ll': test_ll,
        'test_RMSE': test_RMSE,
        'split': splits,
    }).to_csv(args.output_filename)
