import pytest
from itertools import product
import numpy as np
import torch
from torch import nn
from torch.nn.utils import parameters_to_vector
from torch.utils.data import DataLoader, TensorDataset
from torch.distributions import Normal, Categorical

from laplace.curvature import AsdlGGN, AsdlEF, BackPackEF, BackPackGGN
from laplace.curvature.augmented_asdl import AugAsdlGGN
from laplace.laplace import FullLaplace, KronLaplace, DiagLaplace
from tests.utils import jacobians_naive


torch.manual_seed(240)
torch.set_default_tensor_type(torch.DoubleTensor)
flavors = [FullLaplace, KronLaplace, DiagLaplace]


def get_grad(model):
    return torch.cat([e.grad.flatten() for e in model.parameters()])


@pytest.fixture
def model():
    model = torch.nn.Sequential(nn.Linear(3, 20), nn.Linear(20, 2))
    setattr(model, 'output_size', 2)
    model_params = list(model.parameters())
    setattr(model, 'n_layers', len(model_params))  # number of parameter groups
    setattr(model, 'n_params', len(parameters_to_vector(model_params)))
    return model


@pytest.fixture
def class_loader():
    X = torch.randn(10, 3)
    y = torch.randint(2, (10,))
    return DataLoader(TensorDataset(X, y), batch_size=3)


@pytest.fixture
def reg_loader():
    X = torch.randn(10, 3)
    y = torch.randn(10, 2)
    return DataLoader(TensorDataset(X, y), batch_size=3)


@pytest.fixture
def aug_class_loader():
    X = torch.randn(12, 7, 3)
    y = torch.randint(2, (12,))
    return DataLoader(TensorDataset(X, y), batch_size=3, shuffle=True)


@pytest.fixture
def aug_reg_loader():
    X = torch.randn(12, 7, 3)
    y = torch.randn(12, 2)
    return DataLoader(TensorDataset(X, y), batch_size=3, shuffle=True)


@pytest.mark.parametrize('laplace', flavors)
def test_laplace_init(laplace, model):
    lap = laplace(model, 'classification')


@pytest.mark.parametrize('laplace', flavors)
def test_laplace_invalid_likelihood(laplace, model):
    with pytest.raises(ValueError):
        lap = laplace(model, 'otherlh')


@pytest.mark.parametrize('laplace', flavors)
def test_laplace_init_noise(laplace, model):
    # float
    sigma_noise = 1.2
    lap = laplace(model, likelihood='regression', sigma_noise=sigma_noise)
    # torch.tensor 0-dim
    sigma_noise = torch.tensor(1.2)
    lap = laplace(model, likelihood='regression', sigma_noise=sigma_noise)
    # torch.tensor 1-dim
    sigma_noise = torch.tensor(1.2).reshape(-1)
    lap = laplace(model, likelihood='regression', sigma_noise=sigma_noise)

    # for classification should fail
    sigma_noise = 1.2
    with pytest.raises(ValueError):
        lap = laplace(model, likelihood='classification', sigma_noise=sigma_noise)

    # other than that should fail
    # higher dim
    sigma_noise = torch.tensor(1.2).reshape(1, 1)
    with pytest.raises(ValueError):
        lap = laplace(model, likelihood='regression', sigma_noise=sigma_noise)
    # other datatype, only reals supported
    sigma_noise = '1.2'
    with pytest.raises(ValueError):
        lap = laplace(model, likelihood='regression', sigma_noise=sigma_noise)


