import torch
import unittest
from ray_samplers import *


class RaySamplersTest(unittest.TestCase):
    def test_get_beta_gamma(self):
        directions = torch.tensor(
            [[0, 0, 1],
             [0, 0, -1],
             [1, 0, 0],
             [0, 1, 0],
             # [-1, 0, 0],
             # [0, -1, 0],
             # [1, 1, 0],
             # [1, 0, 1],
             [0, 1, 1]])
        # Ratation matrix (assuming yaw = 0)
        # [cos B,  sin B sin G, sin B cos G]
        # [0,      cos G,       - sin G    ]
        # [-sin B, cos B sin G, cos B cos G]

        # Inverse...
        # [cos B, sin B sin G,   -sin B cos G]
        # [0,     cos G,         sin G       ]
        # [sin B, - cos B sin G, cos B cos G ]
        directions = torch.nn.functional.normalize(directions.float(), dim=-1)
        print(directions)
        from torch import stack, cos, sin, zeros_like
        beta, gamma = get_beta_gamma(directions)
        print(stack([beta, gamma], -1))
        mat = stack([
            stack([cos(beta), sin(beta)*sin(gamma), -sin(beta)*cos(gamma)], -1),
            stack([zeros_like(beta), cos(gamma), -sin(gamma)], -1),
            stack([-sin(beta), cos(beta)*sin(gamma), cos(beta)*cos(gamma)], -1)
        ], -2)
        print(mat)

        # print(torch.stack([beta, gamma], -1))


if __name__ == '__main__':
    unittest.main()

