import unittest

import backend as F
from test_utils import parametrize_idtype

import dgl
from dgl.dataloading import (
    NeighborSampler,
    as_edge_prediction_sampler,
    negative_sampler,
)


def create_test_graph(idtype):
    # test heterograph from the docstring, plus a user -- wishes -- game relation
    # 3 users, 2 games, 2 developers
    # metagraph:
    #    ('user', 'follows', 'user'),
    #    ('user', 'plays', 'game'),
    #    ('user', 'wishes', 'game'),
    #    ('developer', 'develops', 'game')])

    g = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1], [1, 2]),
            ("user", "plays", "game"): ([0, 1, 2, 1], [0, 0, 1, 1]),
            ("user", "wishes", "game"): ([0, 2], [1, 0]),
            ("developer", "develops", "game"): ([0, 1], [0, 1]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
    assert g.idtype == idtype
    assert g.device == F.ctx()
    return g


@parametrize_idtype
def test_edge_prediction_sampler(idtype):
    g = create_test_graph(idtype)
    sampler = NeighborSampler([10, 10])
    sampler = as_edge_prediction_sampler(
        sampler, negative_sampler=negative_sampler.Uniform(1)
    )

    seeds = F.copy_to(F.arange(0, 2, dtype=idtype), ctx=F.ctx())
    # just a smoke test to make sure we don't fail internal assertions
    result = sampler.sample(g, {"follows": seeds})


if __name__ == "__main__":
    test_edge_prediction_sampler()
