# -*- coding: utf-8 -*-
import math

import numpy as np
import torch
import torch.nn as nn

from einops import rearrange
from scene.transformer_decoder.dino_predictor import DINOGSPred
from utils.general_utils import quaternion_raw_multiply
from utils.graphics_utils import fov2focal


class PositionalEmbedding(torch.nn.Module):
    def __init__(self, num_channels, max_positions=10000, endpoint=False):
        super().__init__()
        self.num_channels = num_channels
        self.max_positions = max_positions
        self.endpoint = endpoint

    def forward(self, x):
        b, c = x.shape
        x = rearrange(x, 'b c -> (b c)')
        freqs = torch.arange(start=0, end=self.num_channels//2, dtype=torch.float32, device=x.device)
        freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0))
        freqs = (1 / self.max_positions) ** freqs
        x = x.ger(freqs.to(x.dtype))
        x = torch.cat([x.cos(), x.sin()], dim=1)
        x = rearrange(x, '(b c) emb_ch -> b (c emb_ch)', b=b)
        return x

def networkCallBack(cfg, name, out_channels):
    if name == "dino_decoder":
        return DINOGSPred(cfg=cfg, out_dim=out_channels), None
    else:
        raise NotImplementedError


class GaussianSplatPredictor(nn.Module):
    def __init__(self, cfg):
        super(GaussianSplatPredictor, self).__init__()
        self.cfg = cfg

        self.get_splits_and_inits(cfg)
        if self.cfg.model.dino_decoder.use_point_cloud_init == 'coarse_stage':
            from scene.coarse_stage import GaussianSplatPredictor as GaussianSplatPredictorSplatterImage
            self.coarse = GaussianSplatPredictorSplatterImage(cfg)
            ckpt_loaded = torch.load(cfg.model.coarse_model_dir)
            self.coarse.load_state_dict(ckpt_loaded["model_state_dict"])
            if not self.cfg.model.finetune_coarse:
                for name, param in self.coarse.named_parameters():
                    param.requires_grad = False
        self.stage2_net, _ = networkCallBack(cfg,
                                             'dino_decoder',
                                             self.split_dimensions_stage2)

        from scene.coarse_stage import GaussianSplatPredictor as GaussianSplatPredictorSplatterImage
        self.unet_encoder = GaussianSplatPredictorSplatterImage(cfg)
        ckpt_loaded = torch.load(cfg.model.coarse_model_dir)
        self.unet_encoder.load_state_dict(ckpt_loaded["model_state_dict"])
        if cfg.model.freeze_encoder:
            for name, param in self.unet_encoder.named_parameters():
                param.requires_grad = False

        self.init_ray_dirs()

        # Activation functions for different parameters
        self.depth_act = nn.Sigmoid()
        self.scaling_activation = torch.exp
        self.opacity_activation = torch.sigmoid
        self.rotation_activation = torch.nn.functional.normalize

        if self.cfg.model.max_sh_degree > 0:
            self.init_sh_transform_matrices()

    def init_sh_transform_matrices(self):
        v_to_sh_transform = torch.tensor([[0, 0, -1],
                                          [-1, 0, 0],
                                          [0, 1, 0]], dtype=torch.float32)
        sh_to_v_transform = v_to_sh_transform.transpose(0, 1)
        self.register_buffer('sh_to_v_transform', sh_to_v_transform.unsqueeze(0))
        self.register_buffer('v_to_sh_transform', v_to_sh_transform.unsqueeze(0))

    def init_ray_dirs(self):
        # for stage 1 pixels
        resolution_vis = self.cfg.data.coarse_resolution
        num_lines_vis = self.cfg.data.coarse_resolution
        resolution = 128
        num_lines = int(math.sqrt(self.cfg.model.dino_decoder.num_queries))

        x_vis = torch.linspace(-resolution_vis // 2 + 0.5,
                               resolution_vis // 2 - 0.5,
                               num_lines_vis)
        y_vis = torch.linspace(resolution_vis // 2 - 0.5,
                               -resolution_vis // 2 + 0.5,
                               num_lines_vis)
        if self.cfg.model.inverted_x:
            x_vis = -x_vis
        if self.cfg.model.inverted_y:
            y_vis = -y_vis
        grid_x_vis, grid_y_vis = torch.meshgrid(x_vis, y_vis, indexing='xy')
        ones = torch.ones_like(grid_x_vis, dtype=grid_x_vis.dtype)
        ray_dirs_vis = torch.stack([grid_x_vis, grid_y_vis, ones]).unsqueeze(0)

        # for cars and chairs the focal length is fixed across dataset
        # so we can preprocess it
        # for co3d this is done on the fly
        
        ray_dirs_vis[:, :2, ...] /= fov2focal(self.cfg.data.fov * np.pi / 180,
                                                resolution_vis)
        self.register_buffer('ray_dirs_vis', ray_dirs_vis)

        # for stage 2 pixels
        x = torch.linspace(-resolution // 2 + 0.5,
                           resolution // 2 - 0.5,
                           num_lines)
        y = torch.linspace(resolution // 2 - 0.5,
                           -resolution // 2 + 0.5,
                           num_lines)
        if self.cfg.model.inverted_x:
            x = -x
        if self.cfg.model.inverted_y:
            y = -y
        grid_x, grid_y = torch.meshgrid(x, y, indexing='xy')
        ones = torch.ones_like(grid_x, dtype=grid_x.dtype)
        ray_dirs = torch.stack([grid_x, grid_y, ones]).unsqueeze(0)

        # for cars and chairs the focal length is fixed across dataset
        # so we can preprocess it
        # for co3d this is done on the fly
        
        ray_dirs[:, :2, ...] /= fov2focal(self.cfg.data.fov * np.pi / 180,
                                            resolution)
        self.register_buffer('ray_dirs', ray_dirs)

    def get_splits_and_inits(self, cfg):
        # Gets channel split dimensions and last layer initialisation
        split_dimensions = []

        split_dimensions = split_dimensions + [1, 3, 1, 3, 4, 3]
        
        if cfg.model.max_sh_degree != 0:
            sh_num = (self.cfg.model.max_sh_degree + 1) ** 2 - 1
            sh_num_rgb = sh_num * 3
            split_dimensions.append(sh_num_rgb)
        
        self.split_dimensions_stage2 = [3, 1, 3, 4, 3, sh_num_rgb]
        self.split_dimensions_with_offset = split_dimensions
        

    def flatten_vector(self, x):
        # Gets rid of the image dimensions and flattens to a point list
        # B x C x H x W -> B x C x N -> B x N x C
        return x.reshape(x.shape[0], x.shape[1], -1).permute(0, 2, 1)

    def make_contiguous(self, tensor_dict):
        cont_dict = {}
        for k in tensor_dict:
            values = []
            vs = tensor_dict[k]

            if isinstance(vs, list):
                for v in vs:
                    values.append(v.contiguous())
            else:
                values = vs.contiguous()
            cont_dict[k] = values
        return cont_dict

    def transform_SHs(self, shss, source_cameras_to_world):
        if isinstance(shss, list):
            shs_list = []
            for idx in range(len(shss)):
                shs = rearrange(shss[idx].unsqueeze(0), 'b n sh_num rgb -> b (n rgb) sh_num')
                assert shs.shape[2] == 3, "Can only process shs order 1"
                transforms = torch.bmm(
                    self.sh_to_v_transform.expand(1, 3, 3),
                    # transpose is because source_cameras_to_world is
                    # in row major order
                    source_cameras_to_world[idx, :3, :3].unsqueeze(0))
                transforms = torch.bmm(transforms,
                                       self.v_to_sh_transform.expand(1, 3, 3))

                shs_transformed = torch.bmm(shs, transforms)
                shs_transformed = rearrange(shs_transformed, 'b (n rgb) sh_num -> b n sh_num rgb', rgb=3)
                shs_list.append(shs_transformed.squeeze(0))
            shs_transformed = shs_list
        else:
            # shs: B x N x SH_num x 3
            # source_cameras_to_world: B 4 4
            assert shss.shape[2] == 3, "Can only process shs order 1"
            shs = rearrange(shss, 'b n sh_num rgb -> b (n rgb) sh_num')
            transforms = torch.bmm(
                self.sh_to_v_transform.expand(source_cameras_to_world.shape[0], 3, 3),
                # transpose is because source_cameras_to_world is
                # in row major order
                source_cameras_to_world[:, :3, :3])
            transforms = torch.bmm(transforms,
                                   self.v_to_sh_transform.expand(source_cameras_to_world.shape[0], 3, 3))

            shs_transformed = torch.bmm(shs, transforms)
            shs_transformed = rearrange(shs_transformed, 'b (n rgb) sh_num -> b n sh_num rgb', rgb=3)

        return shs_transformed

    def transform_rotations(self, rotations, source_cv2wT_quat):
        """
        Applies a transform that rotates the predicted rotations from
        camera space to world space.
        Args:
            rotations: predicted in-camera rotation quaternions (B x N x 4)
            source_cameras_to_world: transformation quaternions from
                camera-to-world matrices transposed(B x 4)
        Retures:
            rotations with appropriately applied transform to world space
        """

        if isinstance(rotations, list):
            rotations_list = []
            for idx in range(len(rotations)):
                rotation = rotations[idx].unsqueeze(0)
                Mq = source_cv2wT_quat[idx].unsqueeze(0).unsqueeze(1).expand(*rotation.shape)
                rotations_list.append(quaternion_raw_multiply(Mq, rotation).squeeze(0))
            rotations = rotations_list
        else:
            Mq = source_cv2wT_quat.unsqueeze(1).expand(*rotations.shape)
            rotations = quaternion_raw_multiply(Mq, rotations)

        return rotations

    def get_pos_from_network_output(self, sigmoid_depth_network, offset, const_offset=None, stage=None):

        # expands ray dirs along the batch dimension
        # adjust ray directions according to fov if not done already
        if stage == 'first':
            ray_dirs_xy_vis = self.ray_dirs_vis.expand(sigmoid_depth_network.shape[0], 3, *self.ray_dirs_vis.shape[2:])

        ray_dirs_xy = self.ray_dirs.expand(sigmoid_depth_network.shape[0], 3, *self.ray_dirs.shape[2:])

        # depth and offsets are shaped as (b 3 h w)
        if const_offset is not None:
            resolution = sigmoid_depth_network.shape[-1]
            depth = sigmoid_depth_network * (self.cfg.data.zfar - self.cfg.data.znear) + \
                self.cfg.data.znear + const_offset[:, :, :resolution, :resolution]
        else:
            depth = sigmoid_depth_network * (self.cfg.data.zfar - self.cfg.data.znear) + self.cfg.data.znear
        if offset is not None:
            if stage == 'first':
                pos = ray_dirs_xy_vis * depth + offset
            else:
                pos = ray_dirs_xy * depth + offset
        else:
            pos = ray_dirs_xy * depth 
        return pos

    def postprocess_gs_params(self, split_network_outputs, source_cameras_view_to_world, const_offset, B, N_views, stage=None):
        if stage=='second':
            split_network_outputs = split_network_outputs.split(self.split_dimensions_stage2, dim=1)
            pos, opacity, scaling, rotation, features_dc = split_network_outputs[:5]

        elif stage=='first':
            source_cameras_view_to_world = source_cameras_view_to_world.reshape(B*N_views, 4, 4)
            split_network_outputs = split_network_outputs.split(self.split_dimensions_with_offset, dim=1)
            depth, offset, opacity, scaling, rotation, features_dc = split_network_outputs[:6]  # [1, 3, 1, 3, 4, 3]
            sigmoid_depth = self.depth_act(depth)
            pos = self.get_pos_from_network_output(sigmoid_depth, offset, const_offset=const_offset, stage=stage)
        else:
            raise NotImplementedError("Unknown gaussian dimension {}".format(sum(self.split_dimensions_with_offset)))
        if self.cfg.model.isotropic:
            scaling_out = torch.cat([scaling[:, :1, ...], scaling[:, :1, ...], scaling[:, :1, ...]], dim=1)
        else:
            scaling_out = scaling
        # Pos prediction is in camera space - compute the positions in the world space
        pos = self.flatten_vector(pos)

        if source_cameras_view_to_world is not None:
            source_cameras_view1_to_world = source_cameras_view_to_world
            pos = torch.cat([pos,
                            torch.ones((pos.shape[0], pos.shape[1], 1), device=pos.device, dtype=pos.dtype)
                            ], dim=2)
            pos = torch.bmm(pos, source_cameras_view1_to_world)
            pos = pos[:, :, :3] / (pos[:, :, 3:] + 1e-10)
        if stage=='second':
            opacity_flatten = self.flatten_vector(self.opacity_activation(opacity - 2.0))
            scaling_flatten = self.flatten_vector(torch.clamp(self.scaling_activation(scaling_out - 2.3), max=0.3))
            rotation_flatten = self.flatten_vector(rotation / (1e-8 + rotation.norm(dim=-1, keepdim=True)))
        else:
            opacity_flatten = self.flatten_vector(self.opacity_activation(opacity))
            scaling_flatten = self.flatten_vector(self.scaling_activation(scaling_out))
            rotation_flatten = self.flatten_vector(self.rotation_activation(rotation, eps=1e-6))
        features_dc_flatten = self.flatten_vector(features_dc).unsqueeze(2)

        if stage == 'first':
            return pos, opacity_flatten, scaling_flatten, rotation_flatten, features_dc_flatten, scaling, opacity
        else:
            return pos, opacity_flatten, scaling_flatten, rotation_flatten, features_dc_flatten

    def transform_params(self, out_dict, features_rest, source_cv2wT_quat, source_cameras_view_to_world, B, N_views):
        assert source_cv2wT_quat is not None
        source_cv2wT_quat = source_cv2wT_quat.reshape(B*N_views, *source_cv2wT_quat.shape[2:])

        out_dict["rotation"] = self.transform_rotations(out_dict["rotation"],
                                                        source_cv2wT_quat=source_cv2wT_quat)

        if self.cfg.model.max_sh_degree > 0:
            features_rest = self.flatten_vector(features_rest)
            # Channel dimension holds SH_num * RGB(3) -> renderer expects split across RGB
            # Split channel dimension B x N x C -> B x N x SH_num x 3

            out_dict["features_rest"] = features_rest.reshape(*features_rest.shape[:2], -1, 3)
            assert self.cfg.model.max_sh_degree == 1  # "Only accepting degree 1"
            out_dict["features_rest"] = self.transform_SHs(out_dict["features_rest"],
                                                           source_cameras_view_to_world)
        else:
            out_dict["features_rest"] = torch.zeros((out_dict["features_dc"].shape[0],
                                                     out_dict["features_dc"].shape[1],
                                                     (self.cfg.model.max_sh_degree + 1) ** 2 - 1,
                                                     3), dtype=out_dict["features_dc"].dtype, device="cuda")

        out_dict = self.make_contiguous(out_dict)
        return out_dict
    
    def get_coarse_feature(self,
                           x_raw,
                           source_cameras_view_to_world_coarse,
                           source_cv2wT_quat_coarse,
                           const_offset, 
                           B,
                           masks
                           ):
        
       
        N_views_coarse = self.cfg.data.coarse_stage_input_images
        x_coarse = x_raw[:, :N_views_coarse].detach()
        out_dict, pos_unflatten, stage1_network_outputs, _ = self.coarse(x_coarse, 
                                                                        source_cameras_view_to_world_coarse[:, :N_views_coarse, ...], 
                                                                        source_cv2wT_quat_coarse[:, :N_views_coarse, ...]
                                                            )
        pos_vis, opacity_vis_flatten, scaling_vis_flatten, rotation_vis_flatten, features_dc_vis_flatten, scaling_vis, opacity_vis = self.postprocess_gs_params(
            stage1_network_outputs, source_cameras_view_to_world_coarse[:, :N_views_coarse, ...], const_offset, B, N_views=N_views_coarse, stage='first')
        # input to stage 2: point cloud with shape (B, N, 3), stage1_network_outputs with shape (bs, 24, N)

        point_cloud_init = pos_vis.reshape(B, N_views_coarse, *pos_vis.shape[1:])
        stage1_network_outputs = stage1_network_outputs.reshape(B, N_views_coarse, *stage1_network_outputs.shape[1:]).flatten(3).permute(0, 1, 3, 2)
        nq = self.cfg.model.dino_decoder.num_queries
        sqrt_nq = int(math.sqrt(nq))
        point_num_per_view = nq // N_views_coarse + 1
        point_cloud_batch = []
        gs_init_batch = []
        for b_idx in range(B):
            point_cloud_obj = []
            gs_init_obj = []
            for view_idx in range(N_views_coarse):
                if self.cfg.data.use_mask and masks[b_idx][view_idx].sum() != 0:
                    point_cloud_view = point_cloud_init[b_idx][view_idx][masks[b_idx][view_idx]]
                    gs_init_view = stage1_network_outputs[b_idx][view_idx][masks[b_idx][view_idx]]
                else:
                    point_cloud_view = point_cloud_init[b_idx][view_idx]
                    gs_init_view = stage1_network_outputs[b_idx][view_idx]
                point_cloud_number = point_cloud_view.shape[0]

                if point_cloud_number > point_num_per_view:
                    point_cloud_view = point_cloud_view[:point_num_per_view, :]
                    gs_init_view = gs_init_view[:point_num_per_view, :]
                elif point_cloud_number < point_num_per_view:
                    num_point_padding = point_num_per_view - point_cloud_number
                    if num_point_padding // point_cloud_number > 0:
                        point_cloud_view = point_cloud_view.repeat(1 + (num_point_padding // point_cloud_number), 1)
                        gs_init_view = gs_init_view.repeat(1 + (num_point_padding // point_cloud_number), 1)
                    point_cloud_view = torch.cat([point_cloud_view, point_cloud_view[:num_point_padding % point_cloud_number, :]], dim=0)
                    gs_init_view = torch.cat([gs_init_view, gs_init_view[:num_point_padding % point_cloud_number, :]], dim=0)
                point_cloud_obj.append(point_cloud_view)
                gs_init_obj.append(gs_init_view)
            point_cloud_obj = torch.stack(point_cloud_obj)
            gs_init_obj = torch.stack(gs_init_obj)
            point_cloud_batch.append(point_cloud_obj)
            gs_init_batch.append(gs_init_obj)
        point_cloud_batch = torch.stack(point_cloud_batch)
        gs_init_batch = torch.stack(gs_init_batch)

        gs_init_batch = gs_init_batch.reshape(B, N_views_coarse * point_num_per_view, -1)[:, :nq, :].permute(0, 2, 1).reshape(B, 24, sqrt_nq, sqrt_nq)[:, 4:]
        point_cloud_batch = point_cloud_batch.reshape(B, N_views_coarse * point_num_per_view, -1)[:, :nq, :].permute(0, 2, 1).reshape(B, 3, sqrt_nq, sqrt_nq)
        return gs_init_batch, point_cloud_batch, stage1_network_outputs

    def forward(self, x,
                masks,
                intrinsics,
                source_cameras_view_to_world,
                source_cv2wT_quat=None,
                plucker_emb=None,
                input_cameras=None,
                unnorm_imges=None,
                source_cameras_view_to_world_coarse=None,
                source_cv2wT_quat_coarse=None
                ):

        B = x.shape[0]
        N_views = x.shape[1]


        if self.cfg.data.use_mask:
            if masks.shape[-1] != self.cfg.data.coarse_resolution:
                masks = masks.reshape(B*self.cfg.data.coarse_stage_input_images, *masks.shape[2:])
                masks = nn.functional.interpolate(masks, size=(self.cfg.data.coarse_resolution, self.cfg.data.coarse_resolution), mode="bicubic", align_corners=False)
                masks = masks.reshape(B, self.cfg.data.coarse_stage_input_images, *masks.shape[1:])
            masks = masks.flatten(start_dim=2).bool()

        x_raw = unnorm_imges
        if unnorm_imges.shape[-1] != self.cfg.data.coarse_resolution:
            x_raw = x_raw.reshape(B*self.cfg.data.input_images, *x_raw.shape[2:])
            x_raw = nn.functional.interpolate(x_raw, size=(self.cfg.data.coarse_resolution, self.cfg.data.coarse_resolution), mode="bicubic", align_corners=False)
            x_raw = x_raw.reshape(B, self.cfg.data.input_images, *x_raw.shape[1:])
        # source_cameras_view_to_world_coarse = source_cameras_view_to_world[:, :self.cfg.data.coarse_stage_input_images].clone()
        # source_cv2wT_quat_coarse = source_cv2wT_quat[:, :self.cfg.data.coarse_stage_input_images].clone()

        x = x.reshape(B*N_views, *x.shape[2:])
        if self.cfg.data.origin_distances:
            const_offset = x[:, 3:, ...]
            x = x[:, :3, ...]
        else:
            const_offset = None

        ### encoder
        _, _, _, stage1_feas = self.unet_encoder(x_raw, 
                                source_cameras_view_to_world_coarse, 
                                source_cv2wT_quat_coarse
                                )
        gs_init_batch, point_cloud_batch, stage1_network_outputs = self.get_coarse_feature(x_raw,
                                                                    source_cameras_view_to_world_coarse,
                                                                    source_cv2wT_quat_coarse,
                                                                    const_offset, 
                                                                    B,
                                                                    masks
                                                                    )
        source_cameras_view_to_world_batch = source_cameras_view_to_world.clone()
        source_cameras_view_to_world = source_cameras_view_to_world.reshape(B*N_views, *source_cameras_view_to_world.shape[2:])
        x = x.contiguous(memory_format=torch.channels_last)
        if 'before_enc' in self.cfg.model.plucker_emb:
            plucker_emb_flatten = plucker_emb.reshape(B*N_views, *plucker_emb.shape[2:])
            x = torch.cat([x, plucker_emb_flatten], dim=1)
            x = self.plucker_emb(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)


        # stage 2 new points outputs
        stage2_network_outputs = self.stage2_net(stage1_feas=stage1_feas,
                                                 intrinsics=intrinsics,
                                                 const_offset=const_offset,
                                                 source_cameras_view_to_world=source_cameras_view_to_world_batch,
                                                 input_cameras=input_cameras,
                                                 point_cloud_init=point_cloud_batch,
                                                 stage1_network_outputs=gs_init_batch
                                                 )  # (B, 24, 128, 128)


        poss = []
        opacity_flattens = []
        scaling_flattens = []
        rotation_flattens = []
        features_dc_flattens = []

        for stage2_network_output in stage2_network_outputs:
            pos, opacity_flatten, scaling_flatten, rotation_flatten, features_dc_flatten = self.postprocess_gs_params(
                stage2_network_output, source_cameras_view_to_world_batch[:, 0], const_offset, B, N_views, stage='second')
            poss.append(pos)
            opacity_flattens.append(opacity_flatten)
            scaling_flattens.append(scaling_flatten)
            rotation_flattens.append(rotation_flatten)
            features_dc_flattens.append(features_dc_flatten)

        out_dicts = [] 
        for layer_idx in range(len(poss)):
            pos = poss[layer_idx]
            opacity_flatten = opacity_flattens[layer_idx]
            scaling_flatten = scaling_flattens[layer_idx]
            rotation_flatten = rotation_flattens[layer_idx]
            features_dc_flatten = features_dc_flattens[layer_idx]

            out_dict = {
                "xyz": pos,
                "opacity": opacity_flatten,
                "scaling": scaling_flatten,
                "rotation": rotation_flatten,
                "features_dc": features_dc_flatten
            }
            if self.cfg.model.max_sh_degree > 0:
                stage2_network_output = stage2_network_outputs[layer_idx].split(self.split_dimensions_stage2, dim=1)
                features_rest = stage2_network_output[5]
                
            else:
                features_rest = None


            out_dicts.append(self.transform_params(out_dict, features_rest,
                                                   source_cv2wT_quat[:, [0]], source_cameras_view_to_world_batch[:, 0], B, N_views=1))
        return out_dicts
