# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# modified from DUSt3R

import torch
import torch.nn as nn
import torch.nn.functional as F
from dust3r.heads.postprocess import (
    postprocess,
    postprocess_desc,
    postprocess_rgb,
    postprocess_pose_conf,
    postprocess_pose,
    reg_dense_conf,
)
import dust3r.utils.path_to_croco  # noqa
from dust3r.blocks import Mlp  # noqa
from dust3r.utils.geometry import geotrf
from dust3r.utils.camera import pose_encoding_to_camera, PoseDecoder
from dust3r.blocks import ConditionModulationBlock


class LinearPts3d(nn.Module):
    """
    Linear head for dust3r
    Each token outputs: - 16x16 3D points (+ confidence)
    """

    def __init__(
        self, net, has_conf=False, has_depth=False, has_rgb=False, has_pose_conf=False
    ):
        super().__init__()
        self.patch_size = net.patch_embed.patch_size[0]
        self.depth_mode = net.depth_mode
        self.conf_mode = net.conf_mode
        self.has_conf = has_conf
        self.has_rgb = has_rgb
        self.has_pose_conf = has_pose_conf
        self.has_depth = has_depth
        self.proj = Mlp(
            net.dec_embed_dim, out_features=(3 + has_conf) * self.patch_size**2
        )
        if has_depth:
            self.self_proj = Mlp(
                net.dec_embed_dim, out_features=(3 + has_conf) * self.patch_size**2
            )
        if has_rgb:
            self.rgb_proj = Mlp(net.dec_embed_dim, out_features=3 * self.patch_size**2)

    def setup(self, croconet):
        pass

    def forward(self, decout, img_shape):
        H, W = img_shape
        tokens = decout[-1]
        B, S, D = tokens.shape

        feat = self.proj(tokens)  # B,S,D
        feat = feat.transpose(-1, -2).view(
            B, -1, H // self.patch_size, W // self.patch_size
        )
        feat = F.pixel_shuffle(feat, self.patch_size)  # B,3,H,W

        final_output = postprocess(feat, self.depth_mode, self.conf_mode)
        final_output["pts3d_in_other_view"] = final_output.pop("pts3d")

        if self.has_depth:
            self_feat = self.self_proj(tokens)  # B,S,D
            self_feat = self_feat.transpose(-1, -2).view(
                B, -1, H // self.patch_size, W // self.patch_size
            )
            self_feat = F.pixel_shuffle(self_feat, self.patch_size)  # B,3,H,W
            self_3d_output = postprocess(self_feat, self.depth_mode, self.conf_mode)
            self_3d_output["pts3d_in_self_view"] = self_3d_output.pop("pts3d")
            self_3d_output["conf_self"] = self_3d_output.pop("conf")
            final_output.update(self_3d_output)

        if self.has_rgb:
            rgb_feat = self.rgb_proj(tokens)
            rgb_feat = rgb_feat.transpose(-1, -2).view(
                B, -1, H // self.patch_size, W // self.patch_size
            )
            rgb_feat = F.pixel_shuffle(rgb_feat, self.patch_size)  # B,3,H,W
            rgb_output = postprocess_rgb(rgb_feat)
            final_output.update(rgb_output)

        if self.has_pose_conf:
            pose_conf = self.pose_conf_proj(tokens)
            pose_conf = pose_conf.transpose(-1, -2).view(
                B, -1, H // self.patch_size, W // self.patch_size
            )
            pose_conf = F.pixel_shuffle(pose_conf, self.patch_size)
            pose_conf_output = postprocess_pose_conf(pose_conf)
            final_output.update(pose_conf_output)

        return final_output


