# Taken from https://github.com/AlexImmer/Laplace/blob/main/laplace/marglik_training.py
# and modified for differentiable learning of invariances/augmentation strategies.
import os
import os.path as osp
from copy import deepcopy
import logging
import numpy as np
import torch
from torch.nn.utils.convert_parameters import vector_to_parameters
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import ExponentialLR, CosineAnnealingLR
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn.utils import parameters_to_vector, clip_grad_norm_

from laplace import KronLaplace
from laplace.curvature import AsdlGGN

from utils import edge_homophily, _is_norm, kl_bernoulli

logger = logging.getLogger()


def validation(model, test_loader, likelihood, device, valid_mask, test_mask, n_samples=1):
    N_test = test_mask.sum()
    N_valid = valid_mask.sum()
    loss_test, acc_test = 0, 0
    loss_valid, acc_valid = 0, 0
    model.eval()
    for X, y in test_loader:
        X, y = X.detach().to(device), y.detach().to(device)
        f = list()
        with torch.no_grad():
            for _ in range(n_samples):
                f.append(model(X).detach())
            f = torch.stack(f, dim=0)
        probs = torch.softmax(f, dim=-1).mean(dim=0)
        if likelihood == 'classification':
            acc_test += (torch.argmax(probs[test_mask], dim=-1) == y[test_mask]).sum() / N_test
            acc_valid += (torch.argmax(probs[valid_mask], dim=-1) == y[valid_mask]).sum() / N_valid
            loss_test += CrossEntropyLoss(reduction='sum')(f[:, test_mask].mean(dim=0), y[test_mask]) / N_test
            loss_valid += CrossEntropyLoss(reduction='sum')(f[:, valid_mask].mean(dim=0), y[valid_mask]) / N_valid
        else:
            acc_test += (f[:, test_mask].mean(dim=0) - y[test_mask]).square().sum() / N_test
            acc_valid += (f[:, valid_mask].mean(dim=0) - y[valid_mask]).square().sum() / N_valid
            loss_test += MSELoss(reduction='sum')(f[:, test_mask].mean(dim=0), y[test_mask]) / N_test
            loss_valid += MSELoss(reduction='sum')(f[:, valid_mask].mean(dim=0), y[valid_mask]) / N_valid
    return loss_valid.item(), acc_valid.item(), loss_test.item(), acc_test.item()



def get_scheduler(scheduler, optimizer, train_loader, n_epochs, lr, lr_min):
    n_steps = n_epochs * len(train_loader)
    if scheduler == 'exp':
        min_lr_factor = lr_min / lr
        gamma = np.exp(np.log(min_lr_factor) / n_steps)
        return ExponentialLR(optimizer, gamma=gamma)
    elif scheduler == 'cos':
        return CosineAnnealingLR(optimizer, n_steps, eta_min=lr_min)
    else:
        raise ValueError(f'Invalid scheduler {scheduler}')


def gradient_to_vector(parameters):
    return parameters_to_vector([e.grad for e in parameters])


