import torch
import torch.nn as nn
import torch.nn.functional as F

@torch.no_grad()
def _fetch_pixel_val(x: torch.Tensor, vertex_slice):
    '''
    :param x: shape (H, W, ...)
    :param vertex_slice:
    :return: shape (H - 1, W - 1, ...)
    '''
    return x[vertex_slice[0], vertex_slice[1]]


@torch.no_grad()
def get_triangle_valid(valid: torch.Tensor):
    '''
    a triangle is valid if all vertices are valid
    :param valid: shape (H, W)
    :return: triangle_valid
        triangle_valid: shape (H - 1, W - 1, NUM_TRIANGLE)
    '''
    H, W = valid.shape
    device = valid.device
    ret = torch.empty((H - 2, W - 2, NUM_TRIANGLE), dtype=torch.bool, device=device)
    for i, TRIANGLE_SLICE in enumerate(TRIANGLE_SLICES):
        ret[..., i] = _fetch_pixel_val(valid, TRIANGLE_SLICE[0]) & \
                      _fetch_pixel_val(valid, TRIANGLE_SLICE[1]) & \
                      _fetch_pixel_val(valid, TRIANGLE_SLICE[2])
    return ret

TRIANGLE_SLICES=((
    (slice(None, -2), slice(None, -2)),
    (slice(2, None), slice(None, -2)),
    (slice(None, -2), slice(2, None)),
),)
NUM_TRIANGLE = 1
@torch.no_grad()
def get_triangle_normal(xyz: torch.Tensor):
    '''
    Normal computation method 2: 2-pixel spacing
    :param xyz: shape (H, W, 3)
    :return: normal, normal_valid
        normal: shape (H - 2, W - 2, NUM_TRIANGLE_2, 3)
        normal_valid: shape (H - 2, W - 2, NUM_TRIANGLE_2)
    '''
    H, W = xyz.shape[:2]
    device = xyz.device
    dtype = xyz.dtype
    normal = torch.empty((H - 2, W - 2, 1, 3), dtype=dtype, device=device)
    normal_valid = torch.empty((H - 2, W - 2, 1), dtype=torch.bool, device=device)
    for i, TRIANGLE_SLICE in enumerate(TRIANGLE_SLICES):
        normal[..., i, :] = torch.linalg.cross(
            F.normalize(_fetch_pixel_val(xyz, TRIANGLE_SLICE[1]) - _fetch_pixel_val(xyz, TRIANGLE_SLICE[0]), dim=-1),
            F.normalize(_fetch_pixel_val(xyz, TRIANGLE_SLICE[2]) - _fetch_pixel_val(xyz, TRIANGLE_SLICE[0]), dim=-1),
            dim=-1
        )
        vec_norm = torch.norm(normal[..., i, :], dim=-1)
        normal_valid[..., i] = vec_norm > 1e-5
        normal[..., i, :] /= vec_norm.clamp(min=1e-5).unsqueeze(-1)
    return normal, normal_valid

@torch.no_grad()
def get_triangle_normal_and_valid(xyz: torch.Tensor, valid: torch.Tensor, flatten: bool = True):
    '''
    if gt_d and depth_layer are not None, filter out triangle across depth layers
    :param xyz:
    :param valid:
    :param flatten:
    :return: normal, valid
    '''
    normal, normal_valid = get_triangle_normal(xyz)
    tri_valid = get_triangle_valid(valid)
    valid = normal_valid & tri_valid
    if flatten:
        normal = normal.reshape(-1, 3)
        valid = valid.reshape(-1)
    return normal, valid