@pytest.mark.parametrize('laplace', flavors)
def test_laplace_init_precision(laplace, model):
    # float
    precision = 10.6
    lap = laplace(model, likelihood='regression', prior_precision=precision)
    # torch.tensor 0-dim
    precision = torch.tensor(10.6)
    lap = laplace(model, likelihood='regression', prior_precision=precision)
    # torch.tensor 1-dim
    precision = torch.tensor(10.7).reshape(-1)
    lap = laplace(model, likelihood='regression', prior_precision=precision)
    # torch.tensor 1-dim param-shape
    precision = torch.tensor(10.7).reshape(-1).repeat(model.n_params)
    if laplace == KronLaplace:
        # Kron should not accept per parameter prior precision
        with pytest.raises(ValueError):
            lap = laplace(model, likelihood='regression', prior_precision=precision)
    else:
        lap = laplace(model, likelihood='regression', prior_precision=precision)
    # torch.tensor 1-dim layer-shape
    precision = torch.tensor(10.7).reshape(-1).repeat(model.n_layers)
    lap = laplace(model, likelihood='regression', prior_precision=precision)

    # other than that should fail
    # higher dim
    precision = torch.tensor(10.6).reshape(1, 1)
    with pytest.raises(ValueError):
        lap = laplace(model, likelihood='regression', prior_precision=precision)
    # unmatched dim
    precision = torch.tensor(10.6).reshape(-1).repeat(17)
    with pytest.raises(ValueError):
        lap = laplace(model, likelihood='regression', prior_precision=precision)
    # other datatype, only reals supported
    precision = '1.5'
    with pytest.raises(ValueError):
        lap = laplace(model, likelihood='regression', prior_precision=precision)


@pytest.mark.parametrize('laplace', flavors)
def test_laplace_init_prior_mean_and_scatter(laplace, model):
    mean = parameters_to_vector(model.parameters())
    P = len(mean)
    lap_scalar_mean = laplace(model, 'classification',
                              prior_precision=1e-2, prior_mean=1.)
    assert torch.allclose(lap_scalar_mean.prior_mean, torch.tensor([1.]))
    lap_tensor_mean = laplace(model, 'classification',
                              prior_precision=1e-2, prior_mean=torch.ones(1))
    assert torch.allclose(lap_tensor_mean.prior_mean, torch.tensor([1.]))
    lap_tensor_scalar_mean = laplace(model, 'classification',
                                     prior_precision=1e-2, prior_mean=torch.ones(1)[0])
    assert torch.allclose(lap_tensor_scalar_mean.prior_mean, torch.tensor(1.))
    lap_tensor_full_mean = laplace(model, 'classification',
                                   prior_precision=1e-2, prior_mean=torch.ones(P))
    assert torch.allclose(lap_tensor_full_mean.prior_mean, torch.ones(P))
    expected = ((mean - 1) * 1e-2) @ (mean - 1)
    assert expected.ndim == 0
    assert torch.allclose(lap_scalar_mean.scatter, expected)
    assert lap_scalar_mean.scatter.shape == expected.shape
    assert torch.allclose(lap_tensor_mean.scatter, expected)
    assert lap_tensor_mean.scatter.shape == expected.shape
    assert torch.allclose(lap_tensor_scalar_mean.scatter, expected)
    assert lap_tensor_scalar_mean.scatter.shape == expected.shape
    assert torch.allclose(lap_tensor_full_mean.scatter, expected)
    assert lap_tensor_full_mean.scatter.shape == expected.shape

    # too many dims
    with pytest.raises(ValueError):
        prior_mean = torch.ones(P).unsqueeze(-1)
        laplace(model, 'classification', prior_precision=1e-2, prior_mean=prior_mean)

    # unmatched dim
    with pytest.raises(ValueError):
        prior_mean = torch.ones(P-3)
        laplace(model, 'classification', prior_precision=1e-2, prior_mean=prior_mean)

    # invalid argument type
    with pytest.raises(ValueError):
        laplace(model, 'classification', prior_precision=1e-2, prior_mean='72')


@pytest.mark.parametrize('laplace', flavors)
def test_laplace_init_temperature(laplace, model):
    # valid float
    T = 1.1
    lap = laplace(model, likelihood='classification', temperature=T)
    assert lap.temperature == T


