import numpy as np
import torch
import torch.nn as nn

# from torch_scatter import scatter

TRANS = -1.5

# realistic projection parameters
params = {'maxpool_z': 1, 'maxpool_xy': 7, 'maxpool_pad_z': 0, 'maxpool_pad_xy': 3,
          'conv_z': 1, 'conv_xy': 3, 'conv_sigma_xy': 3, 'conv_sigma_z': 1, 'conv_pad_z': 0, 'conv_pad_xy': 1,
          'img_bias': 0., 'depth_bias': 0.2, 'obj_ratio': 0.8, 'bg_clr': 0.0,
          'resolution': 112, 'depth': 8}


class Grid2Image(nn.Module):
    """
    A pytorch implementation to turn 3D grid to 2D image.
    Maxpool: densifying the grid
    Convolution: smoothing via Gaussian
    Maximize: squeezing the depth channel
    """

    def __init__(self):
        super().__init__()
        torch.backends.cudnn.benchmark = False

        self.maxpool = nn.MaxPool3d(kernel_size=(params['maxpool_z'], params['maxpool_xy'], params['maxpool_xy']),
                                    stride=1,
                                    padding=(params['maxpool_pad_z'], params['maxpool_pad_xy'], params['maxpool_pad_xy']))
        self.conv = torch.nn.Conv3d(in_channels=1,
                                    out_channels=1,
                                    kernel_size=(params['conv_z'], params['conv_xy'], params['conv_xy']),
                                    stride=1,
                                    padding=(params['conv_pad_z'], params['conv_pad_xy'], params['conv_pad_xy']),
                                    bias=True)
        kn3d = get3DGaussianKernel(params['conv_xy'], params['conv_z'], params['conv_sigma_xy'], params['conv_sigma_z'])

        # Initialize 3D convolution kernel with 3D Gaussian kernel。
        self.conv.weight.data = torch.Tensor(kn3d).repeat(1, 1, 1, 1, 1)
        self.conv.bias.data.fill_(0)

    def forward(self, x):
        """
        :param x: 3D grid with size of [B * num_views, depth, resolution, resolution]
        :return: depth map with size of [B * num_views, 3, resolution, resolution]
        """
        # Densify.
        x = self.maxpool(x.unsqueeze(1))  # (B*nv, 1, d, h, w)
        # Smooth.
        x = self.conv(x)  # (B*nv, 1, d, h, w)
        # Squeeze.
        img = torch.max(x, dim=2)[0]  # (B*nv, 1, h, w)

        # Find the largest pixel value and use it to regularize images.
        # (B*nv, 1, h, w) -> (B*nv, 1, h) -> (B*nv, 1) -> (B*nv, 1, 1, 1)
        img = img / torch.max(torch.max(img, dim=-1)[0], dim=-1)[0][:, :, None, None]  # (B*n, 1, h, w)
        img = 1 - img
        img = img.repeat(1, 3, 1, 1)
        return img


def euler2mat(angle):
    """
    Convert euler angles to rotation matrix.
    :param angle: [3] or [b, 3]
    :return rot_mat: [3] or [b, 3, 3]
    source
    https://github.com/ClementPinard/SfmLearner-Pytorch/blob/master/inverse_warp.py
    """
    if len(angle.size()) == 1:
        x, y, z = angle[0], angle[1], angle[2]
        _dim = 0
        _view = [3, 3]
    elif len(angle.size()) == 2:
        b, _ = angle.size()
        x, y, z = angle[:, 0], angle[:, 1], angle[:, 2]
        _dim = 1
        _view = [b, 3, 3]
    else:
        assert False

    cos_z = torch.cos(z)
    sin_z = torch.sin(z)

    # zero = torch.zeros([b], requires_grad=False, device=angle.device)[0]
    # one = torch.ones([b], requires_grad=False, device=angle.device)[0]
    zero = z.detach() * 0
    one = zero.detach() + 1
    z_mat = torch.stack([cos_z, -sin_z, zero,
                         sin_z, cos_z, zero,
                         zero, zero, one], dim=_dim).reshape(_view)

    cos_y = torch.cos(y)
    sin_y = torch.sin(y)

    y_mat = torch.stack([cos_y, zero, sin_y,
                         zero, one, zero,
                         -sin_y, zero, cos_y], dim=_dim).reshape(_view)

    cos_x = torch.cos(x)
    sin_x = torch.sin(x)

    x_mat = torch.stack([one, zero, zero,
                         zero, cos_x, -sin_x,
                         zero, sin_x, cos_x], dim=_dim).reshape(_view)

    rot_mat = x_mat @ y_mat @ z_mat
    # print(rot_mat.shape)
    return rot_mat


