import sys
from functools import partial
from copy import deepcopy
import logging
import numpy as np
import torch
from torch.nn.utils.convert_parameters import vector_to_parameters
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import ExponentialLR, CosineAnnealingLR
from torch.nn import CrossEntropyLoss, MSELoss, Sequential
from torch.nn.utils import parameters_to_vector
from torch.distributions import Normal
import wandb

from hetreg.utils import wandb_log_prior, wandb_log_parameter_norm, select_criterion


GB_FACTOR = 1024 ** 3


def expand_prior_precision(prior_prec, model):
    theta = parameters_to_vector(model.parameters())
    device, P = theta.device, len(theta)
    assert prior_prec.ndim == 1
    if len(prior_prec) == 1:  # scalar
        return torch.ones(P, device=device) * prior_prec
    elif len(prior_prec) == P:  # full diagonal
        return prior_prec.to(device)
    else:
        return torch.cat([delta * torch.ones_like(m).flatten() for delta, m
                          in zip(prior_prec, model.parameters())])


def get_prior_hyperparams(prior_prec_init, prior_structure, H, P, device):
    log_prior_prec_init = np.log(prior_prec_init)
    if prior_structure == 'scalar':
        log_prior_prec = log_prior_prec_init * torch.ones(1, device=device)
    elif prior_structure == 'layerwise':
        log_prior_prec = log_prior_prec_init * torch.ones(H, device=device)
    elif prior_structure == 'diagonal':
        log_prior_prec = log_prior_prec_init * torch.ones(P, device=device)
    else:
        raise ValueError(f'Invalid prior structure {prior_structure}')
    log_prior_prec.requires_grad = True
    return log_prior_prec


def valid_performance(model, test_loader, likelihood, device):
    criterion = select_criterion(likelihood)
    N = len(test_loader.dataset)
    perf = 0
    nll = 0
    nll_corr = 0
    for X, y in test_loader:
        X, y = X.detach().to(device), y.detach().to(device)
        if "loss" in likelihood and "_and_corr" in likelihood:
            with torch.no_grad():
                dict_rst = model(X)
            loss, loss_nll, loss_corr = criterion(dict_rst, y)
            nll_corr += loss_corr.item() / len(test_loader)
        elif likelihood == "loss_nll_sole":
            with torch.no_grad():
                dict_rst = model(X)
            loss, loss_nll = criterion(dict_rst, y)
        else:
            with torch.no_grad():
                f = model(X)
            loss_nll = criterion(f, y)

        perf += loss_nll
        nll += loss_nll.item() / len(test_loader)

    return perf, nll, nll_corr


def get_scheduler(scheduler, optimizer, train_loader, n_epochs, lr, lr_min):
    n_steps = n_epochs * len(train_loader)
    if scheduler == 'exp':
        min_lr_factor = lr_min / lr
        gamma = np.exp(np.log(min_lr_factor) / n_steps)
        return ExponentialLR(optimizer, gamma=gamma)
    elif scheduler == 'cos':
        return CosineAnnealingLR(optimizer, n_steps, eta_min=lr_min)
    else:
        raise ValueError(f'Invalid scheduler {scheduler}')


def get_model_optimizer(optimizer, model, lr, lr_corr, weight_decay=0):
    if optimizer == 'Adam':
        has_item = hasattr(model, 'correlation')

        optimizer = torch.optim.Adam([
            {"params": [p for n, p in model.named_parameters() if "correlation" in n],
             "lr": lr_corr,
             "weight_decay": weight_decay},
            {"params": [p for n, p in model.named_parameters() if "correlation" not in n],
             "lr": lr,
             "weight_decay": weight_decay}
        ])

        return optimizer #Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    elif optimizer == 'SGD':
        # fixup parameters should have 10x smaller learning rate
        is_fixup = lambda param: param.size() == torch.Size([1])  # scalars
        fixup_params = [p for p in model.parameters() if is_fixup(p)]
        standard_params = [p for p in model.parameters() if not is_fixup(p)]
        params = [{'params': standard_params}, {'params': fixup_params, 'lr': lr / 10.}]
        return SGD(params, lr=lr, momentum=0.9, weight_decay=weight_decay)
    else:
        raise ValueError(f'Invalid optimizer {optimizer}')