@pytest.mark.parametrize('laplace,lh', product(flavors, ['classification', 'regression']))
def test_laplace_functionality(laplace, lh, model, reg_loader, class_loader):
    if lh == 'classification':
        loader = class_loader
        sigma_noise = 1.
    else:
        loader = reg_loader
        sigma_noise = 0.3
    lap = laplace(model, lh, sigma_noise=sigma_noise, prior_precision=0.7)
    lap.fit(loader)
    assert lap.n_data == len(loader.dataset)
    assert lap.n_outputs == model.output_size
    f = model(loader.dataset.tensors[0])
    y = loader.dataset.tensors[1]
    assert f.shape == torch.Size([10, 2])

    # Test log likelihood (Train)
    log_lik = lap.log_likelihood
    # compute true log lik
    if lh == 'classification':
        log_lik_true = Categorical(logits=f).log_prob(y).sum()
        assert torch.allclose(log_lik, log_lik_true)
    else:
        assert y.size() == f.size()
        log_lik_true = Normal(loc=f, scale=sigma_noise).log_prob(y).sum()
        assert torch.allclose(log_lik, log_lik_true)
        # change likelihood and test again
        lap.sigma_noise = 0.72
        log_lik = lap.log_likelihood
        log_lik_true = Normal(loc=f, scale=0.72).log_prob(y).sum()
        assert torch.allclose(log_lik, log_lik_true)

    # Test marginal likelihood
    # lml = log p(y|f) - 1/2 theta @ prior_prec @ theta
    #       + 1/2 logdet prior_prec - 1/2 log det post_prec
    lml = log_lik_true
    theta = parameters_to_vector(model.parameters()).detach()
    assert torch.allclose(theta, lap.mean)
    prior_prec = torch.diag(lap.prior_precision_diag)
    assert prior_prec.shape == torch.Size([len(theta), len(theta)])
    lml = lml - 1/2 * theta @ prior_prec @ theta
    Sigma_0 = torch.inverse(prior_prec)
    if laplace == DiagLaplace:
        log_det_post_prec = lap.posterior_precision.log().sum()
    else:
        log_det_post_prec = lap.posterior_precision.logdet()
    lml = lml + 1/2 * (prior_prec.logdet() - log_det_post_prec)
    assert torch.allclose(lml, lap.log_marginal_likelihood())

    # test sampling
    torch.manual_seed(61)
    samples = lap.sample(n_samples=1)
    assert samples.shape == torch.Size([1, len(theta)])
    samples = lap.sample(n_samples=1000000)
    assert samples.shape == torch.Size([1000000, len(theta)])
    mu_comp = samples.mean(dim=0)
    mu_true = lap.mean
    assert torch.allclose(mu_comp, mu_true, rtol=1)

    # test functional variance
    if laplace == FullLaplace:
        Sigma = lap.posterior_covariance
    elif laplace == KronLaplace:
        Sigma = lap.posterior_precision.to_matrix(exponent=-1)
    elif laplace == DiagLaplace:
        Sigma = torch.diag(lap.posterior_variance)
    Js, f = jacobians_naive(model, loader.dataset.tensors[0])
    true_f_var = torch.einsum('mkp,pq,mcq->mkc', Js, Sigma, Js)
    comp_f_var = lap.functional_variance(Js)
    assert torch.allclose(true_f_var, comp_f_var, rtol=1e-4)


@pytest.mark.parametrize('laplace', flavors)
def test_regression_predictive(laplace, model, reg_loader):
    lap = laplace(model, 'regression', sigma_noise=0.3, prior_precision=0.7)
    lap.fit(reg_loader)
    X, y = reg_loader.dataset.tensors
    f = model(X)

    # error
    with pytest.raises(ValueError):
        lap(X, pred_type='linear')

    # GLM predictive, functional variance tested already above.
    f_mu, f_var = lap(X, pred_type='glm')
    assert torch.allclose(f_mu, f)
    assert f_var.shape == torch.Size([f_mu.shape[0], f_mu.shape[1], f_mu.shape[1]])
    assert len(f_mu) == len(X)

    # NN predictive (only diagonal variance estimation)
    f_mu, f_var = lap(X, pred_type='nn')
    assert f_mu.shape == f_var.shape
    assert f_var.shape == torch.Size([f_mu.shape[0], f_mu.shape[1]])
    assert len(f_mu) == len(X)


