"""Implementations of alpha update steps for DARTS."""
import torch
from .utils import psnr_compute


def _loss_evaluation(model, inputs, targets):
    """Evaluate loss."""
    outputs, aux_loss = model(inputs, x_true=targets)
    loss = model.criterion(outputs, targets)
    with torch.no_grad():
        psnr = psnr_compute(outputs, targets)

    if aux_loss is not None:
        loss = (loss + aux_loss) / model.layers
        return loss, psnr
    else:
        return loss, psnr


def _fill_alpha_gradients_liu(model, state, data_train, data_meta, eps=1e-2):
    """Compute gradients of alpha via Finite diff. approx. Parameters are kept unchanged."""
    # 1) Hardcopy state dict.
    original_parameters = [p.detach().clone() for p in model.parameters()]
    original_buffers = [b.detach().clone() for b in model.buffers()]

    def _load_old_state(model):
        with torch.no_grad():
            for param, old_state in zip(model.parameters(), original_parameters):
                param.copy_(old_state)
            for buffer, old_state in zip(model.buffers(), original_buffers):
                buffer.copy_(old_state)

    # 2) Fill model with w' = w - xi\nabla_w L_train(w, alpha)
    # Note that we do incur an additional momentum update
    loss, _ = _loss_evaluation(model, *data_train)
    loss.backward()
    state.param_optimizer.step()

    with model.no_sync():
        # 3) Compute \nabla_w' L_val(w', alpha) and \nabla_alpha L_val(w', alpha)
        state.param_optimizer.zero_grad()
        state.alpha_optimizer.zero_grad()
        loss, _ = _loss_evaluation(model, *data_meta)
        loss.backward()  # alpha_grads contain grads w.r.t to L_val(alpha)
        w_dash_grad = [p.grad.detach().clone() for p in model.parameters()]

        # 4) Re-insert w into model, compute finite diff approximation
        eps_n = eps / torch.stack([g.pow(2).sum() for g in w_dash_grad]).sum().sqrt()
        # Positive evaluation
        _load_old_state(model)  # we still have gradients in param.grad but that is ok
        for param, grad in zip(model.parameters(), w_dash_grad):
            param.data += eps_n * grad
        loss, _ = _loss_evaluation(model, *data_train)
        pos_grads = torch.autograd.grad(loss, model.arch_parameters(), only_inputs=True)
        # Negative evaluation (overwriting positive offset)
        for param, grad in zip(model.parameters(), w_dash_grad):
            param.data -= 2 * eps_n * grad
        loss, _ = _loss_evaluation(model, *data_train)
        neg_grads = torch.autograd.grad(loss, model.arch_parameters(), only_inputs=True)
        # Add finite diff. evaluation to alpha_grads
        for alpha, pos_grad, neg_grad in zip(model.arch_parameters(), pos_grads, neg_grads):
            alpha.grad -= (pos_grad - neg_grad) / (2 * eps_n)  # these gradient are scaled

        # 5) Reset model to initial settings:
        _load_old_state(model)


def _fill_alpha_gradients_single(model, state, data_meta):
    """Compute gradients of alpha. Parameters are kept unchanged."""
    # 1) Hardcopy state dict.
    original_parameters = [p.detach().clone() for p in model.parameters()]
    original_buffers = [b.detach().clone() for b in model.buffers()]

    def _load_old_state(model):
        with torch.no_grad():
            for param, old_state in zip(model.parameters(), original_parameters):
                param.copy_(old_state)
            for buffer, old_state in zip(model.buffers(), original_buffers):
                buffer.copy_(old_state)

    with model.no_sync():
        # 2) Compute  \nabla_alpha L_val(w', alpha)
        state.param_optimizer.zero_grad()
        state.alpha_optimizer.zero_grad()
        loss, _ = _loss_evaluation(model, *data_meta)
        # loss.backward() # alpha_grads contain grads w.r.t to L_val(alpha)
        alpha_grads = torch.autograd.grad(loss, model.arch_parameters(), only_inputs=True)

        for alpha, grad in zip(model.arch_parameters(), alpha_grads):
            alpha.grad = grad

        # 3) Reset model to initial settings:
        _load_old_state(model)


def _fill_alpha_gradients_binaryconnect(model, state, data_train):
    """Compute gradients of alpha via binaryConnect approximations."""
    # 1) Hardcopy state dict.
    with model.no_sync():
        original_weights = [a.detach().clone() for a in model.arch_parameters()]
        original_buffers = [b.detach().clone() for b in model.buffers()]

        def _load_old_state(model):
            with torch.no_grad():
                for a, old_state in zip(model.arch_parameters(), original_weights):
                    a.copy_(old_state)
                for buffer, old_state in zip(model.buffers(), original_buffers):
                    buffer.copy_(old_state)

        # 2) Compute forward pass from binary weights
        for alpha in model.arch_parameters():
            alpha.data = torch.zeros_like(alpha).scatter_(
                1, alpha.argmax(dim=-1, keepdim=True), 1).requires_grad_()
        loss, _ = _loss_evaluation(model, *data_train)

        # binary backward step for alpha
        alpha_grads = torch.autograd.grad(loss, model.arch_parameters(), only_inputs=True)
        state.param_optimizer.step()

        # recover old weights and insert gradients
        _load_old_state(model)
        for alpha, grad in zip(model.arch_parameters(), alpha_grads):
            alpha.grad = grad


def _fill_alpha_gradients_higher(model, state, data_train, data_meta, iterations=2):
    """Unroll directly using higher."""
    import higher

    with model.no_sync():
        with higher.innerloop_ctx(model, state.param_optimizer) as (fmodel, fopt):
            for _ in range(iterations):
                loss, _ = _loss_evaluation(fmodel, *data_train)
                fopt.step(loss)

            meta_loss, _ = _loss_evaluation(fmodel, *data_meta)
            alpha_grads = torch.autograd.grad(meta_loss, model.arch_parameters(), only_inputs=True)

        for alpha, grad in zip(model.arch_parameters(), alpha_grads):
            alpha.grad = grad
