import numpy as np
import scipy.sparse as sp
import dgl
import dgl.function as fn
import backend as F
from test_utils import parametrize_idtype

D = 5

def generate_graph(idtype):
    g = dgl.DGLGraph()
    g = g.astype(idtype).to(F.ctx())
    g.add_nodes(10)
    # create a graph where 0 is the source and 9 is the sink
    for i in range(1, 9):
        g.add_edges(0, i)
        g.add_edges(i, 9)
    # add a back flow from 9 to 0
    g.add_edges(9, 0)
    g.ndata.update({'f1' : F.randn((10,)), 'f2' : F.randn((10, D))})
    weights = F.randn((17,))
    g.edata.update({'e1': weights, 'e2': F.unsqueeze(weights, 1)})
    return g

@parametrize_idtype
def test_v2v_update_all(idtype):
    def _test(fld):
        def message_func(edges):
            return {'m' : edges.src[fld]}

        def message_func_edge(edges):
            if len(edges.src[fld].shape) == 1:
                return {'m' : edges.src[fld] * edges.data['e1']}
            else:
                return {'m' : edges.src[fld] * edges.data['e2']}

        def reduce_func(nodes):
            return {fld : F.sum(nodes.mailbox['m'], 1)}

        def apply_func(nodes):
            return {fld : 2 * nodes.data[fld]}
        g = generate_graph(idtype)
        # update all
        v1 = g.ndata[fld]
        g.update_all(fn.copy_u(u=fld, out='m'), fn.sum(msg='m', out=fld), apply_func)
        v2 = g.ndata[fld]
        g.ndata.update({fld : v1})
        g.update_all(message_func, reduce_func, apply_func)
        v3 = g.ndata[fld]
        assert F.allclose(v2, v3)
        # update all with edge weights
        v1 = g.ndata[fld]
        g.update_all(fn.u_mul_e(fld, 'e1', 'm'),
                fn.sum(msg='m', out=fld), apply_func)
        v2 = g.ndata[fld]
        g.ndata.update({fld : v1})
        g.update_all(message_func_edge, reduce_func, apply_func)
        v4 = g.ndata[fld]
        assert F.allclose(v2, v4)
    # test 1d node features
    _test('f1')
    # test 2d node features
    _test('f2')

@parametrize_idtype
def test_v2v_snr(idtype):
    u = F.tensor([0, 0, 0, 3, 4, 9], idtype)
    v = F.tensor([1, 2, 3, 9, 9, 0], idtype)
    def _test(fld):
        def message_func(edges):
            return {'m' : edges.src[fld]}

        def message_func_edge(edges):
            if len(edges.src[fld].shape) == 1:
                return {'m' : edges.src[fld] * edges.data['e1']}
            else:
                return {'m' : edges.src[fld] * edges.data['e2']}

        def reduce_func(nodes):
            return {fld : F.sum(nodes.mailbox['m'], 1)}

        def apply_func(nodes):
            return {fld : 2 * nodes.data[fld]}
        g = generate_graph(idtype)
        # send and recv
        v1 = g.ndata[fld]
        g.send_and_recv((u, v), fn.copy_u(u=fld, out='m'),
                fn.sum(msg='m', out=fld), apply_func)
        v2 = g.ndata[fld]
        g.ndata.update({fld : v1})
        g.send_and_recv((u, v), message_func, reduce_func, apply_func)
        v3 = g.ndata[fld]
        assert F.allclose(v2, v3)
        # send and recv with edge weights
        v1 = g.ndata[fld]
        g.send_and_recv((u, v), fn.u_mul_e(fld, 'e1', 'm'),
                fn.sum(msg='m', out=fld), apply_func)
        v2 = g.ndata[fld]
        g.ndata.update({fld : v1})
        g.send_and_recv((u, v), message_func_edge, reduce_func, apply_func)
        v4 = g.ndata[fld]
        assert F.allclose(v2, v4)
    # test 1d node features
    _test('f1')
    # test 2d node features
    _test('f2')


