import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from copy import deepcopy
from functools import partial
from typing import Optional, Tuple, List, Any
from dataclasses import dataclass
from transformers import PretrainedConfig
from transformers import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutput
from transformers.file_utils import ModelOutput
from dust3r.utils.misc import (
    fill_default_args,
    freeze_all_params,
    transpose_to_landscape,
)
from dust3r.heads import head_factory
from dust3r.utils.camera import PoseEncoder
from dust3r.patch_embed import get_patch_embed
from models.croco import CroCoNet, CrocoConfig  
from dust3r.decode_blocks import (
    Block,
    MemoryDecoderBlock,
    DecoderBlock,
) 

inf = float("inf")
from accelerate.logging import get_logger

printer = get_logger(__name__, log_level="DEBUG")

@dataclass
class ARCroco3DStereoOutput(ModelOutput):
    
    ress: Optional[List[Any]] = None
    views: Optional[List[Any]] = None


def strip_module(state_dict):
    """
    Removes the 'module.' prefix from the keys of a state_dict.
    Args:
        state_dict (dict): The original state_dict with possible 'module.' prefixes.
    Returns:
        OrderedDict: A new state_dict with 'module.' prefixes removed.
    """
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:] if k.startswith("module.") else k
        new_state_dict[name] = v
    return new_state_dict

def from_dust3r_to_ours(state_dict):
    
    new_state_dict = OrderedDict()
        
    for k, v in state_dict.items():
        if k.startswith("dec_blocks2."):
            k = k.replace("dec_blocks2.", "dec_blocks_memory.")
        elif k.startswith("downstream_head1.dpt."):
            k = k.replace("downstream_head1.dpt.", "downstream_head.dpt_self.")
        elif k.startswith("downstream_head2.dpt."):
            k = k.replace("downstream_head2.dpt.", "downstream_head.dpt_cross.")
        name = k
        new_state_dict[name] = v
        
    return new_state_dict


def load_model(model_path, device, verbose=True):
    if verbose:
        print("... loading model from", model_path)
    ckpt = torch.load(model_path, map_location="cpu")
    args = ckpt["args"].model.replace(
        "ManyAR_PatchEmbed", "PatchEmbedDust3R"
    )  
    if "landscape_only" not in args:
        args = args[:-2] + ", landscape_only=False))"
    else:
        args = args.replace(" ", "").replace(
            "landscape_only=True", "landscape_only=False"
        )
    assert "landscape_only=False" in args
    if verbose:
        print(f"instantiating : {args}")
    net = eval(args)
    s = net.load_state_dict(ckpt["model"], strict=False)
    if verbose:
        print(s)
    return net.to(device)


class Point3RConfig(PretrainedConfig):
    model_type = "arcroco_3d_stereo"

    def __init__(
        self,
        output_mode="pts3d",
        head_type="linear", 
        depth_mode=("exp", -float("inf"), float("inf")),
        conf_mode=("exp", 1, float("inf")),
        pose_mode=("exp", -float("inf"), float("inf")),
        freeze="none",
        landscape_only=True,
        patch_embed_cls="PatchEmbedDust3R",
        local_mem_size=256,
        memory_dec_num_heads=16,
        depth_head=False,
        pose_conf_head=False,
        pose_head=False,
        **croco_kwargs,
    ):
        super().__init__()
        self.output_mode = output_mode
        self.head_type = head_type
        self.depth_mode = depth_mode
        self.conf_mode = conf_mode
        self.pose_mode = pose_mode
        self.freeze = freeze
        self.landscape_only = landscape_only
        self.patch_embed_cls = patch_embed_cls
        self.memory_dec_num_heads = memory_dec_num_heads
        self.local_mem_size = local_mem_size
        self.depth_head = depth_head
        self.pose_conf_head = pose_conf_head
        self.pose_head = pose_head
        self.croco_kwargs = croco_kwargs