def marglik_optimization(model,
                         train_loader,
                         marglik_loader=None,
                         valid_loader=None,
                         likelihood='classification',
                         partial_loader=None,
                         train_mask=None,
                         marglik_mask=None,
                         partial_mask=None,
                         valid_mask=None,
                         test_mask=None,
                         n_epochs=500,
                         lr=1e-3,
                         weight_decay=0,
                         lr_min=None,
                         optimizer='Adam',
                         scheduler='exp',
                         n_epochs_burnin=0,
                         n_hypersteps=100,
                         marglik_frequency=1,
                         lr_graph=1e-2,
                         lr_graph_min=None,
                         graph_grad_norm=False,
                         laplace=KronLaplace,
                         backend=AsdlGGN,
                         differentiable=True,
                         early_stop_crit='marglik',
                         graph_opt_kwargs={},
                         graph_prior=None,
                         prior_edge_index=None,
                         prior_edge_probs=None,
                         log_frequency=20,
                         graph_kl_weight=1.,
                         log_det_weight=1.,
                         n_samples=10,
                         checkpoint_dir=None,
                         repeat=None):
    """Runs marglik optimization training for a given model and training dataloader.

    Parameters
    ----------
    model : torch.nn.Module
        torch model
    train_loader : DataLoader
        pytorch training dataset loader
    marglik_loader : DataLoader
        pytorch training dataset loader for marglik opt.
        defaults to train_loader
    valid_loader : DataLoader
        pytorch training dataset loader
    likelihood : str
        'classification' or 'regression'
    partial_loader : DataLoader
        pytorch training dataset loader for partial marglik opt
        defaults to marglik_loader
    train_mask : torch.Tensor
        mask for training data
    marglik_mask : torch.Tensor
    partial_mask : torch.Tensor
    valid_mask : torch.Tensor
    n_epochs : int
    lr : float
        learning rate for model optimizer
    lr_min : float
        minimum learning rate, defaults to lr and hence no decay
        to have the learning rate decay from 1e-3 to 1e-6, set
        lr=1e-3 and lr_min=1e-6.
    lr_graph : float
        learning rate for graph optimizer
    lr_graph_min: float
        minimum learning rate for graph optimizer, defaults to lr_graph
    graph_grad_norm : bool
        whether or not to normalize the graph gradients
    early_stop_crit: str
        'marglik' or 'valid'
        which criterion to use for early stopping
    graph_prior: str
        'bernoulli' or None
        only Bernoulli prior is supported. None means no prior used for KL.
    prior_edge_index: torch.Tensor
        edge index of the prior graph
    prior_edge_probs: torch.Tensor
        edge probabilities of the prior graph
    graph_kl_weight: float
        weight for the KL divergence term in the loss function
        'alpha' in the paper
    log_det_weight: float
        weight for the log determinant term in the marglik
        'beta' in the paper
    optimizer : str
        either 'Adam' or 'SGD'
    scheduler : str
        either 'exp' for exponential and 'cos' for cosine decay towards lr_min
    n_epochs_burnin : int default=0
        how many epochs to train without estimating and differentiating marglik
    n_hypersteps : int
        how many steps to take on the hyperparameters when marglik is estimated
    marglik_frequency : int
        how often to estimate (and differentiate) the marginal likelihood
    n_samples: int
        number of MC samples
    checkpoint_dir: str
        directory to save the checkpoints of learnd graph
    laplace : Laplace
        type of Laplace approximation (Kron/Diag/Full)
    backend : Backend
        AsdlGGN/AsdlEF or BackPackGGN/BackPackEF
        only AsdlGGN is currently supported.


    Returns
    -------
    lap : Laplace
        laplace approximation
    model : torch.nn.Module
    margliks : list
    valid_pers: list
        validation accuracies
    valid_losses : list
        validation losses
    losses : list
        training losses
    """    
    if checkpoint_dir is not None:
        os.makedirs(checkpoint_dir, exist_ok=True)
    if laplace == KronLaplace and backend == AsdlGGN and any(_is_norm(m) for m in model.modules()):
        logger.warning('Ignoring norm parameters as AsdlGGN with Kron does not support it.')
        model.ignore_norm_params = True
    if backend != AsdlGGN:
        raise ValueError('Only AsdlGGN backend is currently supported.')
    if lr_min is None:  # don't decay lr
        lr_min = lr
    if lr_graph_min is None:  # don't decay lr
        lr_graph_min = lr_graph
    if marglik_loader is None:
        marglik_loader = train_loader
    if partial_loader is None:
        partial_loader = marglik_loader
    if valid_loader is None:
        valid_loader = train_loader
    marglik_mask = train_mask if marglik_mask is None else marglik_mask
    partial_mask = marglik_mask if partial_mask is None else partial_mask
    device = parameters_to_vector(model.parameters()).device
    N = len(train_loader.dataset) if train_mask is None else train_mask.sum()
    backend_kwargs = dict(differentiable=differentiable)

    # set up loss (and observation noise hyperparam)
    if likelihood == 'classification':
        criterion = CrossEntropyLoss(reduction='mean')
    elif likelihood == 'regression':
        criterion = MSELoss(reduction='mean')

    # set up model optimizer and scheduler
    optimizer = Adam([v for k, v in model.named_parameters() if 'graph_builder' not in k], lr=lr, weight_decay=weight_decay)
    scheduler = get_scheduler(scheduler, optimizer, train_loader, n_epochs, lr, lr_min)

    # set up hyperparameter optimizer
    optimize_graph = parameters_to_vector(model.graph_builder.parameters()).requires_grad
    if optimize_graph:
        logger.info('MARGLIK: optimize adjacency matrix.', extra={'repeat': repeat})
        graph_optimizer = SGD(model.graph_builder.parameters(),
                              lr=lr_graph, momentum=0.9, **graph_opt_kwargs)
        graph_scheduler = get_scheduler('exp', graph_optimizer, train_loader,
                                        n_epochs // marglik_frequency * n_hypersteps,
                                        lr_graph, lr_graph_min)

    best_epoch = None
    best_model_dict = None
    criteria = {
        'marglik': {
            'best': np.inf,
            'log': list(),
            'check': lambda: criteria['marglik']['log'][-1] < criteria['marglik']['best'],
        },
        'valid_loss': {
            'best': np.inf,
            'log': list(),
            'check': lambda: criteria['valid_loss']['log'][-1] < criteria['valid_loss']['best'],
        },
        'valid_acc': {
            'best': 0.0,
            'log': list(),
            'check': lambda: criteria['valid_acc']['log'][-1] > criteria['valid_acc']['best'],
        }
    }
    best_model_stats = {
        'valid_loss': np.inf,
        'valid_acc': 0.0,
        'test_loss': np.inf,
        'test_acc': 0.0,
        'marglik': np.inf,
    }
    losses = list()


    if not optimize_graph:
        n_hypersteps = 1
        n_epochs_burnin = 0
    
    if optimize_graph and graph_prior == 'bernoulli':
        prior_edge_index, prior_edge_probs = prior_edge_index.to(device), prior_edge_probs.to(device)
        n_nodes = prior_edge_index.max().item() + 1
        prior_adj = torch.zeros((n_nodes, n_nodes), device=device)
        prior_adj[prior_edge_index[0], prior_edge_index[1]] = prior_edge_probs
        prior_adj[prior_edge_index[1], prior_edge_index[0]] = prior_edge_probs

    for epoch in range(1, n_epochs + 1):
        epoch_loss = 0
        epoch_perf = 0

        sample_losses = list()
        model.disable_graph_builder_grad()  # no adj updates during burnin
        # standard NN training per batch
        torch.cuda.empty_cache()
        model.train()
        for X, Y in train_loader:
            optimizer.zero_grad()
            for _ in range(n_samples):
                X, Y = X.detach().to(device), Y.to(device)
                f = model(X)
                if train_mask is not None:
                    f = f[train_mask]
                    y = Y[train_mask]
                else:
                    y = Y
                loss = criterion(f, y)
                sample_losses.append(loss)
            loss = torch.stack(sample_losses).mean()
            loss.backward()

            if epoch == 1:
                n_edges = torch.sum(model.discrete_adj)
                homophily = edge_homophily(model.discrete_adj, Y)
                logger.info(f'GRAPH[epoch={epoch}]: edges={n_edges}, homophily={homophily:.4f}.',
                            extra={'repeat': repeat})

            optimizer.step()
            epoch_loss += loss.cpu().item() / len(train_loader)
            if likelihood == 'regression':
                epoch_perf += (f.detach() - y).square().sum() / N
            else:
                epoch_perf += torch.sum(torch.argmax(f.detach(), dim=-1) == y).item() / N
            scheduler.step()

        losses.append((epoch, epoch_loss))
        
        if epoch % log_frequency == 0:
            logger.info(f'MARGLIK[epoch={epoch}]: training performance {epoch_perf*100:.2f}%.',
                        extra={'repeat': repeat})
            gb_factor = 1024 ** 3
            logger.info('Max memory allocated: ' + str(torch.cuda.max_memory_allocated(loss.device)/gb_factor) + ' Gb.',
                        extra={'repeat': repeat})
        optimizer.zero_grad(set_to_none=True)

        # optimize graph by differentiating marglik
        # 1. fit laplace approximation
        torch.cuda.empty_cache()
        lap = laplace(model, likelihood, backend=backend, backend_kwargs=backend_kwargs)
        
        # only update graph every "Frequency" steps after "burnin"
        if (epoch % marglik_frequency) != 0 or epoch < n_epochs_burnin:
            # compute validation error to report during training
            valid_loss, valid_acc, test_loss, test_acc = validation(
                model, valid_loader, likelihood, device, valid_mask, test_mask, n_samples=n_samples)
            criteria['valid_acc']['log'].append(valid_acc)
            criteria['valid_loss']['log'].append(valid_loss)
            lap.fit(marglik_loader, train_mask=marglik_mask)
            log_lik, log_det = lap.log_marginal_likelihood()
            marglik = -(log_lik - log_det).item()
            criteria['marglik']['log'].append(marglik)
            if epoch % log_frequency == 0:
                logger.info(f'MARGLIK[epoch={epoch}]: marglik {marglik:.2f} | valid acc {valid_acc*100:.2f}% (loss: {valid_loss:.2f}) | test acc {test_acc*100:.2f}% (loss: {test_loss:.2f})',
                            extra={'repeat': repeat})
        else:
            if optimize_graph:
                model.enable_graph_builder_grad()
                # fit without grad and make data iterator to draw from for partial fits
                lap.fit(marglik_loader, train_mask=marglik_mask, keep_factors=True)  # calculate the preconditioner with full data
                torch.cuda.empty_cache()
                partial_iterator = iter(partial_loader)
                X, y = next(partial_iterator)
                lap.fit_partial(X, y, train_mask=marglik_mask)  # gradient only on batch
            else:
                lap.fit(marglik_loader, train_mask=marglik_mask)
                torch.cuda.empty_cache()
            
            # 2. differentiate wrt. graph for n_hypersteps
            for i in range(n_hypersteps):
                model.zero_grad()
                graph_optimizer.zero_grad()
                sample_losses = list()
                for j in range(n_samples):
                    log_lik, log_det = lap.log_marginal_likelihood()
                    marglik = -(log_lik - log_det).item()  # negative log marglik
                    loss = -(log_lik - log_det_weight * log_det)  # log_det_weight is beta in paper
                    if optimize_graph and graph_prior == 'bernoulli':
                        edge_index = model.discrete_adj_prob.indices()
                        probs = model.discrete_adj_prob.values()
                        prior_probs = prior_adj[edge_index[0], edge_index[1]]
                        kl_loss = kl_bernoulli(probs, prior_probs).sum() / 2  # symmetric
                        homophily = edge_homophily(model.discrete_adj, Y)
                        loss = loss + graph_kl_weight * kl_loss  # graph_kl_weight is alpha in paper
                    torch.cuda.empty_cache()
                    sample_losses.append(loss)
                    if optimize_graph and (i < n_hypersteps) and (j < n_samples):
                        try:
                            X, y = next(partial_iterator)
                        except StopIteration:
                            pass
                        lap.fit_partial(X, y, train_mask=marglik_mask)
                        
                loss = torch.stack(sample_losses).mean()
                loss.backward()

                torch.cuda.empty_cache()

                if optimize_graph and torch.any(torch.isnan(gradient_to_vector(model.graph_builder.parameters()))):
                    for param in model.graph_builder.parameters():
                        param.grad = torch.nan_to_num(param.grad)

                if optimize_graph:
                    if graph_grad_norm:
                        clip_grad_norm_(model.graph_builder.parameters(), max_norm=1.0)
                    graph_optimizer.step()
                    graph_scheduler.step()
                    torch.cuda.empty_cache()
                    if (i < n_hypersteps - 1):
                        try:
                            X, y = next(partial_iterator)
                        except StopIteration:
                            pass
                        lap.fit_partial(X, y, train_mask=marglik_mask)

            if optimize_graph and epoch % log_frequency == 0:
                n_edges = torch.sum(model.discrete_adj)
                homophily = edge_homophily(model.discrete_adj, y)
                logger.info(f'GRAPH[epoch={epoch}]: edges={n_edges}, homophily={homophily:.4f}.',
                            extra={'repeat': repeat})
                logger.info(f'GRAPH[epoch={epoch}]: lr={graph_scheduler.get_lr()[-1]:.6f}.',
                            extra={'repeat': repeat})
                if checkpoint_dir is not None:
                    torch.save(model.graph_builder.state_dict(), osp.join(checkpoint_dir, f'epoch_{epoch}.pt'))

            criteria['marglik']['log'].append(marglik)
            del lap

        # compute validation error to report during training
        valid_loss, valid_acc, test_loss, test_acc = validation(
                model, valid_loader, likelihood, device, valid_mask, test_mask, n_samples=n_samples)
        criteria['valid_acc']['log'].append(valid_acc)
        criteria['valid_loss']['log'].append(valid_loss)
        if epoch % log_frequency == 0:
            logger.info(f'MARGLIK[epoch={epoch}]: valid acc {valid_acc*100:.2f}% (loss: {valid_loss:.2f}) | test acc {test_acc*100:.2f}% (loss: {test_loss:.2f})',
                        extra={'repeat': repeat})
        

        best_val = criteria[early_stop_crit]['best']
        current_val = criteria[early_stop_crit]['log'][-1]
        # early stopping
        if early_stop_crit in criteria and criteria[early_stop_crit]['check']():
            best_epoch = epoch
            model.cpu()
            best_model_dict = deepcopy(model.state_dict())
            model.to(device)
            if optimize_graph:
                best_graph_builder = deepcopy(model.graph_builder.state_dict())
            # update best values
            criteria[early_stop_crit]['best'] = current_val
            best_model_stats['valid_loss'] = valid_loss
            best_model_stats['valid_acc'] = valid_acc
            best_model_stats['test_loss'] = test_loss
            best_model_stats['test_acc'] = test_acc
            best_model_stats['marglik'] = marglik
            logger.info(f'MARGLIK[epoch={epoch}]: {early_stop_crit}={current_val:.4f}. Saving new best model (test acc={best_model_stats["test_acc"]:.4f}).',
                        extra={'repeat': repeat})
        else:
            if epoch % log_frequency == 0:
                logger.info(f'MARGLIK[epoch={epoch}]: {early_stop_crit}={current_val:.2f}. (best {early_stop_crit}={best_val:.4f} | test_acc={best_model_stats["test_acc"]:.4f})', extra={'repeat': repeat})

    torch.cuda.empty_cache()
    logger.info(f'MARGLIK: finished training. Recover best model and fit Laplace.', extra={'repeat': repeat})
    logger.info(f'MARGLIK: best epoch {best_epoch}.', extra={'repeat': repeat})
    if best_model_dict is not None:
        model.load_state_dict(best_model_dict)
        if optimize_graph:
            model.graph_builder.load_state_dict(best_graph_builder)
    
    lap = laplace(model, likelihood, backend=backend, backend_kwargs=backend_kwargs)
    lap.fit(marglik_loader)
    logger.info(f'MARGLIK: valid_loss={best_model_stats["valid_loss"]:.4f}, valid_acc={best_model_stats["valid_acc"]*100:.2f}%.', extra={'repeat': repeat})
    logger.info(f'MARGLIK: test_loss={best_model_stats["test_loss"]:.4f}, test_acc={best_model_stats["test_acc"]*100:.2f}%.', extra={'repeat': repeat})
    logger.info(f'MARGLIK: marglik={best_model_stats["marglik"]:.4f}.', extra={'repeat': repeat})

    return lap, model, best_model_stats