@pytest.mark.parametrize('laplace', flavors)
def test_classification_predictive(laplace, model, class_loader):
    lap = laplace(model, 'classification', prior_precision=0.7)
    lap.fit(class_loader)
    X, y = class_loader.dataset.tensors
    f = torch.softmax(model(X), dim=-1)

    # error
    with pytest.raises(ValueError):
        lap(X, pred_type='linear')

    # GLM predictive
    f_pred = lap(X, pred_type='glm', link_approx='mc', n_samples=100)
    assert f_pred.shape == f.shape
    assert torch.allclose(f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double))  # sum up to 1
    f_pred = lap(X, pred_type='glm', link_approx='probit')
    assert f_pred.shape == f.shape
    assert torch.allclose(f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double))  # sum up to 1
    f_pred = lap(X, pred_type='glm', link_approx='bridge')
    assert f_pred.shape == f.shape
    assert torch.allclose(f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double))  # sum up to 1


    # NN predictive
    f_pred = lap(X, pred_type='nn', n_samples=100)
    assert f_pred.shape == f.shape
    assert torch.allclose(f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double))  # sum up to 1


@pytest.mark.parametrize('laplace', flavors)
def test_regression_predictive_samples(laplace, model, reg_loader):
    lap = laplace(model, 'regression', sigma_noise=0.3, prior_precision=0.7)
    lap.fit(reg_loader)
    X, y = reg_loader.dataset.tensors
    f = model(X)

    # error
    with pytest.raises(ValueError):
        lap(X, pred_type='linear')

    # GLM predictive, functional variance tested already above.
    fsamples = lap.predictive_samples(X, pred_type='glm', n_samples=100)
    assert fsamples.shape == torch.Size([100, f.shape[0], f.shape[1]])

    # NN predictive (only diagonal variance estimation)
    fsamples = lap.predictive_samples(X, pred_type='nn', n_samples=100)
    assert fsamples.shape == torch.Size([100, f.shape[0], f.shape[1]])


@pytest.mark.parametrize('laplace', flavors)
def test_classification_predictive_samples(laplace, model, class_loader):
    lap = laplace(model, 'classification', prior_precision=0.7)
    lap.fit(class_loader)
    X, y = class_loader.dataset.tensors
    f = torch.softmax(model(X), dim=-1)

    # error
    with pytest.raises(ValueError):
        lap(X, pred_type='linear')

    # GLM predictive
    fsamples = lap.predictive_samples(X, pred_type='glm', n_samples=100)
    assert fsamples.shape == torch.Size([100, f.shape[0], f.shape[1]])
    assert np.allclose(fsamples.sum().item(), len(f) * 100)  # sum up to 1

    # NN predictive
    f_pred = lap.predictive_samples(X, pred_type='nn', n_samples=100)
    assert fsamples.shape == torch.Size([100, f.shape[0], f.shape[1]])
    assert np.allclose(fsamples.sum().item(), len(f) * 100)  # sum up to 1


@pytest.mark.parametrize('curv_type,laplace', 
                         product(['ggn', 'ef'], [FullLaplace, DiagLaplace, KronLaplace]))
def test_differentiable_marglik_backends_class(laplace, model, class_loader, curv_type):
    if curv_type == 'ef' and laplace is KronLaplace:
        # not to be tested since backpack doesn't have Kron-EF
        return
    if curv_type == 'ggn':
        ba, bb = AsdlGGN, BackPackGGN
    else:
        ba, bb = AsdlEF, BackPackEF
    backend_kwargs = dict(differentiable=True)

    lap = laplace(model, 'classification', backend=ba, backend_kwargs=backend_kwargs)
    lap.fit(class_loader)
    model.zero_grad()
    marglik = lap.log_marginal_likelihood()
    marglik.backward()
    grad = get_grad(model).clone()

    lap = laplace(model, 'classification', backend=bb, backend_kwargs=backend_kwargs)
    lap.fit(class_loader)
    model.zero_grad()
    marglikb = lap.log_marginal_likelihood()
    marglikb.backward()
    gradb = get_grad(model).clone()

    assert torch.allclose(marglik, marglikb)
    if not (curv_type == 'ggn' and laplace in [DiagLaplace, KronLaplace]):
        assert torch.allclose(grad, gradb)


