import itertools
import unittest
from collections import Counter
from itertools import product

import backend as F
import networkx as nx
import numpy as np
import pytest
import scipy.sparse as ssp
import test_utils
from scipy.sparse import rand
from test_utils import get_cases, parametrize_idtype

import dgl
import dgl.function as fn
from dgl import DGLError

rfuncs = {"sum": fn.sum, "max": fn.max, "min": fn.min, "mean": fn.mean}
feat_size = 2


@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
def create_test_heterograph(idtype):
    # test heterograph from the docstring, plus a user -- wishes -- game relation
    # 3 users, 2 games, 2 developers
    # metagraph:
    #    ('user', 'follows', 'user'),
    #    ('user', 'plays', 'game'),
    #    ('user', 'wishes', 'game'),
    #    ('developer', 'develops', 'game')])

    g = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1, 2, 1], [0, 0, 1, 1]),
            ("user", "plays", "game"): ([0, 1, 2, 1], [0, 0, 1, 1]),
            ("user", "wishes", "game"): ([0, 1, 1], [0, 0, 1]),
            ("developer", "develops", "game"): ([0, 1, 0], [0, 1, 1]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
    assert g.idtype == idtype
    assert g.device == F.ctx()
    return g


def create_test_heterograph_2(idtype):

    src = np.random.randint(0, 50, 25)
    dst = np.random.randint(0, 50, 25)
    src1 = np.random.randint(0, 25, 10)
    dst1 = np.random.randint(0, 25, 10)
    src2 = np.random.randint(0, 100, 1000)
    dst2 = np.random.randint(0, 100, 1000)
    g = dgl.heterograph(
        {
            ("user", "becomes", "player"): (src, dst),
            ("user", "follows", "user"): (src, dst),
            ("user", "plays", "game"): (src, dst),
            ("user", "wishes", "game"): (src1, dst1),
            ("developer", "develops", "game"): (src2, dst2),
        },
        idtype=idtype,
        device=F.ctx(),
    )
    assert g.idtype == idtype
    assert g.device == F.ctx()
    return g


def create_test_heterograph_large(idtype):

    src = np.random.randint(0, 50, 2500)
    dst = np.random.randint(0, 50, 2500)
    g = dgl.heterograph(
        {
            ("user", "follows", "user"): (src, dst),
            ("user", "plays", "game"): (src, dst),
            ("user", "wishes", "game"): (src, dst),
            ("developer", "develops", "game"): (src, dst),
        },
        idtype=idtype,
        device=F.ctx(),
    )
    assert g.idtype == idtype
    assert g.device == F.ctx()
    return g


@parametrize_idtype
def test_unary_copy_u(idtype):
    def _test(mfunc, rfunc):
        g = create_test_heterograph_2(idtype)
        g0 = create_test_heterograph(idtype)
        g1 = create_test_heterograph_large(idtype)
        cross_reducer = rfunc.__name__
        x1 = F.randn((g.num_nodes("user"), feat_size))
        x2 = F.randn((g.num_nodes("developer"), feat_size))
        F.attach_grad(x1)
        F.attach_grad(x2)
        g.nodes["user"].data["h"] = x1
        g.nodes["developer"].data["h"] = x2

        #################################################################
        #  multi_update_all(): call msg_passing separately for each etype
        #################################################################

        with F.record_grad():
            g.multi_update_all(
                {
                    etype: (mfunc("h", "m"), rfunc("m", "y"))
                    for etype in g.canonical_etypes
                },
                cross_reducer,
            )
            r1 = g.nodes["game"].data["y"].clone()
            r2 = g.nodes["user"].data["y"].clone()
            r3 = g.nodes["player"].data["y"].clone()
            loss = r1.sum() + r2.sum() + r3.sum()
            F.backward(loss)
            n_grad1 = F.grad(g.nodes["user"].data["h"]).clone()
            n_grad2 = F.grad(g.nodes["developer"].data["h"]).clone()

        g.nodes["user"].data.clear()
        g.nodes["developer"].data.clear()
        g.nodes["game"].data.clear()
        g.nodes["player"].data.clear()

        #################################################################
        #  update_all(): call msg_passing for all etypes
        #################################################################

        F.attach_grad(x1)
        F.attach_grad(x2)
        g.nodes["user"].data["h"] = x1
        g.nodes["developer"].data["h"] = x2

        with F.record_grad():
            g.update_all(mfunc("h", "m"), rfunc("m", "y"))
            r4 = g.nodes["game"].data["y"]
            r5 = g.nodes["user"].data["y"]
            r6 = g.nodes["player"].data["y"]
            loss = r4.sum() + r5.sum() + r6.sum()
            F.backward(loss)
            n_grad3 = F.grad(g.nodes["user"].data["h"])
            n_grad4 = F.grad(g.nodes["developer"].data["h"])

        assert F.allclose(r1, r4)
        assert F.allclose(r2, r5)
        assert F.allclose(r3, r6)
        assert F.allclose(n_grad1, n_grad3)
        assert F.allclose(n_grad2, n_grad4)

    _test(fn.copy_u, fn.sum)
    _test(fn.copy_u, fn.max)
    _test(fn.copy_u, fn.min)
    # _test('copy_u', 'mean')


@parametrize_idtype
def test_unary_copy_e(idtype):
    def _test(mfunc, rfunc):

        g = create_test_heterograph_large(idtype)
        g0 = create_test_heterograph_2(idtype)
        g1 = create_test_heterograph(idtype)
        cross_reducer = rfunc.__name__
        x1 = F.randn((g.num_edges("plays"), feat_size))
        x2 = F.randn((g.num_edges("follows"), feat_size))
        x3 = F.randn((g.num_edges("develops"), feat_size))
        x4 = F.randn((g.num_edges("wishes"), feat_size))
        F.attach_grad(x1)
        F.attach_grad(x2)
        F.attach_grad(x3)
        F.attach_grad(x4)
        g["plays"].edata["eid"] = x1
        g["follows"].edata["eid"] = x2
        g["develops"].edata["eid"] = x3
        g["wishes"].edata["eid"] = x4

        #################################################################
        #  multi_update_all(): call msg_passing separately for each etype
        #################################################################

        with F.record_grad():
            g.multi_update_all(
                {
                    "plays": (mfunc("eid", "m"), rfunc("m", "y")),
                    "follows": (mfunc("eid", "m"), rfunc("m", "y")),
                    "develops": (mfunc("eid", "m"), rfunc("m", "y")),
                    "wishes": (mfunc("eid", "m"), rfunc("m", "y")),
                },
                cross_reducer,
            )
            r1 = g.nodes["game"].data["y"].clone()
            r2 = g.nodes["user"].data["y"].clone()
            loss = r1.sum() + r2.sum()
            F.backward(loss)
            e_grad1 = F.grad(g["develops"].edata["eid"]).clone()
            e_grad2 = F.grad(g["plays"].edata["eid"]).clone()
            e_grad3 = F.grad(g["wishes"].edata["eid"]).clone()
            e_grad4 = F.grad(g["follows"].edata["eid"]).clone()
        {etype: (g[etype].edata.clear()) for _, etype, _ in g.canonical_etypes},

        #################################################################
        #  update_all(): call msg_passing for all etypes
        #################################################################

        # TODO(Israt): output type can be None in multi_update and empty
        F.attach_grad(x1)
        F.attach_grad(x2)
        F.attach_grad(x3)
        F.attach_grad(x4)

        g["plays"].edata["eid"] = x1
        g["follows"].edata["eid"] = x2
        g["develops"].edata["eid"] = x3
        g["wishes"].edata["eid"] = x4

        with F.record_grad():
            g.update_all(mfunc("eid", "m"), rfunc("m", "y"))
            r3 = g.nodes["game"].data["y"]
            r4 = g.nodes["user"].data["y"]
            loss = r3.sum() + r4.sum()
            F.backward(loss)
            e_grad5 = F.grad(g["develops"].edata["eid"])
            e_grad6 = F.grad(g["plays"].edata["eid"])
            e_grad7 = F.grad(g["wishes"].edata["eid"])
            e_grad8 = F.grad(g["follows"].edata["eid"])
        # # correctness check
        def _print_error(a, b):
            for i, (x, y) in enumerate(
                zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())
            ):
                if not np.allclose(x, y):
                    print("@{} {} v.s. {}".format(i, x, y))

        assert F.allclose(r1, r3)
        assert F.allclose(r2, r4)
        assert F.allclose(e_grad1, e_grad5)
        assert F.allclose(e_grad2, e_grad6)
        assert F.allclose(e_grad3, e_grad7)
        assert F.allclose(e_grad4, e_grad8)

    _test(fn.copy_e, fn.sum)
    _test(fn.copy_e, fn.max)
    _test(fn.copy_e, fn.min)
    # _test('copy_e', 'mean')


