import random

import backend as F
import numpy as np
import pytest
import torch
from test_utils import parametrize_idtype

import dgl

random.seed(42)
np.random.seed(42)
dgl.seed(42)
torch.random.manual_seed(42)


@parametrize_idtype
@pytest.mark.parametrize("feat_size", [(5,), ()])
def test_copy_u(idtype, feat_size):
    ctx = F.ctx()
    g = dgl.rand_graph(30, 100)
    g = g.astype(idtype).to(ctx)
    x = torch.randn((g.num_nodes(),) + feat_size, requires_grad=True, device=ctx)

    y = dgl.copy_u(g, x)
    y.sum().backward()
    x_grad = x.grad

    x.grad.zero_()
    u, v = g.edges()
    y_true = x[u.long()]
    y_true.sum().backward()
    x_grad_true = x.grad

    assert torch.allclose(y, y_true)
    assert torch.allclose(x_grad, x_grad_true)


@parametrize_idtype
@pytest.mark.parametrize("feat_size", [(5,), ()])
def test_copy_u_hetero(idtype, feat_size):
    ctx = F.ctx()
    hg = dgl.heterograph(
        {
            ("user", "follow", "user"): ([0, 1, 2], [2, 3, 4]),
            ("user", "like", "movie"): ([3, 3, 1, 2], [0, 0, 1, 1]),
        }
    )

    hg = hg.astype(idtype).to(ctx)
    x = torch.randn((hg.num_nodes("user"),) + feat_size, requires_grad=True, device=ctx)

    y = dgl.copy_u(hg, x, etype="like")
    y.sum().backward()
    x_grad = x.grad

    x.grad.zero_()
    u, v = hg.edges(etype="like")
    y_true = x[u.long()]
    y_true.sum().backward()
    x_grad_true = x.grad

    assert torch.allclose(y, y_true)
    assert torch.allclose(x_grad, x_grad_true)


@parametrize_idtype
@pytest.mark.parametrize("feat_size", [(5,), ()])
def test_copy_v(idtype, feat_size):
    ctx = F.ctx()
    g = dgl.rand_graph(30, 100)
    g = g.astype(idtype).to(ctx)
    x = torch.randn((g.num_nodes(),) + feat_size, requires_grad=True, device=ctx)

    y = dgl.copy_v(g, x)
    y.sum().backward()
    x_grad = x.grad

    x.grad.zero_()
    u, v = g.edges()
    y_true = x[v.long()]
    y_true.sum().backward()
    x_grad_true = x.grad

    assert torch.allclose(y, y_true)
    assert torch.allclose(x_grad, x_grad_true)


@parametrize_idtype
@pytest.mark.parametrize("feat_size", [(5,), ()])
def test_copy_v_hetero(idtype, feat_size):
    ctx = F.ctx()
    hg = dgl.heterograph(
        {
            ("user", "follow", "user"): ([0, 1, 2], [2, 3, 4]),
            ("user", "like", "movie"): ([3, 3, 1, 2], [0, 0, 1, 1]),
        }
    )

    hg = hg.astype(idtype).to(ctx)
    x = torch.randn((hg.num_nodes("movie"),) + feat_size, requires_grad=True, device=ctx)

    y = dgl.copy_v(hg, x, etype="like")
    y.sum().backward()
    x_grad = x.grad

    x.grad.zero_()
    u, v = hg.edges(etype="like")
    y_true = x[v.long()]
    y_true.sum().backward()
    x_grad_true = x.grad

    assert torch.allclose(y, y_true)
    assert torch.allclose(x_grad, x_grad_true)


binary_arg_sizes = [
    ((5,), (5,)),
    ((5,), ()),
    ((), (5,)),
    ((1, 3, 3), (4, 1, 3)),
    ((3, 3), (4, 1, 3)),
    ((4, 1, 3), (3, 3)),
]

dot_arg_sizes = [
    ((5,), (5,)),
    ((1, 3, 3), (4, 1, 3)),
    ((3, 3), (4, 1, 3)),
    ((4, 1, 3), (3, 3)),
]

ops = ["add", "sub", "mul", "div"]


def pad_shape(x, y, x_size, y_size):
    xy_size = torch.broadcast_shapes(x_size, y_size)
    new_x_size = (1,) * (len(xy_size) - len(x_size)) + x_size
    new_y_size = (1,) * (len(xy_size) - len(y_size)) + y_size
    new_x = x.view(-1, *new_x_size)
    new_y = y.view(-1, *new_y_size)
    return new_x, new_y


@parametrize_idtype
@pytest.mark.parametrize("op", ops)
@pytest.mark.parametrize("x_size,y_size", binary_arg_sizes)
def test_u_op_v(idtype, op, x_size, y_size):
    ctx = F.ctx()
    g = dgl.rand_graph(30, 100)
    g = g.astype(idtype).to(ctx)
    x = torch.randn((g.num_nodes(),) + x_size, requires_grad=True, device=ctx)
    y = torch.randn((g.num_nodes(),) + y_size, requires_grad=True, device=ctx)

    f_dgl = getattr(dgl, f"u_{op}_v")
    z = f_dgl(g, x, y)
    z.sum().backward()
    x_grad = x.grad
    y_grad = y.grad

    x_grad.zero_()
    y_grad.zero_()
    u, v = g.edges()
    f_torch = getattr(torch, op)
    x_u, y_v = pad_shape(x[u.long()], y[v.long()], x_size, y_size)
    z_true = f_torch(x_u, y_v)
    z_true.sum().backward()
    x_grad_true = x.grad
    y_grad_true = y.grad

    assert torch.allclose(z, z_true)
    assert torch.allclose(x_grad, x_grad_true)
    assert torch.allclose(y_grad, y_grad_true)


@parametrize_idtype
@pytest.mark.parametrize("x_size,y_size", dot_arg_sizes)
def test_u_dot_v(idtype, x_size, y_size):
    ctx = F.ctx()
    g = dgl.rand_graph(30, 100)
    g = g.astype(idtype).to(ctx)
    x = torch.randn((g.num_nodes(),) + x_size, requires_grad=True, device=ctx)
    y = torch.randn((g.num_nodes(),) + y_size, requires_grad=True, device=ctx)

    z = dgl.u_dot_v(g, x, y)
    z.sum().backward()
    x_grad = x.grad
    y_grad = y.grad

    x_grad.zero_()
    y_grad.zero_()
    u, v = g.edges()
    x_u, y_v = pad_shape(x[u.long()], y[v.long()], x_size, y_size)
    z_true = (x_u * y_v).sum(-1).unsqueeze(-1)
    z_true.sum().backward()
    x_grad_true = x.grad
    y_grad_true = y.grad

    assert torch.allclose(z, z_true, atol=1e-4, rtol=1e-4)
    assert torch.allclose(x_grad, x_grad_true)
    assert torch.allclose(y_grad, y_grad_true)
