import torch

from functorch import jacrev, vmap, grad

def compute_dfdx(x, model):
    def dfdx(coords):
        out_size = model.net[-1].out_features
        g0 = coords.new(coords.shape[0], out_size).fill_(1.)
        
        _, cache = model.forward(coords, keep_cache=True)    
        dfdx, _ = model.reverse(g0, cache)

        return dfdx
    
    return dfdx(x)

def mag_hessian(x, model):
    g0 = x.new(1,1).fill_(1.)

    def mags(coords):
        _, cache = model.forward(coords, keep_cache=True, detach=False)
        _, dfdzs = model.reverse(g0, cache, detach=False)

        mag = 0
        hs = [coords] + cache['hs'][:-1]
        for i, (_h, _fz) in enumerate(zip(hs, dfdzs)):
            mag_dfdz = (_fz ** 2).sum(-1)
            mag_h = (_h ** 2).sum(-1)
            mag += mag_dfdz * (1 + mag_h)
            
        return mag.squeeze()

    H = vmap(jacrev(grad(mags)), in_dims=(0,))(x) #[N,2]
    
    return H.cpu().data.numpy()


def compute_d2fdzdx(x, model):
    g0 = x.new(1,1).fill_(1.)

    def dfdz(coords):
        out_size = model.net[-1].out_features
        
        _, cache = model.forward(coords, keep_cache=True, detach=False)
        _, dfdzs = model.reverse(g0, cache, detach=False)
    
        if coords.ndim == 1:
            dfdzs = [g.squeeze(0) for g in dfdzs]
        
        return tuple(dfdzs)

    d2fdzdx = vmap(jacrev(dfdz))(x)
    d2fdzdx = [g.cpu().data for g in d2fdzdx]
    return d2fdzdx

def compute_mag_and_derivatives(x, model):
    g0 = x.new(1,1).fill_(1.)

    def mags(coords):
        out_size = model.net[-1].out_features
        
        _, cache = model.forward(coords, keep_cache=True, detach=False)
        _, dfdzs = model.reverse(g0, cache, detach=False)

        hmags = [coords] + cache['hs'][:-1]
        hmags = tuple([_h.norm(p=2, dim=-1)**2 for _h in hmags])
        
        
        if coords.ndim == 1:
            dfdzs = [g.squeeze(0) for g in dfdzs]

        #hs_mags
        dfmags = tuple([_df.norm(p=2, dim=-1)**2 for _df in dfdzs])
        
        return hmags, dfmags

    hmags, dfmags = vmap(mags)(x)
    dhm_dx, dfm_dx = vmap(jacrev(mags))(x)

    _out = {
        'hmags' : hmags,
        'dfmags' : dfmags,
        'dhm_dx' : dhm_dx,
        'dfm_dx' : dfm_dx
    }
    for k, v in _out.items():
        _out[k] = [_v.cpu().data for _v in v]
    
    return _out

def compute_dhdx(x, model):

    def hs(coords):
        _, cache = model.forward(coords, keep_cache=True, detach=False)
        _hs = [coords] + cache['hs'][:-1]
    
        if coords.ndim == 1:
            _hs = [g.squeeze(0) for g in _hs]
        
        return tuple(_hs)

    dhdx = vmap(jacrev(hs))(x)
    dhdx = [g.cpu().data for g in dhdx]
    return dhdx

def gaussian_approx(x, model):
    mag_dict = compute_mag_and_derivatives(x, model)
    d2fdzdxs = compute_d2fdzdx(x, model)
    dhdxs = compute_dhdx(x, model)
    
    n_layers = len(dhdxs)
    info = {
        'H' : 0,
        'a2' : 0,
        'D' : 0
    }
    for i in range(n_layers):
        mag_h, mag_dfdz = mag_dict['hmags'][i], mag_dict['dfmags'][i]
        mag_theta = mag_dfdz * (1 + mag_h)
        
        dhdx, d2f_dxdz = dhdxs[i], d2fdzdxs[i]
        dmag_h, dmag_fz = mag_dict['dhm_dx'][i], mag_dict['dfm_dx'][i]
        
        H1 = torch.bmm(d2f_dxdz.permute(0,2,1), d2f_dxdz) * (1 + mag_h[:, None, None])
        H2 = torch.bmm(dhdx.permute(0,2,1), dhdx) * mag_dfdz[:, None, None]
        cross = torch.bmm(dmag_h[:, :, None], dmag_fz[:, None, :])
        cross = cross + cross.permute(0, 2, 1)
        
        _H = H1 + H2 + 0.25 * cross
        #_H = H1
        _a2 = mag_theta
        #_D = torch_grad(mag_theta, coord_range)
        _D = (1 + mag_h[:, None]) * dmag_fz + dmag_h * mag_dfdz[:, None]

        _H = _H.cpu().numpy()
        _a2 = _a2.cpu().numpy()
        _D = _D.cpu().numpy()

        info['H'] += _H
        info['a2'] += _a2
        info['D'] += _D

    return info