def points2grid(points, resolution=params['resolution'], depth=params['depth']):
    """
    Quantize each point cloud to a 3D grid.
    :param points: tensor with size of [B * num_views, points_num, 3]
    :param depth: the depth of 3D grid
    :param resolution: the height and width of 3D grid
    :return grid: tensor with size of [B * num_views, depth, resolution, resolution]
    """

    batch, p_num, _ = points.shape

    p_max, p_min = points.max(dim=1)[0], points.min(dim=1)[0]  # [B * num_views, 3]
    p_cent = (p_max + p_min) / 2
    p_cent = p_cent[:, None, :]
    p_range = (p_max - p_min).max(dim=-1)[0][:, None, None]  # [B * num_views, 1, 1]
    points = (points - p_cent) / p_range * 2.  # [-1, 1]
    points[:, :, :2] = points[:, :, :2] * params['obj_ratio']  # x, y: [-0.7, 0.7]

    depth_bias = params['depth_bias']
    _x = (points[:, :, 0] + 1) / 2 * resolution  # x: [0.15, 0.85] * image_height
    _y = (points[:, :, 1] + 1) / 2 * resolution  # y: [0.15, 0.85] * image_height
    _z = ((points[:, :, 2] + 1) / 2 + depth_bias) / (1 + depth_bias) * (depth - 2)  # z: [0.2, 1.2] / 1.2 * depth

    # record the coordinate of each point
    _x.ceil_()  # (batch_size * view_num) * 1024 * 1
    _y.ceil_()  # (batch_size * view_num) * 1024 * 1
    z_int = _z.ceil()  # (batch_size * view_num) * 1024 * 1

    _x = torch.clip(_x, 1, resolution - 2)
    _y = torch.clip(_y, 1, resolution - 2)
    _z = torch.clip(_z, 1, depth - 2)

    '''
    coordinates = z_int * resolution * resolution + _y * resolution + _x    # index of _z, [B * num_views, p_num]

    # Grid: (batch_size * view_num) * 112 * 224 *224
    grid = torch.ones([batch, depth, resolution, resolution], device=points.device).view(batch, -1) * params['bg_clr']
    print(grid.device, _z.device)
    grid = torch.scatter(_z, coordinates.long(), dim=1, out=grid, reduce="max")
    grid = grid.reshape((batch, depth, resolution, resolution)).permute((0, 1, 3, 2))
    '''

    # n_batch: [0,0,0...0,0,0, 1,1,1...1,1,1, 2,2,2,...2,2,2, ... 14,14,14, 15,15,15]
    n_batch = torch.repeat_interleave(torch.arange(0, batch)[:, None], p_num).view(-1, ).to(points.device)
    coordinates = torch.cat((n_batch, z_int.view(-1), _y.view(-1), _x.view(-1)), dim=0).view(-1, ).long()
    index = torch.chunk(coordinates, 4, dim=0)

    # Grid: (batch_size * view_num) * 112 * 224 * 224
    grid = torch.ones([batch, depth, resolution, resolution], device=points.device) * params['bg_clr']
    grid = grid.index_put(index, _z.view(-1, )).permute((0, 1, 3, 2))

    return grid


