from unittest import TestCase

import einops
import torch

from models.point_wrapper import PointWrapperVec as PointWrapper


class TestPointWrapperVec(TestCase):
    def test_merge(self):
        bz = 4
        x1 = torch.rand(bz, 64, 2)
        x2 = torch.rand(bz, 32, 2)

        p1 = PointWrapper.create_from_equal_bx(x1)
        p2 = PointWrapper.create_from_equal_bx(x2)

        p3 = PointWrapper.merge(p1, p2)

        for i in range(bz):
            print(f'shape index: {i}')
            assert torch.equal(p3.pts_of_shape(i), torch.cat([x1[i], x2[i]]))

    def test_merge_w_masks(self):
        bz = 4
        x1 = torch.rand(bz, 3, 2)
        x2 = torch.rand(bz, 5, 2)

        m1 = x1[:,:,0] < x1[:,:,1]
        m2 = x2[:,:,0] < x2[:,:,1]

        p1 = PointWrapper.create_from_equal_bx(x1)
        p2 = PointWrapper.create_from_equal_bx(x2)
        p3 = PointWrapper.merge(p1, p2)

        m3 = p3.data[:, 0] < p3.data[:, 1]
        p3_m = p3.select_w_mask(incl_mask=m3)

        for i in range(bz):
            print(f'shape index: {i}')
            target = torch.cat([x1[i][m1[i]], x2[i][m2[i]]])
            print(f'target: {target}')
            res = p3_m.pts_of_shape(i)
            print(f'res: {res}')
            assert torch.equal(res, target)

    def test_get_map_from_same_bx(self):
        bz = 4
        bx = 3
        ordered = torch.arange(24)
        print(f'ordered: {ordered}')
        reshaped = einops.rearrange(ordered, '(bz bx nx) -> bz bx nx', bz=bz, bx=bx)
        print(f'reshaped: {reshaped}')

        p = PointWrapper.create_from_equal_bx(reshaped)
        for i in range(bz):
            print(f'shape index: {i}')
            print(f'indices: {p._map[i]}')
            print(f'values: {p.pts_of_shape(i)}')
            assert torch.equal(p.pts_of_shape(i), reshaped[i])

    # def test_net_in(self):
    #     bz = 4
    #     bx = 64
    #     nx = 2
    #     nz = 16
    #     x = torch.rand(bz, bx, nx)
    #     z = torch.rand(bz, nz)

    #     x_pw = PointWrapper.create_from_equal_bx(x)
    #     net_in = x_pw.net_in(z)
    #     assert len(net_in.shape) == 2
    #     assert net_in.shape[0] == bz * bx
    #     assert net_in.shape[1] == nx + nz
        
    def test_z_in(self):
        shape_list = [
                torch.tensor([[1., 2.]]),
                torch.tensor([
                    [3., 4.],
                    [5., 6.],
                ]),
            ]

        z = torch.tensor([
                [7., 8.],
                [9., 10.],
            ])

        res = torch.tensor([
                [7., 8.],
                [9., 10.],
                [9., 10.],
            ])

        p = PointWrapper.create_from_pts_per_shape_list(shape_list)
        z_in = p.z_in(z)
        print(z_in)
        assert len(z_in.shape) == 2
        assert torch.equal(z_in, res)
        