from unittest import TestCase

import torch

from models.point_wrapper import PointWrapper
from GINN.morse.cp_speedups import update_proximity_mask


class TestProximityStopping(TestCase):
    def test_matching_simple(self):
        
        shape_list = [
                torch.tensor([[1., 2.]]),
                torch.tensor([
                    [7., 8.],
                    [10., 20.],
                ]),
            ]
        p = PointWrapper.create_from_pts_per_shape_list(shape_list)  # [(1 + 2) 2]
        
        eps = 1.1
        x_path_over_iters = torch.tensor([
                [
                    [5., 6.],
                    [7., 8.],
                    [10., 20.],
                ],
                p.data,
            ])
        print('x_path_over_iters.shape', x_path_over_iters.shape)
        squared_norms = torch.tensor([3., 2., 1.])

        new_mask = update_proximity_mask(p, x_path_over_iters, squared_norms, eps=eps)
        print('new_mask', new_mask)
        target_mask = torch.tensor([True, True, True])
        
        assert torch.equal(new_mask, target_mask)
        
    def test_matching_with_t_equal_3(self):
        
        shape_list = [
                torch.tensor([[1., 2.]]),
                torch.tensor([
                    [3., 4.],
                    [3., 4.],
                    [3., 4.],
                ]),
            ]
        p = PointWrapper.create_from_pts_per_shape_list(shape_list)  # [(1 + 2) 2]
        
        eps = 1.1
        x_path_over_iters = torch.tensor([
                # t=0
                [
                    [5., 6.],
                    [5., 6.],
                    [7., 8.],
                    [9., 10.],
                ],
                # t=1
                p.data,
            ])
        print('x_path_over_iters.shape', x_path_over_iters.shape)
        squared_norms = torch.tensor([3., 2., 1., 0.5])

        new_mask = update_proximity_mask(p, x_path_over_iters, squared_norms, eps=eps)
        print('new_mask', new_mask)
        target_mask = torch.tensor([True, False, False, True])
        
        assert torch.equal(new_mask, target_mask)
        
    def test_matching_with_4_points(self):

        # current points        
        shape_list = [
                torch.tensor([[1., 2.]]),
                torch.tensor([
                    [5., 6.],
                    [7., 8.],
                    [9., 10.],
                ]),
            ]
        p = PointWrapper.create_from_pts_per_shape_list(shape_list)  # [(1 + 2) 2]
        
        eps = 3
        x_path_over_iters = torch.tensor([
                # t=0
                [
                    [5., 6.],
                    [5., 6.],
                    [7., 8.],
                    [9., 10.],
                ],
                # t=1
                [
                    [5., 6.],
                    [5., 6.],
                    [7., 8.],
                    [9., 10.],
                ],
                # t=2
                [
                    [5., 6.],
                    [5., 6.],
                    [7., 8.],
                    [9., 10.],
                ],
                p.data,
            ])
        print('x_path_over_iters.shape', x_path_over_iters.shape)
        squared_norms = torch.tensor([3., 2., 1., 0.5])

        new_mask = update_proximity_mask(p, x_path_over_iters, squared_norms, eps=eps)
        print('new_mask', new_mask)
        target_mask = torch.tensor([True, False, False, True])
        print('target_mask', target_mask)
        
        assert torch.equal(new_mask, target_mask)
        
    def test_matching_with_inf(self):

        # current points        
        shape_list = [
                torch.tensor([[1., 2.]]),
                torch.tensor([
                    [5., 6.],
                    [7., 8.],
                    [9., 10.],
                ]),
            ]
        p = PointWrapper.create_from_pts_per_shape_list(shape_list)  # [(1 + 2) 2]
        
        eps = 3
        x_path_over_iters = torch.tensor([
                # t=0
                [
                    [5., 6.],
                    [5., 6.],
                    [7., 8.],
                    [9., 10.],
                ],
                # t=1
                [
                    [5., 6.],
                    [5., 6.],
                    [7., 8.],
                    [9., 10.],
                ],
                # t=2
                [
                    [5., 6.],
                    [5., 6.],
                    [7., 8.],
                    [9., 10.],
                ],
                p.data,
            ])
        print('x_path_over_iters.shape', x_path_over_iters.shape)
        squared_norms = torch.tensor([3., 2., float('inf'), 0.5])

        new_mask = update_proximity_mask(p, x_path_over_iters, squared_norms, eps=eps)
        print('new_mask', new_mask)
        target_mask = torch.tensor([True, True, False, True])
        print('target_mask', target_mask)
        
        assert torch.equal(new_mask, target_mask)
        