def marglik_for_lucid_v23(model,
                          # model_twin,
                          train_loader,
                          marglik_loader=None,
                          valid_loader=None,
                          partial_loader=None,
                          likelihood='classification',
                          reg_config={},
                          n_epochs=500,
                          lr=1e-3,
                            lr_corr=1e-2,
                          lr_min=None,
                          optimizer='Adam',
                          scheduler='cos',
                          early_stopping=False,
                          use_wandb=False,
                          grad_clip_norm=None,
                          mean_head=None):
    """Runs marglik optimization training for a given model and training dataloader.

    Parameters
    ----------
    model : torch.nn.Module
        torch model
    train_loader : DataLoader
        pytorch training dataset loader
    valid_loader : DataLoader
    likelihood : str
        'classification', 'regression', 'heteroscedastic_regression'
    prior_structure : str
        'scalar', 'layerwise', 'diagonal'
    prior_prec_init : float
        initial prior precision
    sigma_noise_init : float
        initial observation noise (for regression only)
    temperature : float
        factor for the likelihood for 'overcounting' data.
        Often required when using data augmentation.
    n_epochs : int
    lr : float
        learning rate for model optimizer
    lr_min : float
        minimum learning rate, defaults to lr and hence no decay
        to have the learning rate decay from 1e-3 to 1e-6, set
        lr=1e-3 and lr_min=1e-6.
    optimizer : str
        either 'Adam' or 'SGD'
    scheduler : str
        either 'exp' for exponential and 'cos' for cosine decay towards lr_min
    n_epochs_burnin : int default=0
        how many epochs to train without estimating and differentiating marglik
    n_hypersteps : int
        how many steps to take on the hyperparameters when marglik is estimated
    marglik_frequency : int
        how often to estimate (and differentiate) the marginal likelihood
    lr_hyp : float
        learning rate for hyperparameters (should be between 1e-3 and 1)
    laplace : Laplace
        type of Laplace approximation (Kron/Diag/Full)
    backend : Backend
        AsdlGGN/AsdlEF or BackPackGGN/BackPackEF
    stochastic_grad : bool
    independent : bool
        whether to use independent functional laplace
    single_output : bool
        whether to use single random output for functional laplace
    kron_jac : bool
        whether to use kron_jac in the backend

    Returns
    -------
    lap : Laplace
        lapalce approximation
    model : torch.nn.Module
    margliks : list
    losses : list
    """
    if lr_min is None:  # don't decay lr
        lr_min = lr
    if marglik_loader is None:
        marglik_loader = train_loader
    if partial_loader is None:
        partial_loader = marglik_loader
    device = parameters_to_vector(model.parameters()).device
    N = len(train_loader.dataset)
    H = len(list(model.parameters()))
    P = len(parameters_to_vector(model.parameters()))

    if use_wandb:
        wandb.config.update(dict(n_params=P, n_param_groups=H, n_data=N))

    sigma_noise = 1

    criterion = partial(select_criterion(likelihood), reg_config=reg_config)
    # set up model optimizer and scheduler
    optimizer = get_model_optimizer(optimizer, model, lr, lr_corr)
    scheduler = get_scheduler(scheduler, optimizer, train_loader, n_epochs, lr, lr_min)

    losses = list()
    valid_perfs = list()
    valid_nlls = list()
    valid_nlls_corr = list()
    negliks_pred = list()
    negliks_corr = list()

    best_neglik = np.inf
    best_neglik_corr = np.inf
    for epoch in range(1, n_epochs + 1):
        print("Epoch " + str(epoch))
        epoch_loss = 0
        epoch_perf = 0
        epoch_nll = 0
        epoch_nll_corr = 0
        epoch_log = dict(epoch=epoch)
        # standard NN training per batch
        torch.cuda.empty_cache()
        for X, y in train_loader:
            X, y = X.detach().to(device), y.to(device)
            optimizer.zero_grad()

            if 'loss_' in likelihood and '_and_corr' in likelihood:
                dict_rst = model(X)
                loss, loss_nll, loss_correlation_part = criterion(dict_rst, y)
                epoch_nll += loss_nll.item() / len(train_loader)
                epoch_nll_corr += loss_correlation_part.cpu().item() / len(train_loader)
            elif likelihood == 'loss_nll_sole':
                dict_rst = model(X)
                loss, loss_nll = criterion(dict_rst, y)
                epoch_nll += loss_nll.item() / len(train_loader)
            else:
                f = model(X)
                loss = criterion(f, y)
                epoch_nll += criterion(f.detach(), y).item() / len(train_loader)

            loss.backward()
            if grad_clip_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm)
            optimizer.step()
            epoch_loss += loss.cpu().item() / len(train_loader)

            if "loss_" not in likelihood:
                epoch_perf += (f.detach() - y).square().sum() / N
            elif "nll" in likelihood:
                f = dict_rst["overall_mu_and_var"]
                if mean_head is None:
                    epoch_perf += (y.squeeze() + 0.5 * f[:, 0] / f[:, 1]).square().sum() / N
                else:  # mean_head is natural reparam head use mean-var parameterization
                    epoch_perf += (y.squeeze() - f[:, 0]).square().sum() / N
            else:
                epoch_perf += torch.sum(torch.argmax(f.detach(), dim=-1) == y).item() / N
            scheduler.step()
        losses.append(epoch_loss)
        negliks_pred.append(epoch_nll)
        negliks_corr.append(epoch_nll_corr)

        # logging.info(f'MARGLIK[epoch={epoch}]: train. perf={epoch_perf:.2f}; loss={epoch_loss:.5f}; nll={epoch_nll:.5f}')
        optimizer.zero_grad(set_to_none=True)
        llr = scheduler.get_last_lr()[0]
        epoch_log.update({'train/loss': epoch_loss, 'train/nll': epoch_nll, 'train/perf': epoch_perf, 'train/lr': llr})
        print({'train/loss': epoch_loss, 'train/nll': epoch_nll, 'train/perf': epoch_perf, 'train/lr': llr})

        # compute validation error to report during training
        if valid_loader is not None:
            with torch.no_grad():
                val_perf, val_nll, val_nll_corr = valid_performance(model, valid_loader, likelihood, device)
                valid_perfs.append(val_perf)
                valid_nlls.append(val_nll)
                valid_nlls_corr.append(val_nll_corr)
                # logging.info(f'MARGLIK[epoch={epoch}]: valid. perf={val_perf:.2f}; nll={val_nll:.5f}.')
                epoch_log.update({'valid/perf': val_perf, 'valid/nll': val_nll})

        if use_wandb and (epoch % 50) == 0:
            wandb_log_parameter_norm(model)

        if use_wandb:
            wandb.log(epoch_log, step=epoch, commit=True)

        # early stopping on marginal likelihood
        if valid_loader is not None:
            if early_stopping and (valid_nlls[-1] < best_neglik) and epoch > 30:
                best_model_dict = deepcopy(model.state_dict())
                best_neglik = valid_nlls[-1]
                print("update model with best score")
            if early_stopping and ("corr" in likelihood) and (valid_nlls_corr[-1] < best_neglik_corr) and epoch > 30:
                best_model_dict_corr = deepcopy(model.correlation.state_dict())
                best_neglik_corr = valid_nlls_corr[-1]
                print("update corr model with best corr score")

    best_model_earlystop = deepcopy(model)
    if early_stopping and (best_model_dict is not None):
        best_model_earlystop.load_state_dict(best_model_dict)

    if early_stopping and ("corr" in likelihood):
        if (best_model_dict_corr is not None):
            best_model_earlystop.correlation.load_state_dict(best_model_dict_corr, strict=False)

    return model, best_model_earlystop, negliks_pred, valid_perfs, valid_nlls