@parametrize_idtype
def test_binary_op(idtype):
    def _test(lhs, rhs, binary_op, reducer):

        g = create_test_heterograph(idtype)

        x1 = F.randn((g.num_nodes("user"), feat_size))
        x2 = F.randn((g.num_nodes("developer"), feat_size))
        x3 = F.randn((g.num_nodes("game"), feat_size))

        F.attach_grad(x1)
        F.attach_grad(x2)
        F.attach_grad(x3)
        g.nodes["user"].data["h"] = x1
        g.nodes["developer"].data["h"] = x2
        g.nodes["game"].data["h"] = x3

        x1 = F.randn((4, feat_size))
        x2 = F.randn((4, feat_size))
        x3 = F.randn((3, feat_size))
        x4 = F.randn((3, feat_size))
        F.attach_grad(x1)
        F.attach_grad(x2)
        F.attach_grad(x3)
        F.attach_grad(x4)
        g["plays"].edata["h"] = x1
        g["follows"].edata["h"] = x2
        g["develops"].edata["h"] = x3
        g["wishes"].edata["h"] = x4

        builtin_msg_name = "{}_{}_{}".format(lhs, binary_op, rhs)
        builtin_msg = getattr(fn, builtin_msg_name)
        builtin_red = getattr(fn, reducer)

        #################################################################
        #  multi_update_all(): call msg_passing separately for each etype
        #################################################################

        with F.record_grad():
            g.multi_update_all(
                {
                    etype: (builtin_msg("h", "h", "m"), builtin_red("m", "y"))
                    for etype in g.canonical_etypes
                },
                "sum",
            )
            r1 = g.nodes["game"].data["y"]
            F.backward(r1, F.ones(r1.shape))
            n_grad1 = F.grad(r1)

        #################################################################
        #  update_all(): call msg_passing for all etypes
        #################################################################

        g.update_all(builtin_msg("h", "h", "m"), builtin_red("m", "y"))
        r2 = g.nodes["game"].data["y"]
        F.backward(r2, F.ones(r2.shape))
        n_grad2 = F.grad(r2)
        # correctness check
        def _print_error(a, b):
            for i, (x, y) in enumerate(
                zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())
            ):
                if not np.allclose(x, y):
                    print("@{} {} v.s. {}".format(i, x, y))

        if not F.allclose(r1, r2):
            _print_error(r1, r2)
        assert F.allclose(r1, r2)
        # TODO (Israt): r1 and r2 have different frad func associated with
        # if not F.allclose(n_grad1, n_grad2):
        #     print('node grad')
        #     _print_error(n_grad1, n_grad2)
        # assert(F.allclose(n_grad1, n_grad2))

    target = ["u", "v", "e"]
    for lhs, rhs in product(target, target):
        if lhs == rhs:
            continue
        for binary_op in ["add", "sub", "mul", "div"]:
            # TODO(Israt) :Add support for reduce func "max", "min", "mean"
            for reducer in ["sum"]:
                print(lhs, rhs, binary_op, reducer)
                _test(lhs, rhs, binary_op, reducer)


if __name__ == "__main__":
    test_unary_copy_u()
    test_unary_copy_e()
    test_binary_op()
