# -*- coding: utf-8 -*-
import math
import numpy as np
import torch
from torch import nn

from utils.graphics_utils import fov2focal


class ReferencePointProjection(nn.Module):
    def __init__(self, cfg):
        super(ReferencePointProjection, self).__init__()
        self.cfg = cfg
        self.init_ray_dirs()
        self.depth_act = nn.Sigmoid()

    def init_ray_dirs(self):
        resolution = 128
        self.num_lines = int(math.sqrt(self.cfg.model.dino_decoder.num_queries))
        x = torch.linspace(-resolution // 2 + 0.5,
                           resolution // 2 - 0.5,
                           self.num_lines)
        y = torch.linspace(resolution // 2 - 0.5,
                           -resolution // 2 + 0.5,
                           self.num_lines)
        if self.cfg.model.inverted_x:
            x = -x
        if self.cfg.model.inverted_y:
            y = -y
        grid_x, grid_y = torch.meshgrid(x, y, indexing='xy')
        ones = torch.ones_like(grid_x, dtype=grid_x.dtype)
        ray_dirs = torch.stack([grid_x, grid_y, ones]).unsqueeze(0)

        # for cars and chairs the focal length is fixed across dataset
        # so we can preprocess it
        # for co3d this is done on the fly
        ray_dirs[:, :2, ...] /= fov2focal(self.cfg.data.fov * np.pi / 180,
                                            resolution)
        self.register_buffer('ray_dirs', ray_dirs)  # (1, 3, resolution, resolution)

    def get_pos_from_network_output(self, depth_network, offset, const_offset=None):

        # expands ray dirs along the batch dimension
        # adjust ray directions according to fov if not done already
        ray_dirs_xy = self.ray_dirs.expand(depth_network.shape[0], 3, *self.ray_dirs.shape[2:])  # (B, 3, resolution, resolution)

        # depth and offsets are shaped as (b 3 h w)
        if const_offset is not None:
            depth = self.depth_act(depth_network) * (self.cfg.data.zfar - self.cfg.data.znear) + \
                self.cfg.data.znear + const_offset[:, :, :self.num_lines, :self.num_lines]
        else:
            depth = self.depth_act(depth_network) * (self.cfg.data.zfar - self.cfg.data.znear) + self.cfg.data.znear
        pos = ray_dirs_xy * depth + offset
        return pos


    def flatten_vector(self, x):
        # Gets rid of the image dimensions and flattens to a point list
        # B x C x H x W -> B x C x N -> B x N x C
        return x.reshape(x.shape[0], x.shape[1], -1).permute(0, 2, 1)


    def project_3d_gaussian_to_uv(self, gaussian, intrinsics=None, const_offset=None, source_cameras_view_to_world=None):
        '''gaussian: output gaussian parameters with shape (bs, 24, sqrt(nq), sqrt(nq))
            return pos: 3d reference point (bs, nq, 3)
                    uv_coord: 2d refernce point after projection (bs, nq, 2)'''

        bs, N_views = source_cameras_view_to_world.shape[0], source_cameras_view_to_world.shape[1]
        if gaussian.shape[1] == 24:
            depth, offset = gaussian[:, 0, :, :].unsqueeze(1), gaussian[:, 1:4, :, :]
            pos_raw = self.get_pos_from_network_output(depth, offset, const_offset=const_offset)
        else:
            pos_raw = gaussian[:, 0:3, :, :]
        # Pos prediction is in camera space - compute the positions in the world space

        source_cameras_world_to_view = torch.inverse(source_cameras_view_to_world.float().reshape(bs * N_views, *source_cameras_view_to_world.shape[2:]))
        # if gaussian.dtype == torch.float16:
        source_cameras_world_to_view = source_cameras_world_to_view.to(dtype=gaussian.dtype)
        pos = self.flatten_vector(pos_raw)
        
        pos = torch.cat([pos,
                         torch.ones((pos.shape[0], pos.shape[1], 1), device="cuda", dtype=pos.dtype)
                         ], dim=2)
        pos = torch.bmm(pos.repeat(N_views, 1, 1, 1).transpose(0, 1).reshape(source_cameras_world_to_view.shape[0], pos.shape[1], 4), source_cameras_world_to_view)
        pos = pos[:, :, :3] / (pos[:, :, 3:] + 1e-10)

        uv_coord = self.xyz2uv(pos, intrinsics)
        return uv_coord, pos, pos_raw


    def xyz2uv(self, coord_3d, intrinsics):

        bs, nq, _ = coord_3d.shape
        coord_3d_init = coord_3d.clone()
        if intrinsics is None:
            f_x = self.cfg.data.intrinsics[0]
            f_y = self.cfg.data.intrinsics[1]
            c_x = self.cfg.data.intrinsics[2]
            c_y = self.cfg.data.intrinsics[3]
            intrinsic_matrix = torch.tensor([[f_x, 0, c_x],
                                             [0, f_y, c_y],
                                             [0, 0, 1]], device=coord_3d.device, dtype=coord_3d.dtype)
            # Project onto the image plane
            coord_3d = coord_3d.permute(2, 0, 1).flatten(1)
            coord_image = torch.matmul(intrinsic_matrix, coord_3d)
        else:
            intrinsic_matrix = []
            for i in range(coord_3d.shape[0]):
                f_x = intrinsics[i][0][0]
                f_y = intrinsics[i][0][1]
                c_x = intrinsics[i][0][2]
                c_y = intrinsics[i][0][3]
                intrinsic_matrix_each = torch.tensor([[f_x, 0, c_x],
                                                      [0, f_y, c_y],
                                                      [0, 0, 1]], device=coord_3d.device, dtype=coord_3d.dtype)
                intrinsic_matrix.append(intrinsic_matrix_each)
            intrinsic_matrix = torch.stack(intrinsic_matrix, dim=0)
            # Project onto the image plane
            coord_3d = coord_3d_init.permute(0, 2, 1)
            coord_image = torch.matmul(intrinsic_matrix, coord_3d).permute(1, 0, 2).flatten(1)
        # Normalize image coordinates
        coord_image_normalized = coord_image / (coord_image[2] + 1e-10)
        # the projection always project the points onto a resolution of 128, 128 and normalized to [0, 1]
        coord_image_normalized = coord_image_normalized[:2] / 128
        coord_image_normalized = coord_image_normalized.t().reshape(bs, nq, 2)
        return coord_image_normalized