@parametrize_idtype
def test_v2v_pull(idtype):
    nodes = F.tensor([1, 2, 3, 9], idtype)
    def _test(fld):
        def message_func(edges):
            return {'m' : edges.src[fld]}

        def message_func_edge(edges):
            if len(edges.src[fld].shape) == 1:
                return {'m' : edges.src[fld] * edges.data['e1']}
            else:
                return {'m' : edges.src[fld] * edges.data['e2']}

        def reduce_func(nodes):
            return {fld : F.sum(nodes.mailbox['m'], 1)}

        def apply_func(nodes):
            return {fld : 2 * nodes.data[fld]}
        g = generate_graph(idtype)
        # send and recv
        v1 = g.ndata[fld]
        g.pull(nodes, fn.copy_u(u=fld, out='m'), fn.sum(msg='m', out=fld), apply_func)
        v2 = g.ndata[fld]
        g.ndata[fld] = v1
        g.pull(nodes, message_func, reduce_func, apply_func)
        v3 = g.ndata[fld]
        assert F.allclose(v2, v3)
        # send and recv with edge weights
        v1 = g.ndata[fld]
        g.pull(nodes, fn.u_mul_e(fld, 'e1', 'm'),
                fn.sum(msg='m', out=fld), apply_func)
        v2 = g.ndata[fld]
        g.ndata[fld] = v1
        g.pull(nodes, message_func_edge, reduce_func, apply_func)
        v4 = g.ndata[fld]
        assert F.allclose(v2, v4)
    # test 1d node features
    _test('f1')
    # test 2d node features
    _test('f2')

@parametrize_idtype
def test_update_all_multi_fallback(idtype):
    # create a graph with zero in degree nodes
    g = dgl.DGLGraph()
    g = g.astype(idtype).to(F.ctx())
    g.add_nodes(10)
    for i in range(1, 9):
        g.add_edges(0, i)
        g.add_edges(i, 9)
    g.ndata['h'] = F.randn((10, D))
    g.edata['w1'] = F.randn((16,))
    g.edata['w2'] = F.randn((16, D))
    def _mfunc_hxw1(edges):
        return {'m1' : edges.src['h'] * F.unsqueeze(edges.data['w1'], 1)}
    def _mfunc_hxw2(edges):
        return {'m2' : edges.src['h'] * edges.data['w2']}
    def _rfunc_m1(nodes):
        return {'o1' : F.sum(nodes.mailbox['m1'], 1)}
    def _rfunc_m2(nodes):
        return {'o2' : F.sum(nodes.mailbox['m2'], 1)}
    def _rfunc_m1max(nodes):
        return {'o3' : F.max(nodes.mailbox['m1'], 1)}
    def _afunc(nodes):
        ret = {}
        for k, v in nodes.data.items():
            if k.startswith('o'):
                ret[k] = 2 * v
        return ret
    # compute ground truth
    g.update_all(_mfunc_hxw1, _rfunc_m1, _afunc)
    o1 = g.ndata.pop('o1')
    g.update_all(_mfunc_hxw2, _rfunc_m2, _afunc)
    o2 = g.ndata.pop('o2')
    g.update_all(_mfunc_hxw1, _rfunc_m1max, _afunc)
    o3 = g.ndata.pop('o3')
    # v2v spmv
    g.update_all(fn.u_mul_e('h', 'w1', 'm1'),
                 fn.sum(msg='m1', out='o1'),
                 _afunc)
    assert F.allclose(o1, g.ndata.pop('o1'))
    # v2v fallback to e2v
    g.update_all(fn.u_mul_e('h', 'w2', 'm2'),
                 fn.sum(msg='m2', out='o2'),
                 _afunc)
    assert F.allclose(o2, g.ndata.pop('o2'))

