import copy
import torch

class Architecture():
    """ Compute gradients of alphas """
    def __init__(self, net, w_momentum, w_weight_decay):
        '''
        Args:
            net
            w_momentum: weights momentum
        '''
        self.net = net
        self.v_net = copy.deepcopy(net)
        self.w_momentum = w_momentum
        self.w_weight_decay = w_weight_decay

    def virtual_step(self, trn_X, trn_y, xi, w_optim):
        loss = self.cal_loss_net(trn_X, trn_y)
        dwLtrain = torch.autograd.grad(loss, self.net.weights())
        # w' = w - ξ * dw Ltrain(w, α)
        with torch.no_grad():
            for w, vw, g in zip(self.net.weights(), self.v_net.weights(), dwLtrain):
                # m = v * w_momentum
                m = w_optim.state[w].get('momentum_buffer', 0.) * self.w_momentum
                # w' = w - ξ * (m + dw Ltrain(w, α) + regu )
                vw.copy_(w - xi * (m + g + self.w_weight_decay * w))
            for a, va in zip(self.net.alphas(), self.v_net.alphas()):
                va.copy_(a)

    def unrolled_backward(self, trn_X, trn_y, val_X, val_y, xi, w_optim):

        # do virtual step (calc w`)
        self.virtual_step(trn_X, trn_y, xi, w_optim)

        # calc unrolled loss
        loss = self.cal_loss_v_net(val_X, val_y)

        # compute gradient
        v_alphas = tuple(self.v_net.alphas())
        v_weights = tuple(self.v_net.weights())
        v_grads = torch.autograd.grad(loss, v_alphas + v_weights)
        dalpha = v_grads[:len(v_alphas)]  # dα L_val(w', α)
        dw = v_grads[len(v_alphas):]  # dw' L_val(w', α)
        hessian = self.compute_hessian(dw, trn_X, trn_y)

        # update final gradient = dalpha - xi*hessian
        with torch.no_grad():
            for alpha, da, h in zip(self.net.alphas(), dalpha, hessian):
                alpha.grad = da - xi * h

    def cal_loss_net(self, X, Y):
        """ Compute loss on multiple steps """
        if X.dim() == 4:
            return self.net.loss(X, Y)
        else: # X.dim() == 5 for event-based data
            timestep = X.size(0)
            loss = 0
            for t in range(timestep):
                loss += self.net.loss(X[t], Y)
            return loss / timestep

    def cal_loss_v_net(self, X, Y):
        """ Compute loss on multiple steps """
        if X.dim() == 4:
            return self.v_net.loss(X, Y)
        else: # X.dim() == 5 for event-based data
            timestep = X.size(0)
            loss = 0
            for t in range(timestep):
                loss += self.v_net.loss(X[t], Y)
            return loss / timestep

    def compute_hessian(self, dw, trn_X, trn_y):
        """
        dw = dw` { L_val(w`, alpha) }
        w+ = w + eps * dw
        w- = w - eps * dw
        hessian = (dalpha { L_trn(w+, alpha) } - dalpha { L_trn(w-, alpha) }) / (2*eps)    [1]
        eps = 0.01 / ||dw||
        """
        norm = torch.cat([w.view(-1) for w in dw]).norm()
        eps = 0.01 / norm

        # w+ = w + eps * dw`
        with torch.no_grad():
            for p, d in zip(self.net.weights(), dw):
                p += eps * d
        loss = self.cal_loss_net(trn_X, trn_y)  # L_trn(w+)
        dalpha_pos = torch.autograd.grad(loss, self.net.alphas())  # dalpha { L_trn(w+) }

        # w- = w - eps * dw`
        with torch.no_grad():
            for p, d in zip(self.net.weights(), dw):
                p -= 2. * eps * d  #  w- = w - eps * dw = w+ - eps * dw * 2
        loss = self.cal_loss_net(trn_X, trn_y)  # L_trn(w-)
        dalpha_neg = torch.autograd.grad(loss, self.net.alphas())  # dalpha { L_trn(w-) }

        # recover w
        with torch.no_grad():
            for p, d in zip(self.net.weights(), dw):
                p += eps * d  # w,  w = w- + eps * dw

        hessian = [(p - n) / 2. * eps for p, n in zip(dalpha_pos, dalpha_neg)]
        return hessian

