from unittest import TestCase

import einops
from sympy import Point
import torch

from GINN.batch_helper import tensor_product_xz


class Test_PointWrapper(TestCase):
    def test_merge(self):
        x = torch.tensor([
                [1., 2.],
                [3., 4.],
            ])

        z = torch.tensor([
                [7., 8.],
                [9., 10.],
            ])

        x_res = torch.tensor([
                [1., 2.],
                [3., 4.],
                [1., 2.],
                [3., 4.],
            ])
        
        z_res = torch.tensor([
                [7., 8.],
                [7., 8.],
                [9., 10.],
                [9., 10.],
            ])

        x2, z2 = tensor_product_xz(x, z)
        assert torch.equal(x2, x_res)
        assert torch.equal(z2, z_res)