import torch as t
import layers as l
import math


def train(net, uci, opt, train_samples, device, dtype, L=1.):
    iters = 0
    total_elbo = 0
    total_ll = 0
    total_KL = 0
    lenX = len(uci.trainset)

    for data, target in uci.trainloader:
        opt.zero_grad()
        data, target = data.to(device=device, dtype=dtype), target.to(device=device, dtype=dtype)
        data = data.expand(train_samples, *data.shape)
        output, logpq, _ = l.propagate(net, data)

        assert target.shape == output.loc.shape[1:]

        ll = output.log_prob(target.unsqueeze(0)).mean(0).sum()
        elbo = ll + L*logpq.mean()*target.shape[0] / lenX

        (-elbo).backward()
        opt.step()

        iters += 1
        total_elbo += elbo.detach().item()
        total_ll += ll.detach().item()
        total_KL -= logpq.mean().detach().item() / lenX

    return total_elbo/lenX, total_ll/lenX, total_KL/iters


def test(net, uci, test_samples, device, dtype):
    with t.no_grad():
        test_SE = 0
        test_ll = 0
        lenX = len(uci.testloader.dataset)

        for data, target in uci.testloader:
            data, target = data.to(device=device, dtype=dtype), target.to(device=device, dtype=dtype)
            data = data.expand(test_samples, *data.shape)
            Py = net(data)
            d_Py = uci.denormalize_Py(Py)
            d_target = uci.denormalize_y(target)
            ind_ll = d_Py.log_prob(d_target).detach()

            test_ll += (t.logsumexp(ind_ll, 0) - math.log(test_samples)).sum().item()/lenX

            mean_y = t.mean(d_Py.loc, 0, keepdim=True)
            test_SE += ((d_target - mean_y)**2).sum().item()/lenX

        test_RMSE = math.sqrt(test_SE)
    return test_ll, test_RMSE