class Realistic_Projection:
    """
    For creating images from PC based on the view information.
    """

    def __init__(self, device='cpu'):
        _views = np.asarray([[[1 * np.pi / 4, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
                             [[3 * np.pi / 4, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
                             [[5 * np.pi / 4, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
                             [[7 * np.pi / 4, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
                             [[0 * np.pi / 2, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
                             [[1 * np.pi / 2, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
                             [[2 * np.pi / 2, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
                             [[3 * np.pi / 2, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
                             [[0, -np.pi / 2,    np.pi / 2], [-0.5, -0.5, TRANS]],
                             [[0,  np.pi / 2,    np.pi / 2], [-0.5, -0.5, TRANS]]])

        # Adding some bias to the view angle to reveal more surface.
        _views_bias = np.asarray([[[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
                                  [[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
                                  [[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
                                  [[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
                                  [[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
                                  [[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
                                  [[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
                                  [[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
                                  [[0, np.pi /15, 0], [-0.5, 0, TRANS]],
                                  [[0, np.pi /15, 0], [-0.5, 0, TRANS]]])

        self.num_views = _views.shape[0]
        self.device = torch.device(device)

        angle1 = torch.tensor(_views[:, 0, :], requires_grad=False).float().to(self.device)
        self.rot_mat1 = euler2mat(angle1).transpose(1, 2)
        angle2 = torch.tensor(_views_bias[:, 0, :], requires_grad=False).float().to(self.device)
        self.rot_mat2 = euler2mat(angle2).transpose(1, 2)

        self.translation = torch.tensor(_views[:, 1, :], requires_grad=False).float().to(self.device)
        self.translation = self.translation.unsqueeze(1)

        self.grid2image = Grid2Image().to(self.device)

    def get_img(self, points):
        """
        Get depth map from point cloud.
        :param: points (torch.tensor): of size [B, num_points, 3]
        :return img (torch.tensor): of size [B * self.num_views, resolution, resolution]
        """
        b, _, _ = points.shape
        v = self.translation.shape[0]  # v = self.num_views

        _points = self.point_transform(
            points=torch.repeat_interleave(points, v, dim=0),
            rot_mat1=self.rot_mat1.repeat(b, 1, 1),
            rot_mat2=self.rot_mat2.repeat(b, 1, 1),
            translation=self.translation.repeat(b, 1, 1))

        grid = points2grid(points=_points, resolution=params['resolution'], depth=params['depth']).squeeze()
        img = self.grid2image(grid)
        return img

    @staticmethod
    def point_transform(points, rot_mat1, rot_mat2, translation):
        """
        :param points: [batch * num_views, num_points, 3]
        :param rot_mat1: [batch * num_views, 3, 3]
        :param rot_mat2: [batch * num_views, 3, 3]
        :param translation: [batch * num_views, 1, 3]
        :return: points: [batch * num_views, num_points, 3]
        """
        rot_mat1 = rot_mat1.to(points.device)
        rot_mat2 = rot_mat2.to(points.device)
        translation = translation.to(points.device)
        points = torch.matmul(points, rot_mat1)
        points = torch.matmul(points, rot_mat2)
        points = points - translation
        return points


def get2DGaussianKernel(ksize, sigma=0):
    center = ksize // 2
    xs = (np.arange(ksize, dtype=np.float32) - center)
    kernel1d = np.exp(-(xs ** 2) / (2 * sigma ** 2))
    kernel = kernel1d[..., None] @ kernel1d[None, ...]
    kernel = torch.from_numpy(kernel)
    kernel = kernel / kernel.sum()
    return kernel


def get3DGaussianKernel(ksize, depth, sigma=2, z_sigma=2):
    kernel2d = get2DGaussianKernel(ksize, sigma)
    zs = (np.arange(depth, dtype=np.float32) - depth // 2)  # zs = [0.]
    z_kernel = np.exp(-(zs ** 2) / (2 * z_sigma ** 2))  # z_kernel = [1.]
    kernel3d = np.repeat(kernel2d[None, :, :], depth, axis=0) * z_kernel[:, None, None]
    kernel3d = kernel3d / torch.sum(kernel3d)
    return kernel3d