class LinearPts3d_Desc(nn.Module):
    """
    Linear head for dust3r
    Each token outputs: - 16x16 3D points (+ confidence)
    """

    def __init__(
        self,
        net,
        has_conf=False,
        has_depth=False,
        local_feat_dim=24,
        hidden_dim_factor=4.0,
    ):
        super().__init__()
        self.patch_size = net.patch_embed.patch_size[0]
        self.depth_mode = net.depth_mode
        self.conf_mode = net.conf_mode
        self.has_conf = has_conf
        self.double_channel = has_depth
        self.local_feat_dim = local_feat_dim

        if not has_depth:
            self.proj = nn.Linear(
                net.dec_embed_dim, (3 + has_conf) * self.patch_size**2
            )
        else:
            self.proj = nn.Linear(
                net.dec_embed_dim, (3 + has_conf) * 2 * self.patch_size**2
            )
        idim = net.enc_embed_dim + net.dec_embed_dim
        self.head_local_features = Mlp(
            in_features=idim,
            hidden_features=int(hidden_dim_factor * idim),
            out_features=(self.local_feat_dim + 1) * self.patch_size**2,
        )

    def setup(self, croconet):
        pass

    def forward(self, decout, img_shape):
        H, W = img_shape
        tokens = decout[-1]
        B, S, D = tokens.shape

        feat = self.proj(tokens)  # B,S,D
        feat = feat.transpose(-1, -2).view(
            B, -1, H // self.patch_size, W // self.patch_size
        )
        feat = F.pixel_shuffle(feat, self.patch_size)  # B,3,H,W

        enc_output, dec_output = decout[0], decout[-1]
        cat_output = torch.cat([enc_output, dec_output], dim=-1)
        local_features = self.head_local_features(cat_output)  # B,S,D
        local_features = local_features.transpose(-1, -2).view(
            B, -1, H // self.patch_size, W // self.patch_size
        )
        local_features = F.pixel_shuffle(local_features, self.patch_size)  # B,d,H,W
        feat = torch.cat([feat, local_features], dim=1)

        return postprocess_desc(
            feat,
            self.depth_mode,
            self.conf_mode,
            self.local_feat_dim,
            self.double_channel,
        )


class LinearPts3dPoseDirect(nn.Module):
    """
    Linear head for dust3r
    Each token outputs: - 16x16 3D points (+ confidence)
    """

    def __init__(self, net, has_conf=False, has_rgb=False, has_pose=False):
        super().__init__()
        self.patch_size = net.patch_embed.patch_size[0]
        self.depth_mode = net.depth_mode
        self.conf_mode = net.conf_mode
        self.pose_mode = net.pose_mode
        self.has_conf = has_conf
        self.has_rgb = has_rgb
        self.has_pose = has_pose

        self.proj = Mlp(
            net.dec_embed_dim, out_features=(3 + has_conf) * self.patch_size**2
        )
        if has_rgb:
            self.rgb_proj = Mlp(net.dec_embed_dim, out_features=3 * self.patch_size**2)
        if has_pose:
            self.pose_head = PoseDecoder(hidden_size=net.dec_embed_dim)
        if has_conf:
            self.cross_conf_proj = Mlp(
                net.dec_embed_dim, out_features=self.patch_size**2
            )

    def setup(self, croconet):
        pass

    def forward(self, decout, img_shape):
        H, W = img_shape
        tokens = decout[-1]
        if self.has_pose:
            pose_token = tokens[:, 0]
            tokens = tokens[:, 1:]
        B, S, D = tokens.shape

        feat = self.proj(tokens)  # B,S,D
        feat = feat.transpose(-1, -2).view(
            B, -1, H // self.patch_size, W // self.patch_size
        )
        feat = F.pixel_shuffle(feat, self.patch_size)  # B,3,H,W
        final_output = postprocess(feat, self.depth_mode, self.conf_mode)
        final_output["pts3d_in_self_view"] = final_output.pop("pts3d")
        final_output["conf_self"] = final_output.pop("conf")

        if self.has_rgb:
            rgb_feat = self.rgb_proj(tokens)
            rgb_feat = rgb_feat.transpose(-1, -2).view(
                B, -1, H // self.patch_size, W // self.patch_size
            )
            rgb_feat = F.pixel_shuffle(rgb_feat, self.patch_size)  # B,3,H,W
            rgb_output = postprocess_rgb(rgb_feat)
            final_output.update(rgb_output)

        if self.has_pose:
            pose = self.pose_head(pose_token)
            pose = postprocess_pose(pose, self.pose_mode)
            final_output["camera_pose"] = pose  # B,7
            final_output["pts3d_in_other_view"] = geotrf(
                pose_encoding_to_camera(final_output["camera_pose"]),
                final_output["pts3d_in_self_view"],
            )

        if self.has_conf:
            cross_conf = self.cross_conf_proj(tokens)
            cross_conf = cross_conf.transpose(-1, -2).view(
                B, -1, H // self.patch_size, W // self.patch_size
            )
            cross_conf = F.pixel_shuffle(cross_conf, self.patch_size)[:, 0]
            final_output["conf"] = reg_dense_conf(cross_conf, mode=self.conf_mode)
        return final_output


