import torch as t
import math
import pandas as pd
from uci_train_test import train, test
from models.uci_dgp import dgp_net
from models.uci_dwp import dwp_net
import jug

from data.UCI.uci import UCI


@jug.TaskGenerator
def train_uci(output_fn, model, dataset, split, depth, lr=1e-2, gradient_steps=20000,
              inducing_batch=100, mbatch_size=10000, train_samples=10, test_samples=100, device='cpu', dtype=t.float64,
              seed=0, thin=True, width=None):
    t.set_num_threads(1)
    device = t.device(device)
    t.manual_seed(seed)

    epoch = []
    elbo = []
    elbo_ll = []
    elbo_KL = []
    test_ll = []
    test_RMSE = []
    splits = []

    uci = UCI(dataset, split, mbatch_size)

    epochs = math.ceil(gradient_steps / math.ceil(uci.num_train_set / mbatch_size))
    temper_end = math.ceil(1000 / math.ceil(uci.num_train_set / mbatch_size))

    inducing_data, inducing_targets = next(iter(uci.trainloader))
    if inducing_data.shape[0] > inducing_batch:
        inducing_data = inducing_data[:inducing_batch]
        inducing_targets = inducing_targets[:inducing_batch]
    in_features = inducing_data.shape[-1]

    if width is None:
        width = in_features

    if model == 'dgp':
        net = dgp_net(inducing_batch, inducing_data, inducing_targets, width, depth=depth, in_features=in_features)
    elif model == 'dwp':
        net = dwp_net(inducing_batch, inducing_data, inducing_targets, width, depth=depth, in_features=in_features)

    net = net.to(device=device, dtype=dtype)

    opt = t.optim.Adam(net.parameters(), lr=lr)
    scheduler = t.optim.lr_scheduler.MultiStepLR(opt, [epochs // 2])

    for _epoch in range(epochs):
        if _epoch < temper_end:
            L = _epoch / temper_end
        else:
            L = 1.
        if model == 'resdgp_ext' or model == 'resdgp':
            L = 1.
        _elbo, _elbo_ll, _elbo_KL = train(net, uci, opt, train_samples, device, dtype, L=L)

        if not thin:
            # Record test/train metrics every epoch
            _test_ll, _test_RMSE = test(net, uci, test_samples, device, dtype)
            test_ll.append(_test_ll)
            test_RMSE.append(_test_RMSE)
            splits.append(split)
            epoch.append(_epoch)
            elbo.append(_elbo)
            elbo_ll.append(_elbo_ll)
            elbo_KL.append(_elbo_KL)

        if _epoch + 1 == epochs:
            _test_ll, _test_RMSE = test(net, uci, test_samples, device, dtype)
            test_ll.append(_test_ll)
            test_RMSE.append(_test_RMSE)
            splits.append(split)
            epoch.append(_epoch)
            elbo.append(_elbo)
            elbo_ll.append(_elbo_ll)
            elbo_KL.append(_elbo_KL)

        scheduler.step()
        print(f"epoch:{_epoch: 3d}, elbo:{_elbo:.2f}, elbo_ll:{_elbo_ll:.2f}, elbo_KL:{_elbo_KL:.2f}")

        pd.DataFrame({
            'epoch': epoch,
            'elbo': elbo,
            'elbo_ll': elbo_ll,
            'elbo_KL': elbo_KL,
            'test_ll': test_ll,
            'test_RMSE': test_RMSE,
            'split': splits,
        }).to_csv(output_fn)