@pytest.mark.parametrize('backend', [AsdlGGN, BackPackGGN])
def test_differentiable_marglik_diag(model, class_loader, backend):
    backend_kwargs = dict(differentiable=True)
    lap = FullLaplace(model, 'classification', backend=backend, backend_kwargs=backend_kwargs)
    lap.fit(class_loader)
    model.zero_grad()
    diag_posterior_prec = lap.posterior_precision.diagonal()
    pps = diag_posterior_prec.sum()
    pps.backward()
    grad = get_grad(model).clone()

    lap = DiagLaplace(model, 'classification', backend=backend, backend_kwargs=backend_kwargs)
    lap.fit(class_loader)
    model.zero_grad()
    ppsb = lap.posterior_precision.sum()
    ppsb.backward()
    gradb = get_grad(model).clone()

    assert torch.allclose(pps, ppsb)
    if backend == BackPackGGN:
        # BackPackGGN backpropagates not differentiably when using GGN, at least in parts incorrect
        with pytest.raises(AssertionError):
            assert torch.allclose(grad, gradb)
    elif backend == AsdlGGN:  # should work
        assert torch.allclose(grad, gradb)


@pytest.mark.parametrize('laplace,lh', product(flavors, ['classification', 'regression']))
def test_laplace_stochastic_gradient_estimate(laplace, lh, model, aug_reg_loader, aug_class_loader):
    model.zero_grad()
    torch.manual_seed(711)
    if lh == 'classification':
        loader = aug_class_loader
        sigma_noise = 1.
    else:
        loader = aug_reg_loader
        sigma_noise = 0.3

    x = loader.dataset.tensors[0]
    x.requires_grad = True

    if laplace == KronLaplace:
        lap = laplace(model, lh, backend=AugAsdlGGN, sigma_noise=sigma_noise, prior_precision=0.7)
        lap.fit(loader)
        lml = lap.log_marginal_likelihood()
        lml.backward()
        real_grad = x.grad.mean(1) # <- aug dim

        last_lap = laplace(model, lh, backend=AugAsdlGGN, sigma_noise=sigma_noise, prior_precision=0.7)
        last_lap.fit(loader, only_diff_last=1)
        last_lml = last_lap.log_marginal_likelihood()
        x.grad.zero_()
        last_lml.backward()
        stoch_grad = x.grad.mean(1) # <- aug dim

        # assert log marginal likelihood and partially detached log marginal likelihood are equal
        assert torch.allclose(lml, last_lml)
        
        # assert gradients and stochastic gradients are different
        assert not torch.allclose(real_grad, stoch_grad)

        # assert stochastic gradients are the same in expectation (tests unbiased estimate)
        sum_stoch_grad = stoch_grad * 0.0
        for stoch_i in range(1, 1001):
            last_lap.H = None
            last_lap.loss = 0.0
            last_lap.fit(loader, only_diff_last=1)
            last_lml = last_lap.log_marginal_likelihood()

            x.grad.zero_()
            last_lml.backward()

            N = len(loader.dataset)
            B = loader.batch_size

            # find samples with gradient
            # ideally, we should be able to get these from the dataloader, but with these checks this should suffice
            samples_i = torch.nonzero(x.grad.std([1, 2])).flatten()
            assert len(samples_i) == B                  # assert num samples with a grad equals the batch size
            for i in samples_i:
                if not i in samples_i:
                    assert torch.all(x.grad[i] == 0.0)  # assert other samples are all zero everywhere

            stoch_grad = x.grad.mean(1) # <- aug dim

            # sum the stochastic gradients
            sum_stoch_grad += stoch_grad

            # average (/ stoch_i) and multiply with (N / B) to account for stochastic sampling
            # and obtain unbiased estimate
            exp_stoch_grad = sum_stoch_grad * (N / B) / stoch_i

            mean_abs_error = torch.mean(torch.abs(exp_stoch_grad - real_grad))
            max_abs_error = torch.max(torch.abs(exp_stoch_grad - real_grad))
            # if (stoch_i < 10) or (stoch_i % 50 == 0):
            #     print(f"Sample {stoch_i},\tmean_abs_error: {mean_abs_error:.5f}\tmax_abs_error: {max_abs_error:.5f}")

        assert torch.allclose(exp_stoch_grad, real_grad, atol=0.1)