class LinearPts3dPose(nn.Module):
    """
    Linear head for dust3r
    Each token outputs: - 16x16 3D points (+ confidence)
    """

    def __init__(
        self, net, has_conf=False, has_rgb=False, has_pose=False, mlp_ratio=4.0
    ):
        super().__init__()
        self.patch_size = net.patch_embed.patch_size[0]
        self.depth_mode = net.depth_mode
        self.conf_mode = net.conf_mode
        self.pose_mode = net.pose_mode
        self.has_conf = has_conf
        self.has_rgb = has_rgb
        self.has_pose = has_pose

        self.proj = Mlp(
            net.dec_embed_dim,
            hidden_features=int(mlp_ratio * net.dec_embed_dim),
            out_features=(3 + has_conf) * self.patch_size**2,
        )
        if has_rgb:
            self.rgb_proj = Mlp(
                net.dec_embed_dim,
                hidden_features=int(mlp_ratio * net.dec_embed_dim),
                out_features=3 * self.patch_size**2,
            )
        if has_pose:
            self.pose_head = PoseDecoder(hidden_size=net.dec_embed_dim)
            self.final_transform = nn.ModuleList(
                [
                    ConditionModulationBlock(
                        net.dec_embed_dim,
                        net.dec_num_heads,
                        mlp_ratio=4.0,
                        qkv_bias=True,
                        rope=net.rope,
                    )
                    for _ in range(2)
                ]
            )
            self.cross_proj = Mlp(
                net.dec_embed_dim,
                hidden_features=int(mlp_ratio * net.dec_embed_dim),
                out_features=(3 + has_conf) * self.patch_size**2,
            )

    def setup(self, croconet):
        pass

    def forward(self, decout, img_shape, **kwargs):
        H, W = img_shape
        tokens = decout[-1]
        if self.has_pose:
            pose_token = tokens[:, 0]
            tokens = tokens[:, 1:]
            with torch.amp.autocast("cuda", enabled=False):
                pose = self.pose_head(pose_token)
            cross_tokens = tokens
            for blk in self.final_transform:
                cross_tokens = blk(cross_tokens, pose_token, kwargs.get("pos"))

        with torch.amp.autocast("cuda", enabled=False):
            B, S, D = tokens.shape

            feat = self.proj(tokens)  # B,S,D
            feat = feat.transpose(-1, -2).view(
                B, -1, H // self.patch_size, W // self.patch_size
            )
            feat = F.pixel_shuffle(feat, self.patch_size)  # B,3,H,W
            final_output = postprocess(
                feat, self.depth_mode, self.conf_mode, pos_z=True
            )
            final_output["pts3d_in_self_view"] = final_output.pop("pts3d")
            final_output["conf_self"] = final_output.pop("conf")

            if self.has_rgb:
                rgb_feat = self.rgb_proj(tokens)
                rgb_feat = rgb_feat.transpose(-1, -2).view(
                    B, -1, H // self.patch_size, W // self.patch_size
                )
                rgb_feat = F.pixel_shuffle(rgb_feat, self.patch_size)  # B,3,H,W
                rgb_output = postprocess_rgb(rgb_feat)
                final_output.update(rgb_output)

            if self.has_pose:
                pose = postprocess_pose(pose, self.pose_mode)
                final_output["camera_pose"] = pose  # B,7

                cross_feat = self.cross_proj(cross_tokens)  # B,S,D
                cross_feat = cross_feat.transpose(-1, -2).view(
                    B, -1, H // self.patch_size, W // self.patch_size
                )
                cross_feat = F.pixel_shuffle(cross_feat, self.patch_size)  # B,3,H,W
                tmp = postprocess(cross_feat, self.depth_mode, self.conf_mode)
                final_output["pts3d_in_other_view"] = tmp.pop("pts3d")
                final_output["conf"] = tmp.pop("conf")

            return final_output
