import collections
import torch

from backpack import backpack
from backpack.extensions import BatchGrad


def grad_with_backpack(loss, proxy, manipulate_grad=None):
    batch_size = loss.size(0)
    parameters = [p for p in proxy.parameters() if p.requires_grad]
    with backpack(BatchGrad()):
        _ = torch.autograd.grad(loss.sum(), parameters,
                                create_graph=True)
    grads = torch.cat([param.grad_batch.flatten(start_dim=1)
                       for param in proxy.parameters() if param.requires_grad], dim=1)

    if manipulate_grad is not None:
        correction = manipulate_grad(grads, scalar=False)
        grads = grads + correction

    mean_grad = grads.mean(0)
    mean_sq_grad = grads.pow(2).mean(0)
    reg_sq = (mean_grad.pow(2) * batch_size - mean_sq_grad) / (batch_size - 1)

    return mean_grad, reg_sq


def clone_state_dict(state_dict, detach=False):
    new_state_dict = collections.OrderedDict()
    for key, item in state_dict.items():
        if detach:
            new_state_dict[key] = item.clone().detach()
        else:
            new_state_dict[key] = item.clone()
    return new_state_dict


def sq_mean_grad(loss, proxy, manipulate_grad=None):
    batch_size = loss.size(0)
    mean_grad = 0.
    mean_sq_grad = 0.

    # estimate mean and squared_mean
    for i in range(1, batch_size):
        l = loss[i]
        parameters = [p for p in proxy.parameters() if p.requires_grad]
        tmp = torch.autograd.grad(l.sum(), parameters,
                                  create_graph=True)
        tmp = torch.cat([grad.flatten() for grad in tmp])

        if manipulate_grad is not None:
            correction = manipulate_grad(tmp, scalar=False)
            tmp = tmp + correction
        mean_grad = mean_grad + tmp
        mean_sq_grad = mean_sq_grad + tmp.pow(2)

    mean_sq_grad = mean_sq_grad / batch_size
    mean_grad = mean_grad / batch_size
    reg_sq = (mean_grad.pow(2) * batch_size - mean_sq_grad) / (batch_size - 1)

    return reg_sq


def check_proxymodel_valid(proxy_model):
    "Run a try / except to test the given proxy model"
    pass
