# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
from detectron2.layers import Conv2d

from scene.transformer_decoder.dino_decoder import build_deformable_transformer
from scene.transformer_decoder.dino_decoder import DINODecoder
from utils.camera_utils import CameraEmbedder


class DINOGSPred(nn.Module):
    def __init__(self, cfg, out_dim=23):
        super(DINOGSPred, self).__init__()

        self.cfg = cfg
        # in case the Internet connection is not stable, please load the DINOv2 locally

        self.up_channel = Conv2d(
            cfg.model.base_dim, cfg.model.dino_decoder.in_channels[0], kernel_size=1
        )

        ### cross attention on the features after encoder

        transformer_dec = build_deformable_transformer(cfg, out_dim)
        num_feature_levels = 4

        ### emb plucker camera to image feature
        if 'after_enc' in self.cfg.model.plucker_emb:
            self.plucker_emb = CameraEmbedder(cfg.model.dino_decoder.in_channels[0]+6, cfg.model.dino_decoder.in_channels[0])
        self.dino_dec = DINODecoder(cfg, transformer_dec,
                                    cfg.model.dino_decoder.num_queries,
                                    cfg.model.dino_decoder.in_channels,
                                    num_feature_levels=num_feature_levels
                                    )
        if self.cfg.data.mod_camera_dec:
            self.cam_emb = CameraEmbedder(raw_dim=12+4, embed_dim=self.cfg.model.dino_decoder.camera_embed_dim)

    def forward(self, 
                stage1_feas=None,
                intrinsics=None, 
                const_offset=None,
                source_cameras_view_to_world=None,
                input_cameras=None,
                point_cloud_init=None,
                stage1_network_outputs=None):
        
        stage1_feas[-1] = self.up_channel(stage1_feas[-1])
        features = stage1_feas
        h, w = stage1_feas[-1].shape[-2:]
        
        ### embed camera parameters from 16 to 1024
        if self.cfg.data.mod_camera_dec:
            cam_emb = self.cam_emb(input_cameras)
        else:
            cam_emb = None
        gs_preds = self.dino_dec(features, 
                                 h, w, 
                                 intrinsics=intrinsics, 
                                 const_offset=const_offset,
                                 source_cameras_view_to_world=source_cameras_view_to_world,
                                 cam_emb=cam_emb,
                                 point_cloud_init=point_cloud_init,
                                 stage1_network_outputs=stage1_network_outputs)
        return gs_preds



