import torch
from torch.autograd.variable import Variable
import torch.nn

def get_SI_omega(model, W, epsilon, old_P):
    Omega = []
    for i, p in enumerate(model.parameters()):
        if p.requires_grad:
            p_current = p.detach().clone()
            p_change = p_current - old_P[i]
            omega_add = W[i]/(p_change**2 + epsilon)
            Omega.append(omega_add)
        else:
            Omega.append(torch.zeros_like(p.data))
    return Omega

def get_grads(model, dataloader, device, criterion):
    model.eval()
    model.to(device)
    Gs = [torch.zeros_like(p) for p in model.parameters()]
    for X, Y in dataloader:
        X, Y = X.to(device), Y.to(device)
        X, Y = Variable(X), Variable(Y)
        model.zero_grad()
        out = model(X)
        loss = criterion(out, Y)
        loss.backward()
        Gs = [g + p.grad.clone().detach() / len(dataloader) if p.grad is not None else g + 0.0 for g, p in zip(Gs, model.parameters())]
    return Gs

def get_diag_fisher(model, dataloader, device, criterion):
    model.eval()
    diag_fisher = [torch.zeros_like(p) for p in model.parameters()]
    num_data = 0
    for inputs, targets in dataloader:
        model.zero_grad()
        tmp_num_data = inputs.size(0)
        outputs = model(inputs.to(device))
        loss = criterion(outputs, targets.to(device))
        loss.backward()
        params, gradsH = get_params_grad(model)
        model.zero_grad()
        diag_fisher = [d + (g ** 2) * tmp_num_data for d, g in zip(diag_fisher, gradsH)]
        num_data = num_data + tmp_num_data
    diag_fisher = [d / num_data for d in diag_fisher]
    return diag_fisher
    

def get_diag_hessian(model, dataloader, device, criterion):
    model.eval()
    hutchinson_trace = [torch.zeros_like(p) for p in model.parameters()]
    Gs = [torch.zeros_like(p) for p in model.parameters()]
    total_number = 0
    for data in dataloader:
        v = [2 * torch.randint_like(p, high=2) - 1 for p in model.parameters()]
        hvs, gs = dataloader_hv_product(v, [data], model, criterion, device)
        temp_hutchinson_trace = []
        for hv in hvs:
            param_size = hv.size()
            if len(param_size) <= 2:  
                tmp_output = hv.abs()
            elif len(param_size) == 4:  
                tmp_output = torch.mean(hv.abs(), dim=[2, 3], keepdim=True)
            temp_hutchinson_trace.append(tmp_output)
        for i in range(len(temp_hutchinson_trace)):
            hutchinson_trace[i] = hutchinson_trace[i] + temp_hutchinson_trace[i] * data[0].size(0)
            Gs[i] = Gs[i] + gs[i] * data[0].size(0)
        total_number += data[0].size(0)
    if total_number > 0:
        for i in range(len(hutchinson_trace)):
            hutchinson_trace[i] = hutchinson_trace[i] / total_number
            Gs[i] = Gs[i] / total_number

    return hutchinson_trace, Gs

def dataloader_hv_product(v, data, model, criterion, device):

    num_data = 0  # count the number of datum points in the dataloader

    THv = [torch.zeros(p.size()).to(device) for p in model.parameters()
            ]  # accumulate result
    TG = [torch.zeros(p.size()).to(device) for p in model.parameters()
            ]
    for inputs, targets in data:
        model.zero_grad()
        tmp_num_data = inputs.size(0)
        outputs = model(inputs.to(device))
        loss = criterion(outputs, targets.to(device))
        loss.backward(create_graph=True)
        params, gradsH = get_params_grad(model)
        model.zero_grad()
        Hv = torch.autograd.grad(gradsH,
                                    params,
                                    grad_outputs=v,
                                    only_inputs=True,
                                    retain_graph=False)
        THv = [
            THv1 + Hv1 * float(tmp_num_data) + 0.
            for THv1, Hv1 in zip(THv, Hv)
        ]
        TG = [
            TG1 + G1 * float(tmp_num_data) + 0.
            for TG1, G1 in zip(TG, gradsH)
        ]
        num_data += float(tmp_num_data)

    THv = [THv1 / float(num_data) for THv1 in THv]
    TG = [TG1 / float(num_data) for TG1 in TG]
    return THv, TG

def get_params_grad(model):
    """
    get model parameters and corresponding gradients
    """
    params = []
    grads = []
    for param in model.parameters():
        if not param.requires_grad:
            continue
        params.append(param)
        grads.append(0. if param.grad is None else param.grad + 0.)
    return params, grads

def copy_hessian(hessian, new_hessian, s, alpha=0.1):
    for i in range(len(hessian)):
        hessian[i] = alpha * hessian[i] + new_hessian[i] / s
    return hessian

def combine_hessian(hessian, new_hessian, L):
    for i in range(len(hessian)):
        hessian[i] += new_hessian[i] * L
    return hessian

