'''
PyTest for torch_canon/canon/align.py
==========================================

'''

import pytest
import torch
import math

from torch_canon.pointcloud.align import (align_pc_t,
                                          align_pc_s3,
                                          cartesian2spherical_xtheta,
                                          cartesian2spherical_ytheta,
                                          cartesian2spherical_ztheta,
                                          xy_planar_alignment,
                                          xz_planar_alignment,
                                          zy_planar_alignment,
                                          z_axis_alignment,
                                          )


# Predefined Objects
#-------------------
unit_cube = [[1.0, 0.0, 0.0],
             [0.0, 1.0, 0.0],
             [0.0, 0.0, 1.0],
             [1.0, 1.0, 0.0],
             [1.0, 0.0, 1.0],
             [0.0, 1.0, 1.0],
             [1.0, 1.0, 1.0],
             [-1.0, 0.0, 0.0],
             [0.0, -1.0, 0.0],
             [0.0, 0.0, -1.0],
             [-1.0, -1.0, 0.0],
             [-1.0, 0.0, -1.0],
             [0.0, -1.0, -1.0],
             [-1.0, -1.0, -1.0],
             ]
                


# Unit Tests
#-----------
def unittest_cartesian_to_spherical(direction, tol
        ):
  sqrt2_2 = math.sqrt(2)/2
  dir2fn = {'x': cartesian2spherical_xtheta,
                   'y': cartesian2spherical_ytheta,
                   'z': cartesian2spherical_ztheta}
  dir2idx = {'x': 0, 'y': 1, 'z': 2}
  fn = dir2fn[direction]

  # axial test
  #~~~~~~~~~~
  pos_cart = torch.zeros(3)
  pos_cart[dir2idx[direction]] = 1.0
  pos_sph = fn(*pos_cart)
  assert pos_sph == pytest.approx((1.0, 0, 0), rel=tol, abs=tol)

  # negative axial test
  #~~~~~~~~~~~~~~~~~~~~
  pos_cart = torch.zeros(3)
  pos_cart[dir2idx[direction]] = -1.0
  pos_sph = fn(*pos_cart)
  assert pos_sph == pytest.approx((1.0, torch.pi, 0), rel=1e-7, abs=1e-7)

  # planar test
  #~~~~~~~~~~~~
  pos_cart = torch.zeros(3)
  idx = [i for i in range(3) if i != dir2idx[direction]]
  pos_cart[idx] = sqrt2_2
  pos_sph = fn(*pos_cart)
  assert pos_sph == pytest.approx((1.0, torch.pi/2, torch.pi/4), rel=1e-7, abs=1e-7)

  # rotate planar test
  #~~~~~~~~~~~~~~~~~~~~~
  pos_cart = torch.zeros(3)
  idx = [i for i in range(3) if i != dir2idx[direction]]
  pos_cart[idx[0]] = sqrt2_2
  pos_cart[idx[1]] = -sqrt2_2
  pos_sph = fn(*pos_cart)
  assert pos_sph == pytest.approx((1.0, torch.pi/2, 3*torch.pi/4), rel=1e-7, abs=1e-7)

  # negative planar test
  #~~~~~~~~~~~~~~~~~~~~~
  pos_cart = torch.zeros(3)
  idx = [i for i in range(3) if i != dir2idx[direction]]
  pos_cart[idx] = -sqrt2_2
  pos_sph = fn(*pos_cart)
  assert pos_sph == pytest.approx((1.0, torch.pi/2, -3*torch.pi/4), rel=1e-7, abs=1e-7)

  # rotate planar test
  #~~~~~~~~~~~~~~~~~~~~~
  pos_cart = torch.zeros(3)
  idx = [i for i in range(3) if i != dir2idx[direction]]
  pos_cart[idx[0]] = -sqrt2_2
  pos_cart[idx[1]] = sqrt2_2
  pos_sph = fn(*pos_cart)
  assert pos_sph == pytest.approx((1.0, torch.pi/2, -torch.pi/4), rel=1e-7, abs=1e-7)


def unittest_planar_alignment(direction, positions, tol):
  dir2fn = {'xy': xy_planar_alignment,
              'xz': xz_planar_alignment,
              'zy': zy_planar_alignment}
  dir2idx = {'xy': 2, 'xz': 1, 'zy': 0}
  dir2axs = {'xy': 0,
             'xz': 0,
             'zy': 2}

  positions = torch.tensor(positions)
  aligned_positions, Q = dir2fn[direction](positions.unsqueeze(0), positions)

  idx = dir2idx[direction]
  for pos in aligned_positions:
    # test dynamic vector
    #~~~~~~~~~~~~~~~~~~~~
    assert pytest.approx(pos[idx], abs=tol, rel=tol) == 0.0

    # test static vector
    #~~~~~~~~~~~~~~~~~~~
    axs = dir2axs[direction]
    assert pytest.approx(pos[axs], abs=tol, rel=tol) == positions[axs]

def unittest_z_axis_alignment(vector, tol):
    pos = torch.tensor(vector, dtype=torch.float32)
    aligned_pos, Q = z_axis_alignment(pos.unsqueeze(0), pos)
    assert aligned_pos[0][0] == pytest.approx(0.0, abs=tol, rel=tol)
    assert aligned_pos[0][1] == pytest.approx(0.0, abs=tol, rel=tol)
    assert aligned_pos[0][2] == pytest.approx(pos.norm(), abs=tol, rel=tol)


# Pytests
#--------
@pytest.mark.parametrize('tol',[1e-12, 1e-16])
@pytest.mark.parametrize('direction',["x", "y", "z"])
def test_cartesian_to_spherical(direction, tol):
    unittest_cartesian_to_spherical(direction, tol)

@pytest.mark.parametrize('tol',[1e-6])
@pytest.mark.parametrize('pos', unit_cube)
@pytest.mark.parametrize('direction',["xy", "xz", "zy"])
def test_planar_align(direction, pos, tol):
    unittest_planar_alignment(direction, pos, tol)

@pytest.mark.parametrize('tol',[1e-6])
@pytest.mark.parametrize('pos', unit_cube)
def test_zaxis_align(pos, tol):
    unittest_z_axis_alignment(pos, tol)