class Point3R(CroCoNet):
    config_class = Point3RConfig
    base_model_prefix = "arcroco3dstereo"
    supports_gradient_checkpointing = True

    def __init__(self, config: Point3RConfig):
        self.gradient_checkpointing = False
        self.fixed_input_length = True
        config.croco_kwargs = fill_default_args(
            config.croco_kwargs, CrocoConfig.__init__
        )
        self.config = config
        self.patch_embed_cls = config.patch_embed_cls
        self.croco_args = config.croco_kwargs
        croco_cfg = CrocoConfig(**self.croco_args)
        super().__init__(croco_cfg)
        self.dec_num_heads = self.croco_args["dec_num_heads"]
        self.pose_head_flag = config.pose_head
        if self.pose_head_flag:
            self.pose_token = nn.Parameter(
                torch.randn(1, 1, self.dec_embed_dim) * 0.02, requires_grad=True
            )
        
        self._set_memory_decoder(
            self.enc_embed_dim,
            self.dec_embed_dim,
            config.memory_dec_num_heads,
            self.dec_depth,
            self.croco_args.get("mlp_ratio", None),
            self.croco_args.get("norm_layer", None),
            self.croco_args.get("norm_im2_in_dec", None),
        )

        self._set_value_encoder(
            enc_depth=6, 
            enc_embed_dim=1024, 
            out_dim=1024, 
            enc_num_heads=16,
            mlp_ratio=4, 
            norm_layer=self.croco_args.get("norm_layer", None),
        )

        self.set_downstream_head(
            config.output_mode,
            config.head_type,
            config.landscape_only,
            config.depth_mode,
            config.conf_mode,
            config.pose_mode,
            config.depth_head,
            config.pose_conf_head,
            config.pose_head,
            **self.croco_args,
        )
        self.memory_attn_head = nn.Sequential(
            nn.Linear(self.enc_embed_dim+self.dec_embed_dim, self.enc_embed_dim+self.dec_embed_dim),
            nn.GELU(),
            nn.Linear(self.enc_embed_dim+self.dec_embed_dim, self.enc_embed_dim))

        self.set_freeze(config.freeze)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, **kw):
        if os.path.isfile(pretrained_model_name_or_path):
            return load_model(pretrained_model_name_or_path, device="cpu")
        else:
            try:
                model = super(Point3R, cls).from_pretrained(
                    pretrained_model_name_or_path, **kw
                )
            except TypeError as e:
                raise Exception(
                    f"tried to load {pretrained_model_name_or_path} from huggingface, but failed"
                )
            return model

    def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768):
        self.patch_embed = get_patch_embed(
            self.patch_embed_cls, img_size, patch_size, enc_embed_dim, in_chans=3
        )
        self.pts_patch_embed = get_patch_embed(
            self.patch_embed_cls, img_size, patch_size, enc_embed_dim, in_chans=3
        )

    def _set_decoder(
        self,
        enc_embed_dim,
        dec_embed_dim,
        dec_num_heads,
        dec_depth,
        mlp_ratio,
        norm_layer,
        norm_im2_in_dec,
    ):
        self.dec_depth = dec_depth
        self.dec_embed_dim = dec_embed_dim
        self.decoder_embed = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True)
        self.dec_blocks = nn.ModuleList(
            [
                DecoderBlock(
                    dec_embed_dim,
                    dec_num_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=True,
                    norm_layer=norm_layer,
                    norm_mem=norm_im2_in_dec,
                    rope=self.rope,
                    rope3d=self.rope3d,
                )
                for i in range(dec_depth)
            ]
        )
        self.dec_norm = norm_layer(dec_embed_dim)

    def _set_memory_decoder(
        self,
        enc_embed_dim,
        dec_embed_dim,
        dec_num_heads,
        dec_depth,
        mlp_ratio,
        norm_layer,
        norm_im2_in_dec,
    ):
        self.dec_depth_memory = dec_depth
        self.dec_embed_dim_memory = dec_embed_dim
        self.decoder_embed_memory = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True)
        self.dec_blocks_memory = nn.ModuleList(
            [
                MemoryDecoderBlock(
                    dec_embed_dim,
                    dec_num_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=True,
                    norm_layer=norm_layer,
                    norm_mem=norm_im2_in_dec,
                    rope=self.rope,
                    rope3d=self.rope3d,
                )
                for i in range(dec_depth)
            ]
        )
        self.dec_norm_memory = norm_layer(dec_embed_dim)
    
    def _set_value_encoder(
        self,
        enc_depth, 
        enc_embed_dim, 
        out_dim, 
        enc_num_heads,
        mlp_ratio, 
        norm_layer
    ):
        self.value_encoder = nn.ModuleList(
            [
                Block(
                    enc_embed_dim, 
                    enc_num_heads, 
                    mlp_ratio, 
                    qkv_bias=True, 
                    norm_layer=norm_layer, 
                    rope=self.rope
                )
                for i in range(enc_depth)
            ]
        )
        self.value_norm = norm_layer(enc_embed_dim)
        self.value_out = nn.Linear(enc_embed_dim, out_dim)

    def load_state_dict(self, ckpt, **kw):
        if all(k.startswith("module") for k in ckpt):
            ckpt = strip_module(ckpt)
        new_ckpt = dict(ckpt)
        if not any(k.startswith("pts_patch_embed") for k in ckpt):
            for key, value in ckpt.items():
                if key.startswith("patch_embed"):
                    new_ckpt[key.replace("patch_embed", "pts_patch_embed")] = value
        
        return super().load_state_dict(new_ckpt, **kw)

    def set_freeze(self, freeze):  # this is for use by downstream models
        self.freeze = freeze
        to_be_frozen = {
            "none": [],
            "mask": [self.mask_token] if hasattr(self, "mask_token") else [],
            "encoder": [
                self.patch_embed,
                self.enc_blocks,
                self.enc_norm,
            ],
        }
        freeze_all_params(to_be_frozen[freeze])

    def set_prediction_head(self, *args, **kwargs):
        return

    def set_downstream_head(
        self,
        output_mode,
        head_type,
        landscape_only,
        depth_mode,
        conf_mode,
        pose_mode,
        depth_head,
        pose_conf_head,
        pose_head,
        patch_size,
        img_size,
        **kw,
    ):
        assert (
            img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0
        ), f"{img_size=} must be multiple of {patch_size=}"
        self.output_mode = output_mode
        self.head_type = head_type
        self.depth_mode = depth_mode
        self.conf_mode = conf_mode
        self.pose_mode = pose_mode
        self.downstream_head = head_factory(
            head_type,
            output_mode,
            self,
            has_conf=bool(conf_mode),
            has_depth=bool(depth_head),
            has_pose_conf=bool(pose_conf_head),
            has_pose=bool(pose_head),
        )
        self.head = transpose_to_landscape(
            self.downstream_head, activate=landscape_only
        )

    def encode_image(self, image, true_shape):
        x, pos = self.patch_embed(image, true_shape=true_shape)
        assert self.enc_pos_embed is None
        for blk in self.enc_blocks:
            if self.gradient_checkpointing and self.training:
                x = checkpoint(blk, x, pos, use_reentrant=False)
            else:
                x = blk(x, pos)
        x = self.enc_norm(x)
        return [x], pos, None

    
    def encode_views(self, views, img_mask=None):
        device = views[0]["img"].device
        batch_size = views[0]["img"].shape[0]
        given = True
        if img_mask is None:
            given = False
        if not given:
            img_mask = torch.stack(
                [view["img_mask"] for view in views], dim=0
            )  
        imgs = torch.stack(
            [view["img"] for view in views], dim=0
        )  
        shapes = []
        for view in views:
            if "true_shape" in view:
                shapes.append(view["true_shape"])
            else:
                shape = torch.tensor(view["img"].shape[-2:], device=device)
                shapes.append(shape.unsqueeze(0).repeat(batch_size, 1))
        shapes = torch.stack(shapes, dim=0).to(
            imgs.device
        ) 
        imgs = imgs.view(
            -1, *imgs.shape[2:]
        )  
        shapes = shapes.view(-1, 2)  
        img_masks_flat = img_mask.view(-1)  
        selected_imgs = imgs[img_masks_flat]
        selected_shapes = shapes[img_masks_flat]
        
        if selected_imgs.size(0) > 0:
            img_out, img_pos, _ = self.encode_image(selected_imgs, selected_shapes)
        else:
            raise NotImplementedError
        full_out = [
            torch.zeros(
                len(views) * batch_size, *img_out[0].shape[1:], device=img_out[0].device
            )
            for _ in range(len(img_out))
        ]
        full_pos = torch.zeros(
            len(views) * batch_size,
            *img_pos.shape[1:],
            device=img_pos.device,
            dtype=img_pos.dtype,
        )
        for i in range(len(img_out)):
            full_out[i][img_masks_flat] += img_out[i]
        full_pos[img_masks_flat] += img_pos
        
        return (
            shapes.chunk(len(views), dim=0),
            [out.chunk(len(views), dim=0) for out in full_out],
            full_pos.chunk(len(views), dim=0),
        )

    def _decoder(self, i, mask_mem, f_mem, pos_mem, f_img, pos_img, f_pose, pos_pose, point3r_tag=False):
        if isinstance(f_mem, torch.Tensor):
            assert f_mem.shape[-1] == self.dec_embed_dim
        else:
            assert f_mem[-1].shape[-1] == self.dec_embed_dim
        
        final_output = [(f_mem, f_img)]  
        f_img = self.decoder_embed(f_img)
        if self.pose_head_flag:
            assert f_pose is not None
            f_img = torch.cat([f_pose, f_img], dim=1)
        final_output.append((f_mem, f_img))
        
        for blk_mem, blk_img in zip(self.dec_blocks_memory, self.dec_blocks):
            if (
                self.gradient_checkpointing
                and self.training
                and torch.is_grad_enabled()
            ):
                f_mem, _ = checkpoint(
                    blk_mem,
                    i,
                    *final_output[-1][::+1],
                    mask_mem,
                    pos_mem,
                    pos_img,
                    point3r_tag,
                    use_reentrant=not self.fixed_input_length,
                )
                f_img, _ = checkpoint(
                    blk_img,
                    i,
                    *final_output[-1][::-1],
                    mask_mem,
                    pos_img,
                    pos_mem,
                    point3r_tag,
                    use_reentrant=not self.fixed_input_length,
                )
            else:
                f_mem, _ = blk_mem(i, *final_output[-1][::+1], mask_mem, pos_mem, pos_img, point3r_tag=point3r_tag)
                f_img, _ = blk_img(i, *final_output[-1][::-1], mask_mem, pos_img, pos_mem, point3r_tag=point3r_tag)
            final_output.append((f_mem, f_img))
        del final_output[1]  
        final_output[-1] = (
            self.dec_norm_memory(final_output[-1][0]),
            self.dec_norm(final_output[-1][1]),
        )
        return zip(*final_output)

    def _downstream_head(self, decout, img_shape, **kwargs):
        B, S, D = decout[-1].shape
        head = getattr(self, f"head")
        return head(decout, img_shape, **kwargs)

    def _init_mem(self, image_tokens, image_pos):
        mem_feat = self.decoder_embed_memory(image_tokens)
        return mem_feat, None

    def _recurrent_rollout(
        self,
        i,
        mask_mem,
        mem_feat,
        mem_pos,
        current_feat,
        current_pos,
        pose_feat,
        pose_pos,
        point3r_tag=False,
    ):
        new_mem_feat, dec = self._decoder(
            i,
            mask_mem,
            mem_feat, mem_pos, current_feat, current_pos, pose_feat, pose_pos,
            point3r_tag=point3r_tag,
        )
        new_mem_feat = new_mem_feat[-1]
        return new_mem_feat, dec
    
    def enc_pts_value(self, pts, shape):
        out, pos = self.pts_patch_embed(pts.permute(0, 3, 1, 2), true_shape=shape)
        for block in self.value_encoder:
            out = block(out, pos)
        out = self.value_norm(out)
        out = self.value_out(out)
        return out

    def _forward_addmem(
        self,
        i,
        pts3d,
        mem_feat,
        mem_pos,
        feat_i,
        dec_i,
        shape_i,
        ):
        bs, img_h, img_w, _ = pts3d.shape
        img_pos_len_h = img_h // 16
        img_pos_len_w = img_w // 16
        img_pos = pts3d.permute(0, 3, 1, 2)
        img_pos = img_pos.unfold(2, 16, 16)
        img_pos = img_pos.unfold(3, 16, 16)
        img_pos = img_pos.reshape(bs, 3, img_pos_len_h, img_pos_len_w, -1).mean(dim=-1).permute(0, 2, 3, 1).reshape(bs, -1, 3)
        
        feat_key = self.mem_attn_head(torch.cat((feat_i, dec_i), dim=-1))
        feat_pts = self.enc_pts_value(pts3d, shape_i)
        mem_add = self.decoder_embed_memory(feat_key+feat_pts)
        mem_add = mem_add.float()

        if i == 0:
            mem_feat = mem_add
            chosen_pts = img_pos
        else:
            mem_feat = torch.cat((mem_feat, mem_add), dim=1)
            chosen_pts = torch.cat((mem_pos, img_pos), dim=1)
          
        return mem_feat, chosen_pts, img_pos

    def _forward_addmem_merge(
        self,
        i,
        pts3d,
        mem_feat,
        mem_pos,
        feat_i,
        dec_i,
        shape_i,
        ):
        bs, img_h, img_w, _ = pts3d.shape
        img_pos_len_h = img_h // 16
        img_pos_len_w = img_w // 16
        img_pos = pts3d.permute(0, 3, 1, 2)
        img_pos = img_pos.unfold(2, 16, 16)
        img_pos = img_pos.unfold(3, 16, 16)
        img_pos = img_pos.reshape(bs, 3, img_pos_len_h, img_pos_len_w, -1).mean(dim=-1).permute(0, 2, 3, 1).reshape(bs, -1, 3)
        
        feat_key = self.memory_attn_head(torch.cat((feat_i, dec_i), dim=-1))
        feat_pts = self.enc_pts_value(pts3d, shape_i)
        mem_add = self.decoder_embed_memory(feat_key+feat_pts)
        mem_add = mem_add.float()

        if i == 0:
            mem_feat = mem_add
            chosen_pts = img_pos
        else:
            len_unit = 20
            mem_feat_list = []
            mem_pos_list = []
            for j in range(bs):
                mem_pos_j = mem_pos[j]
                mem_feat_j = mem_feat[j]
                img_pos_j = img_pos[j]
                new_feat_j = mem_add[j]
                unit_j = (torch.cat((mem_pos_j, img_pos_j), dim=0).max(dim=0).values - torch.cat((mem_pos_j, img_pos_j), dim=0).min(dim=0).values) / len_unit
                threshold_j = torch.norm(unit_j)
                distances = torch.cdist(img_pos_j, mem_pos_j)
                min_dists, min_indices = distances.min(dim=-1)
                mask_add = min_dists >= threshold_j
                mask_merge = ~mask_add
                if mask_merge.sum() > 0:
                    indices_merge = min_indices[mask_merge]
                    pos_merge = img_pos_j[mask_merge]
                    feat_merge = new_feat_j[mask_merge]
                    unique_indices, inverse_indices = torch.unique(indices_merge, return_inverse=True)
                    num_merge = unique_indices.shape[0]
                    pos_sum = torch.zeros((num_merge, 3), device=mem_pos_j.device)
                    feat_sum = torch.zeros((num_merge, 768), device=mem_feat_j.device)
                    count = torch.zeros((num_merge, 1), device=mem_feat_j.device)
                    pos_sum.index_add_(0, inverse_indices, pos_merge)
                    feat_sum.index_add_(0, inverse_indices, feat_merge)
                    count.index_add_(0, inverse_indices, torch.ones_like(inverse_indices, dtype=torch.float32).unsqueeze(1))
                    pos_avg = pos_sum / count
                    pos_avg = pos_avg.float()
                    feat_avg = feat_sum / count
                    feat_avg = feat_avg.float()
                    mem_pos_j[unique_indices] = pos_avg
                    mem_feat_j[unique_indices] = feat_avg
                if mask_add.sum() > 0:
                    pos_add = img_pos_j[mask_add]
                    feat_add = new_feat_j[mask_add]
                    mem_pos_j = torch.cat([mem_pos_j, pos_add], dim=0)
                    mem_feat_j = torch.cat([mem_feat_j, feat_add], dim=0)
                mem_feat_list.append(mem_feat_j)
                mem_pos_list.append(mem_pos_j)
        
        if i == 0:
            return mem_feat, chosen_pts, img_pos
        else:
            return mem_feat_list, mem_pos_list, img_pos

    def forward(self, views, point3r_tag=False):
        shape, feat_ls, pos = self.encode_views(views)
        feat = feat_ls[-1]
        mem_feat, _ = self._init_mem(feat[0], pos[0])
        ress = []
        pos_decode_img = None
        pos_decode_mem = None
        merge_tag = False
        for i in range(len(views)):
            feat_i = feat[i]
            pos_i = pos[i]
            if i >= 2:
                merge_tag = True
            if merge_tag:
                mem_len_max = max(f_mem_j.shape[0] for f_mem_j in mem_feat)
                f_mem_list_padded = []
                pos_mem_list_padded = []
                mask_mem_list_padded = []
                for j in range(len(mem_feat)):
                    f_mem_j = mem_feat[j]
                    pos_mem_j = pos_decode_mem[j]
                    padding_size = mem_len_max - f_mem_j.shape[0]
                    padding = torch.zeros(padding_size, f_mem_j.shape[1]).to(f_mem_j.device)
                    padding_pos = torch.zeros(padding_size, pos_mem_j.shape[1]).to(pos_mem_j.device)
                    mask_valid = torch.ones(f_mem_j.shape[0]).to(f_mem_j.device)
                    mask_invalid = torch.zeros(padding_size).to(f_mem_j.device)
                    padded_mem_j = torch.cat((f_mem_j, padding), dim=0)
                    padded_pos_mem_j = torch.cat((pos_mem_j, padding_pos), dim=0)
                    padded_mask_j = torch.cat((mask_valid, mask_invalid), dim=0)
                    f_mem_list_padded.append(padded_mem_j)
                    pos_mem_list_padded.append(padded_pos_mem_j)
                    mask_mem_list_padded.append(padded_mask_j)
                mem_feat = torch.stack(f_mem_list_padded, dim=0)
                pos_decode_mem = torch.stack(pos_mem_list_padded, dim=0)
                mask_mem = torch.stack(mask_mem_list_padded, dim=0)
            else:
                mask_mem = None
            
            if self.pose_head_flag:
                pose_feat_i = self.pose_token.expand(feat_i.shape[0], -1, -1)
                pose_pos_i = None
            else:
                pose_feat_i = None
                pose_pos_i = None

            new_mem_feat, dec = self._recurrent_rollout(
                i,
                mask_mem,
                mem_feat,
                pos_decode_mem,
                feat_i,
                pos_decode_img,
                pose_feat_i,
                pose_pos_i,
                point3r_tag=point3r_tag,
            )
            
            assert len(dec) == self.dec_depth + 1
            head_input = [
                dec[0].float(),
                dec[self.dec_depth * 2 // 4][:, 1:].float(),
                dec[self.dec_depth * 3 // 4][:, 1:].float(),
                dec[self.dec_depth].float(),
            ]

            res = self._downstream_head(head_input, shape[i], pos=pos_i)
            ress.append(res)

            update_mask_mem = torch.tensor([False]*mem_feat.shape[0], device=mem_feat.device)
            update_mask_mem = update_mask_mem[:, None, None].float()
            mem_feat = new_mem_feat * update_mask_mem + mem_feat * (1 - update_mask_mem)  
            
            if mask_mem is not None:
                mem_feat_new_list = []
                pos_decode_mem_new_list = []
                for j in range(mask_mem.shape[0]):
                    j_mask_mem = mask_mem[j]
                    j_mask_mem = j_mask_mem.bool()
                    j_mem_feat = mem_feat[j]
                    j_pos_decode_mem = pos_decode_mem[j]
                    j_mem_feat = j_mem_feat[j_mask_mem]
                    j_pos_decode_mem = j_pos_decode_mem[j_mask_mem]
                    mem_feat_new_list.append(j_mem_feat)
                    pos_decode_mem_new_list.append(j_pos_decode_mem)
                mem_feat = mem_feat_new_list
                pos_decode_mem = pos_decode_mem_new_list

            if point3r_tag:
                this_pts3d = res['pts3d_in_other_view'].clone().detach()
                if pos_decode_mem is not None:
                    if isinstance(pos_decode_mem, torch.Tensor):
                        pos_decode_mem = pos_decode_mem.clone().detach()
                    else:
                        pos_decode_mem = [pos_decode_mem_in.clone().detach() for pos_decode_mem_in in pos_decode_mem]
                mem_feat, pos_decode_mem, pos_decode_img= self._forward_addmem_merge(
                    i,
                    pts3d=this_pts3d,
                    mem_feat=mem_feat,
                    mem_pos=pos_decode_mem,
                    feat_i=feat_i.clone().detach(),
                    dec_i=dec[-1][:, 1:].clone().detach(),
                    shape_i=views[i]['true_shape'],
                )

        return ARCroco3DStereoOutput(ress=ress, views=views)


    