@parametrize_idtype
def test_pull_multi_fallback(idtype):
    # create a graph with zero in degree nodes
    g = dgl.DGLGraph()
    g = g.astype(idtype).to(F.ctx())
    g.add_nodes(10)
    for i in range(1, 9):
        g.add_edges(0, i)
        g.add_edges(i, 9)
    g.ndata['h'] = F.randn((10, D))
    g.edata['w1'] = F.randn((16,))
    g.edata['w2'] = F.randn((16, D))
    def _mfunc_hxw1(edges):
        return {'m1' : edges.src['h'] * F.unsqueeze(edges.data['w1'], 1)}
    def _mfunc_hxw2(edges):
        return {'m2' : edges.src['h'] * edges.data['w2']}
    def _rfunc_m1(nodes):
        return {'o1' : F.sum(nodes.mailbox['m1'], 1)}
    def _rfunc_m2(nodes):
        return {'o2' : F.sum(nodes.mailbox['m2'], 1)}
    def _rfunc_m1max(nodes):
        return {'o3' : F.max(nodes.mailbox['m1'], 1)}
    def _afunc(nodes):
        ret = {}
        for k, v in nodes.data.items():
            if k.startswith('o'):
                ret[k] = 2 * v
        return ret
    # nodes to pull
    def _pull_nodes(nodes):
        # compute ground truth
        g.pull(nodes, _mfunc_hxw1, _rfunc_m1, _afunc)
        o1 = g.ndata.pop('o1')
        g.pull(nodes, _mfunc_hxw2, _rfunc_m2, _afunc)
        o2 = g.ndata.pop('o2')
        g.pull(nodes, _mfunc_hxw1, _rfunc_m1max, _afunc)
        o3 = g.ndata.pop('o3')
        # v2v spmv
        g.pull(nodes, fn.u_mul_e('h', 'w1', 'm1'),
                     fn.sum(msg='m1', out='o1'),
                     _afunc)
        assert F.allclose(o1, g.ndata.pop('o1'))
        # v2v fallback to e2v
        g.pull(nodes, fn.u_mul_e('h', 'w2', 'm2'),
                     fn.sum(msg='m2', out='o2'),
                     _afunc)
        assert F.allclose(o2, g.ndata.pop('o2'))
    # test#1: non-0deg nodes
    nodes = [1, 2, 9]
    _pull_nodes(nodes)
    # test#2: 0deg nodes + non-0deg nodes
    nodes = [0, 1, 2, 9]
    _pull_nodes(nodes)

@parametrize_idtype
def test_spmv_3d_feat(idtype):
    def src_mul_edge_udf(edges):
        return {'sum': edges.src['h'] * F.unsqueeze(F.unsqueeze(edges.data['h'], 1), 1)}

    def sum_udf(nodes):
        return {'h': F.sum(nodes.mailbox['sum'], 1)}

    n = 100
    p = 0.1
    a = sp.random(n, n, p, data_rvs=lambda n: np.ones(n))
    g = dgl.DGLGraph(a)
    g = g.astype(idtype).to(F.ctx())
    m = g.number_of_edges()

    # test#1: v2v with adj data
    h = F.randn((n, 5, 5))
    e = F.randn((m,))

    g.ndata['h'] = h
    g.edata['h'] = e
    g.update_all(message_func=fn.u_mul_e('h', 'h', 'sum'), reduce_func=fn.sum('sum', 'h')) # 1
    ans = g.ndata['h']

    g.ndata['h'] = h
    g.edata['h'] = e
    g.update_all(message_func=src_mul_edge_udf, reduce_func=fn.sum('sum', 'h')) # 2
    assert F.allclose(g.ndata['h'], ans)

    g.ndata['h'] = h
    g.edata['h'] = e
    g.update_all(message_func=src_mul_edge_udf, reduce_func=sum_udf) # 3
    assert F.allclose(g.ndata['h'], ans)

    # test#2: e2v
    def src_mul_edge_udf(edges):
        return {'sum': edges.src['h'] * edges.data['h']}

    h = F.randn((n, 5, 5))
    e = F.randn((m, 5, 5))

    g.ndata['h'] = h
    g.edata['h'] = e
    g.update_all(message_func=fn.u_mul_e('h', 'h', 'sum'), reduce_func=fn.sum('sum', 'h')) # 1
    ans = g.ndata['h']

    g.ndata['h'] = h
    g.edata['h'] = e
    g.update_all(message_func=src_mul_edge_udf, reduce_func=fn.sum('sum', 'h')) # 2
    assert F.allclose(g.ndata['h'], ans)

    g.ndata['h'] = h
    g.edata['h'] = e
    g.update_all(message_func=src_mul_edge_udf, reduce_func=sum_udf) # 3
    assert F.allclose(g.ndata['h'], ans)

if __name__ == '__main__':
    test_v2v_update_all()
    test_v2v_snr()
    test_v2v_pull()
    test_v2v_update_all_multi_fn()
    test_v2v_snr_multi_fn()
    test_e2v_update_all_multi_fn()
    test_e2v_snr_multi_fn()
    test_e2v_recv_multi_fn()
    test_update_all_multi_fallback()
    test_pull_multi_fallback()
    test_spmv_3d_feat()
