import torch
from torch.func import vjp

from core import MonoidReduce, check

def init(a, b):
    x, = a
    p, q = b
    B = x.shape[0]
    Y = q.shape[1]
    return x.new_zeros((B, Y)),

def islice(parts):
    return zip(*(part.chunk(16) for part in parts))

def proj_fold(a, b):
    ax, = a
    bp, bq = b

    x = ax @ bp.t()
    x = torch.nn.functional.relu(x) # <- change this for other activation functions
    x = x @ bq
    return x,

def binary_reduce(x, y):
    x, = x
    y, = y
    return x + y,

def proj_fold_bwd(a, b, p, gp):
    #_, local_vjp = vjp(proj_fold, a, b)
    #return local_vjp(gp)
    ax, = a
    bp, bq = b
    gp, = gp

    xp = torch.nn.functional.relu(ax @ bp.t())
    masked = torch.where(
        xp > 0.0,
        gp @ bq.t(),
        0.0,
    )
    
    gba = masked @ bp
    gbp = masked.t() @ ax
    gbq = xp.t() @ gp

    return (gba,), (gbp, gbq)

MLP = MonoidReduce('MLP', 
                   init=init,
                   avals=1, 
                   bvals=2, 
                   islice=islice, 
                   jslice=islice,
                   proj_fold=proj_fold, 
                   proj_fold_bwd=proj_fold_bwd, 
                   binary_reduce=binary_reduce
                   )

def mlp(x, p, q):
    B, D = x.shape
    H, _D = p.shape
    _H, Y = q.shape

    assert D == _D
    assert H == _H
    pv, = MLP.apply(x, p, q)
    return pv

def naive_mlp(x, p, q):
    x = x @ p.t()
    x = torch.nn.functional.relu(x)
    x = x @ q
    return x
    
if __name__ == '__main__':
    B, D, H, Y = 16*256, 256, 16*256, 256
    dtype = torch.double
    x = torch.randn(B, D, requires_grad=True, dtype=dtype)
    p = torch.randn(H, D, requires_grad=True, dtype=dtype)
    q = torch.randn(H, Y, requires_grad=True, dtype=dtype)

    mock = torch.randn(B, Y, dtype=dtype)
    check(mlp, naive_mlp, (x, p, q), mock)
