import flax.linen as nn
import typing
from typing import List, Dict
import jax.numpy as jnp
import numpy as np
import jax
import einops
import importlib
from flax.core.frozen_dict import FrozenDict
from flax import struct
import pickle
import time
import os, sys
from dataclasses import replace
from functools import partial

BASEDIR = os.path.dirname(os.path.dirname(__file__))
if BASEDIR not in sys.path:
    sys.path.insert(0, BASEDIR)

# import util.cvx_util as cxutil
import util.latent_obj_util as loutil
import util.ev_util.ev_util as eutil
import util.ev_util.ev_layers as evl
import util.ev_util.rotm_util as rmutil
import util.transform_util as tutil
import util.camera_util as cutil
import util.diffusion_util as dfutil
import util.structs as structs
from dinov2_jax.vit import DinoViT
from dinov2_jax.dino_weights import dino_config_by_size

C_SKIP_MIN_VAL = 0.1

class ImageFeature(nn.Module):
    base_dim:int=8
    depth:int=1

    @nn.compact
    def __call__(self, x, train=False):

        
        def cnn(base_dim, filter, depth, x_):
            for _ in range(depth):
                x_ = nn.Conv(base_dim, (filter,filter))(x_)
                x_ = nn.relu(x_)
            return x_

        if x.dtype in [jnp.uint8, jnp.int16, jnp.int32]:
            x = x.astype(jnp.float32)/255.

        # down
        c_list = []
        for _ in range(2):
            x = nn.Conv(self.base_dim, (5,5))(x)
            x = nn.relu(x)
        x = cnn(2*self.base_dim, 3, self.depth, x)
        c_list.append(x)
        x = nn.Conv(2*self.base_dim, (3,3), strides=(2,2))(x)
        x = nn.relu(x)
        x = cnn(4*self.base_dim, 3, self.depth, x)
        c_list.append(x)
        x = nn.Conv(4*self.base_dim, (3,3), strides=(2,2))(x)
        x = nn.relu(x)
        x = cnn(8*self.base_dim, 3, self.depth, x)
        c_list.append(x)
        x = nn.Conv(8*self.base_dim, (3,3), strides=(2,2))(x)
        x = nn.relu(x)

        x = nn.Conv(8*self.base_dim, (3,3), kernel_dilation=(2,2))(x)
        x = nn.relu(x)
        x = nn.Conv(8*self.base_dim, (5,5), kernel_dilation=(2,2))(x)
        x = nn.relu(x)
        x = cnn(8*self.base_dim, 3, self.depth, x)

        def repeat_ft(x, r, ft_dim):
            x = nn.Dense(ft_dim)(x)
            x = nn.relu(x)
            x = einops.repeat(x, '... i j k -> ... (i r1) (j r2) k', r1=r, r2=r)
            return x

        # up
        c_list = list(reversed(c_list))
        # p_list = [einops.repeat(x, '... i j k -> ... (i r1) (j r2) k', r1=8, r2=8)]
        p_list = [repeat_ft(x, 8, 4*self.base_dim)]
        x = nn.ConvTranspose(8*self.base_dim, (3,3), strides=(2,2))(x)
        x = nn.relu(x) + c_list[0]
        # p_list.append(einops.repeat(x, '... i j k -> ... (i r1) (j r2) k', r1=4, r2=4))
        p_list.append(repeat_ft(x, 4, 4*self.base_dim))
        x = nn.ConvTranspose(4*self.base_dim, (3,3), strides=(2,2))(x)
        x = nn.relu(x) + c_list[1]
        # p_list.append(einops.repeat(x, '... i j k -> ... (i r1) (j r2) k', r1=2, r2=2))
        p_list.append(repeat_ft(x, 2, 4*self.base_dim))
        x = nn.ConvTranspose(2*self.base_dim, (3,3), strides=(2,2))(x)
        x = nn.relu(x) + c_list[2]
        p_list.append(repeat_ft(x, 1, 4*self.base_dim))

        x = jnp.concatenate(p_list, axis=-1)

        return x

class ImageFeatureEntire(nn.Module):
    dino_size:str='s'
    
    @nn.compact
    def __call__(self, imgs, cam_posquat, cam_intrinsic, train=False):

        if imgs.dtype == jnp.uint8 or imgs.dtype == jnp.int16 or imgs.dtype == jnp.int32:
            imgs = imgs.astype(jnp.float32)/255.
        
        valid_img_mask = jnp.sum(jnp.abs(imgs), axis=(-1,-2,-3), keepdims=True) > 1e-3

        # preprocessing image size -> should be devided by 14
        original_img_size = imgs.shape[-3:-1]
        patch_size = [int(np.ceil(s/14)) for s in original_img_size]
        img_size = [int(s*14) for s in patch_size]

        # cam_intrinsic_resized = cutil.resize_intrinsic(cam_intrinsic, original_img_size, img_size)
        imgs_resized = cutil.resize_img(imgs, img_size, method='linear', base_operator='jnp')

        # load predefined DINOv2 ViT model
        num_heads, embed_dim, depth = dino_config_by_size(self.dino_size)
        img_feat = DinoViT(num_heads=num_heads,
            embed_dim=embed_dim,
            depth=depth,
            mlp_ratio=4,
            img_size=img_size,
            register_tokens=True)(imgs_resized, out_depth=None, train=False)
        if train:
            img_feat = jax.lax.stop_gradient(img_feat)
        img_feat = img_feat['x_norm_patchtokens']

        # patch size
        img_feat = img_feat.reshape(img_feat.shape[:-2]+(*patch_size, img_feat.shape[-1]))
        img_feat = jnp.where(valid_img_mask, img_feat, 0.0)

        if self.dino_size == 's':
            img_feat_dim = 160
        elif self.dino_size == 'b' or self.dino_size == 'l':
            img_feat_dim = 256
        for i in range(3):
            img_feat = nn.Conv(img_feat_dim, (5,5), padding='SAME')(img_feat)
            img_feat = nn.gelu(img_feat)
            if i==0:
                skip_feat = img_feat
        
        img_feat = nn.Conv(img_feat_dim, (3,3), kernel_dilation=(2,2))(img_feat)
        img_feat = nn.gelu(img_feat)
        img_feat = nn.Conv(img_feat_dim, (3,3), kernel_dilation=(2,2))(img_feat)
        img_feat = nn.gelu(img_feat)
        img_feat = img_feat + skip_feat

        # up img feature resolution
        img_feat = nn.ConvTranspose(img_feat_dim, (3,3), strides=(2,2), padding='SAME')(img_feat)
        img_feat = nn.gelu(img_feat)

        img_feat = jnp.where(valid_img_mask, img_feat, 0.0)

        cam_intrinsic_resized = cutil.resize_intrinsic(cam_intrinsic, original_img_size, img_feat.shape[-3:-1])
        img_feat_structs = structs.ImgFeatures(intrinsic=cam_intrinsic_resized, cam_posquat=cam_posquat, img_feat=img_feat)
        # img_feat_structs = SpatialPE()(img_feat_structs)

        img_feat_structs = img_feat_structs.replace(rgb=imgs)

        return img_feat_structs

class ResConvUnit(nn.Module):
    @nn.compact
    def __call__(self, x):
        out = nn.gelu(x)
        out = nn.Conv(x.shape[-1], (3,3), padding='SAME')(out)
        out = nn.gelu(out)
        out = nn.Conv(x.shape[-1], (3,3), padding='SAME')(out)
        return x + out

class DPT(nn.Module):
    """Decoder for the DPT model."""
    # decoder_channels: List[int]
    decoder_channel:int = 256

    @nn.compact
    def __call__(self, encoder_features:jnp.ndarray, train: bool = True):
        # Simulate multi-scale features by applying convolutions with stride
        scales = [4, 4, 2, 1]  # Simulate different scales
        nlayers = encoder_features.shape[-4]
        processed_features = []
        for layer_idx in range(nlayers):
            feat = encoder_features[...,layer_idx,:,:,:]  # Shape: (B, H_patches, W_patches, C)
            scale = scales[layer_idx]
            feat = nn.Conv(features=self.decoder_channel, kernel_size=(1, 1))(feat)
            if scale > 1:
                feat = jax.image.resize(feat, (*feat.shape[:-3], feat.shape[-3]*scale, feat.shape[-2]*scale, feat.shape[-1]), method='linear')
            processed_features.append(feat)

        # Now, upsample and combine features
        x = processed_features[-1]  # Start from the coarsest scale
        x = ResConvUnit()(x)
        for idx in reversed(range(len(processed_features) - 1)):
            # Upsample x to match the spatial resolution of processed_features[idx]
            scale_factor = scales[idx]//scales[idx + 1]
            x = nn.ConvTranspose(
                features=self.decoder_channel,
                kernel_size=(3, 3),
                strides=(scale_factor, scale_factor),
                padding='SAME',
            )(x)
            x = nn.relu(x)
            # Combine with the processed feature
            x = x + ResConvUnit()(processed_features[idx])
            x = ResConvUnit()(x)
        return x  # Final feature map


class ImageFeatureEntireV2(nn.Module):
    dino_size:str='s'
    
    @nn.compact
    def __call__(self, imgs, cam_posquat, cam_intrinsic, dino_feat_out=False, train=False):

        if imgs.dtype == jnp.uint8 or imgs.dtype == jnp.int16 or imgs.dtype == jnp.int32:
            imgs = imgs.astype(jnp.float32)/255.
        
        valid_img_mask = jnp.sum(jnp.abs(imgs), axis=(-1,-2,-3), keepdims=True) > 1e-3

        # preprocessing image size -> should be devided by 14
        original_img_size = imgs.shape[-3:-1]
        patch_size = [int(np.ceil(s/14)) for s in original_img_size]
        img_size = [int(s*14) for s in patch_size]

        # cam_intrinsic_resized = cutil.resize_intrinsic(cam_intrinsic, original_img_size, img_size)
        imgs_resized = cutil.resize_img(imgs, img_size, method='linear')

        # load predefined DINOv2 ViT model
        num_heads, embed_dim, depth = dino_config_by_size(self.dino_size)
        dino_tokens = DinoViT(num_heads=num_heads,
            embed_dim=embed_dim,
            depth=depth,
            img_size=img_size,
            register_tokens=True)(imgs_resized, out_depth=[2,5,8,11], train=False)
        if train:
            dino_tokens = jax.lax.stop_gradient(dino_tokens)
        patch_tokens = dino_tokens['x_norm_patchtokens']
        cls_tokens = dino_tokens['x_norm_clstoken'][...,-1,:] # last cls tokens

        # # option 1 - unflatten patch
        # patch_tokens = nn.Dense(7*7*64)(patch_tokens)
        # img_feat = einops.rearrange(patch_tokens, '... i j (r1 r2 k) -> ... j r1 r2 (i k)', r1=7, r2=7)
        # img_feat = einops.rearrange(img_feat, '... (p1 p2) r1 r2 k -> ... (p1 r1) (p2 r2) k', p1=patch_size[0], p2=patch_size[1])

        # option 2 - linear interpolation
        img_feat = einops.rearrange(patch_tokens, '... i j k -> ... j (i k)')
        cls_tokens = cls_tokens[...,None,:].repeat(img_feat.shape[-2], axis=-2)
        img_feat = jnp.c_[img_feat, cls_tokens]
        
        # patch size
        img_feat = img_feat.reshape(img_feat.shape[:-2]+(*patch_size, img_feat.shape[-1]))

        if dino_feat_out:
            dino_feat = img_feat
        else:
            dino_feat = None
            
        if self.dino_size == 's':
            img_feat_dim = 512
        elif self.dino_size == 'b' or self.dino_size == 'l':
            img_feat_dim = 1024
        img_feat = nn.Dense(img_feat_dim, use_bias=True)(img_feat)

        # img_feat = jax.image.resize(img_feat, (*img_feat.shape[:-3], 2*img_feat.shape[-3], 2*img_feat.shape[-2], img_feat.shape[-1]), method='linear')

        # img_feat = nn.LayerNorm()(img_feat)
        # _ = nn.LayerNorm()(img_feat) # calculated but not used
        
        img_feat = jnp.where(valid_img_mask, img_feat, 0.0)
        cam_intrinsic_resized = cutil.resize_intrinsic(cam_intrinsic, original_img_size, img_feat.shape[-3:-1])
        img_feat_structs = structs.ImgFeatures(intrinsic=cam_intrinsic_resized, cam_posquat=cam_posquat, img_feat=img_feat, dino_feat=dino_feat)
        img_feat_structs = img_feat_structs.replace(rgb=imgs)
        return img_feat_structs


class ImageFeature(nn.Module):
    base_dim:int=8
    depth:int=1

    @nn.compact
    def __call__(self, x, train=False):
        img_size = x.shape[-3:-1]

        def cnn(base_dim, filter, depth, x_):
            for _ in range(depth):
                x_ = nn.Conv(base_dim, (filter,filter))(x_)
                x_ = nn.relu(x_)
            return x_

        if x.dtype in [jnp.uint8, jnp.int16, jnp.int32]:
            x = x.astype(jnp.float32)/255.

        # down
        c_list = []
        for _ in range(2):
            x = nn.Conv(self.base_dim, (5,5))(x)
            x = nn.relu(x)
        x = cnn(2*self.base_dim, 3, self.depth, x)
        c_list.append(x)
        x = nn.Conv(2*self.base_dim, (3,3), strides=(2,2))(x)
        x = nn.relu(x)
        x = cnn(4*self.base_dim, 3, self.depth, x)
        c_list.append(x)
        x = nn.Conv(4*self.base_dim, (3,3), strides=(2,2))(x)
        x = nn.relu(x)
        x = cnn(8*self.base_dim, 3, self.depth, x)
        c_list.append(x)
        x = nn.Conv(8*self.base_dim, (3,3), strides=(2,2))(x)
        x = nn.relu(x)

        x = nn.Conv(8*self.base_dim, (3,3), kernel_dilation=(2,2))(x)
        x = nn.relu(x)
        x = nn.Conv(8*self.base_dim, (5,5), kernel_dilation=(2,2))(x)
        x = nn.relu(x)
        x = cnn(8*self.base_dim, 3, self.depth, x)

        def repeat_ft(x, r, ft_dim):
            x = nn.Dense(ft_dim)(x)
            x = nn.relu(x)
            x = cutil.resize_img(x, img_size)
            # x = einops.repeat(x, '... i j k -> ... (i r1) (j r2) k', r1=r, r2=r)
            return x

        # up
        c_list = list(reversed(c_list))
        p_list = []
        p_list.append(repeat_ft(x, 8, 4*self.base_dim))
        x = nn.ConvTranspose(8*self.base_dim, (3,3), strides=(2,2))(x)
        x = cutil.resize_img(nn.relu(x), c_list[0].shape[-3:-1]) + c_list[0]
        p_list.append(repeat_ft(x, 4, 4*self.base_dim))
        x = nn.ConvTranspose(4*self.base_dim, (3,3), strides=(2,2))(x)
        x = cutil.resize_img(nn.relu(x), c_list[1].shape[-3:-1])  + c_list[1]
        p_list.append(repeat_ft(x, 2, 4*self.base_dim))
        x = nn.ConvTranspose(2*self.base_dim, (3,3), strides=(2,2))(x)
        x = cutil.resize_img(nn.relu(x), c_list[2].shape[-3:-1]) + c_list[2]
        p_list.append(repeat_ft(x, 1, 4*self.base_dim))

        x = jnp.concatenate(p_list, axis=-1)

        return x


class ImageFeatureEntireCNN(nn.Module):
    @nn.compact
    def __call__(self, imgs, cam_posquat, cam_intrinsic, train=False, dino_feat_out=False):

        if imgs.dtype == jnp.uint8 or imgs.dtype == jnp.int16 or imgs.dtype == jnp.int32:
            imgs = imgs.astype(jnp.float32)/255.

        # preprocessing image size -> should be devided by 14
        original_img_size = imgs.shape[-3:-1]

        img_feat = ImageFeature()(imgs, train=train)
        cam_intrinsic_resized = cutil.resize_intrinsic(cam_intrinsic, original_img_size, img_feat.shape[-3:-1])
        img_feat_structs = structs.ImgFeatures(intrinsic=cam_intrinsic_resized, cam_posquat=cam_posquat, img_feat=img_feat)
        # img_feat_structs = SpatialPE()(img_feat_structs)

        img_feat_structs = img_feat_structs.replace(rgb=imgs)

        return img_feat_structs


class SpatialPE(nn.Module):
    @nn.compact
    def __call__(self, img_feat: structs.ImgFeatures, train=False) -> structs.ImgFeatures:
        
        pixel_size = img_feat.img_feat.shape[-3:-1]
        img_feat_dim = img_feat.img_feat.shape[-1]
        cam_posquat = img_feat.cam_posquat # (... NV 7)
        cam_intrinsic = img_feat.intrinsic # (... NV 6)
        near = 0.050
        far = 1.5
        nsamples = 40

        ray_start_pnts, ray_end_pnts, ray_dir = cutil.pixel_ray(pixel_size, cam_posquat[...,:3], cam_posquat[...,3:], cam_intrinsic, near=near, far=far)
        ray_grid = ray_start_pnts[...,None,:] + ray_dir[...,None,:]*jnp.linspace(near, far, nsamples)[...,None]
        ray_grid = einops.rearrange(ray_grid, '... i j -> ... (i j)')

        ray_grid = nn.Dense(img_feat_dim)(ray_grid)
        ray_grid = nn.gelu(ray_grid)
        ray_grid = nn.Dense(img_feat_dim)(ray_grid)
        
        return img_feat.replace(spatial_PE=ray_grid)


def extract_pixel_features(center, img_fts_structs:structs.ImgFeatures, feat_type='all'):
    '''
    center : (... NO 3)
    img_fts : (... NC NI NJ NF)
    cam_info :
        cam_pq : (... NC 7)
        intrinsic : (... NC 6)
    '''
    # # project points to img space
    # center = jax.lax.stop_gradient(center)

    # cam_posquat, intrinsic = cam_info
    cam_posquat = img_fts_structs.cam_posquat
    intrinsic = img_fts_structs.intrinsic
    # img_fts = img_fts_structs.img_feat
    if img_fts_structs.spatial_PE is None:
        img_fts = img_fts_structs.img_feat
        assert feat_type != 'spatial'
    elif img_fts_structs.spatial_PE is not None and feat_type == 'all':
        img_fts = img_fts_structs.img_feat + img_fts_structs.spatial_PE
    elif feat_type == 'spatial':
        img_fts = img_fts_structs.spatial_PE
    elif feat_type == 'img_feat':
        img_fts = img_fts_structs.img_feat
    else:
        raise ValueError('feat_type should be one of [all, spatial, img_feat]')

    intrinsic_ext = intrinsic[...,None,:,:] # (... 1 NC 7)
    input_img_size = intrinsic_ext[...,:2][...,::-1] #xy -> ij size
    cam_posquat_ext = cam_posquat[...,None,:,:]
    cam_pos_quat_ext = (cam_posquat_ext[...,:3], cam_posquat_ext[...,3:]) # (... NR, ...)
    cam_Rm_ext = einops.rearrange(tutil.q2R(cam_pos_quat_ext[1]), '... i j -> ... (i j)')

    img_ft_size = jnp.array((img_fts.shape[-3], img_fts.shape[-2]))
    img_fts_flat = einops.rearrange(img_fts, '... i j k -> ... (i j) k') # (... NC NIJ NF)
    
    # batch version
    intrinsic_ext_, cam_posquat_, input_img_size_, img_fts_flat_ = \
        jax.tree_map(lambda x: x[...,None,:], (intrinsic_ext, cam_pos_quat_ext, input_img_size, img_fts_flat))
    # img_fts_flat_ : (... NC NIJ 1 NF)
    # q_pnts = local_points_ + qp_dirs # (... NO NG 3)
    q_pnts = center[...,None,:,:]
    px_coord_ctn, out_pnts_indicator = cutil.global_pnts_to_pixel(intrinsic_ext_, cam_posquat_, q_pnts) # (... NO NC NG 2)
    px_coord_ctn = px_coord_ctn/input_img_size_ * img_ft_size
    px_coord = jnp.floor(px_coord_ctn).astype(jnp.int32)
    px_coord = jax.lax.stop_gradient(px_coord)

    # interpolation
    # px_coord_residual = (px_coord_ctn - px_coord) # (... NO NC NG 2)
    def extract_img_fts(px_coord:jnp.ndarray):
        px_coord = px_coord.clip(0, jnp.array(img_ft_size)-1)
        px_flat_idx = px_coord[...,1] + px_coord[...,0] * img_ft_size[...,1] # (... NO NC)
        selected_img_fts = jnp.take_along_axis(img_fts_flat_.squeeze(-2), einops.rearrange(px_flat_idx, '... i j k -> ... j (i k)')[...,None], axis=-2) # (... NC NO*NG NF)
        selected_img_fts = einops.rearrange(selected_img_fts, '... (r i) j -> ... r i j', r=px_coord.shape[-4]) # (... NC NO NG NF)
        img_fts = einops.rearrange(selected_img_fts, '... i j p k -> ... j i p k') # (... NO NC NG NF)
        return img_fts
    
    img_fts_list = []
    residuals = []
    px_coord_ctn_offset = (px_coord_ctn -0.5).clip(0, jnp.array(img_ft_size)-1)
    for sft in [np.array([0,0]), np.array([0,1]), np.array([1,0]), np.array([1,1])]:
        img_fts_list.append(extract_img_fts(px_coord+sft))
        resd_ = jnp.abs(px_coord_ctn_offset - (px_coord+(1-sft))) + 1e-2
        resd_ = resd_[...,0:1] * resd_[...,1:2]
        residuals.append(resd_)
    weights = jnp.stack(residuals, axis=-1)
    weights = weights/jnp.sum(weights, axis=-1, keepdims=True) # (NO NC NG 1 4)
    img_fts = jnp.sum(jnp.stack(img_fts_list, axis=-1) * weights, -1) # (NO NC NG NF)
    img_fts = einops.rearrange(img_fts, '... i j p k -> ... i p j k') # (... NO NG NC NF)
    
    intrinsic_cond = jnp.concatenate([intrinsic_ext[...,2:3], intrinsic_ext[...,2:3]], axis=-1)
    cam_fts = jnp.concatenate([cam_pos_quat_ext[0], cam_Rm_ext, intrinsic_cond/intrinsic_ext[...,1:2]], axis=-1) # (... 1 NC NF)

    return img_fts, cam_fts


class AdaLayerNorm(nn.Module):
    '''
    FiLM (FiLM: Visual Reasoning with a General Conditioning Layer)
    AdaGN in (https://proceedings.neurips.cc/paper_files/paper/2021/hash/49ad23d1ec9fa4bd8d77d02681df5cfa-Abstract.html)
    DiT
    '''

    @nn.compact
    def __call__(self, x, context):
        if x.ndim == context.ndim:
            emb = nn.Dense(2*x.shape[-1])(context)
            scale, shift = jnp.split(emb, 2, -1)
            x = nn.LayerNorm()(x) * (1 + scale) + shift
        else:
            emb = nn.MultiHeadAttention(num_heads=2, qkv_features=x.shape[-1]//2, out_features=2*x.shape[-1])(nn.LayerNorm()(x[...,None,:]), context, deterministic=True)
            emb = emb.squeeze(-2)
            # emb = nn.Dense(2*x.shape[-1])(context)
            scale, shift = jnp.split(emb, 2, -1)
            scale = nn.tanh(scale)
            x = nn.LayerNorm()(x) * scale + shift
        return x



# Resnet Blocks
class ResnetBlockFC(nn.Module):
    args:typing.NamedTuple
    size_h:int
    size_out:int

    @nn.compact
    def __call__(self, x):
        size_in = x.shape[-1]
        net = nn.relu(x)
        net = nn.Dense(self.size_h)(net)
        dx = nn.relu(net)
        dx = nn.Dense(self.size_out)(dx)
        if size_in == self.size_out:
            x_s = x
        else:
            x_s = nn.Dense(self.size_out)(x)
        return x_s + dx

class Aggregator(nn.Module):
    # args:typing.NamedTuple
    base_dim:int
    depth:int=2
    axis:int=-2
    pooled_out:bool=True
    pooling_func:str=jnp.max
    final_activation:bool=True

    @nn.compact
    def __call__(self, x):
        x = jnp.swapaxes(x, -2, self.axis)
        
        if x.shape[-1] == self.base_dim:
            skip = x
        else:
            skip = nn.Dense(self.base_dim)(x)

        for k in range(self.depth):
            if k!=0:
                x = nn.LayerNorm()(x)
            x = nn.gelu(nn.Dense(self.base_dim)(x)) + skip
            if x.shape[-2] != 1:
                x_pooled = self.pooling_func(x, axis=-2, keepdims=True)
                x = jnp.c_[x, einops.repeat(x_pooled, '... r i -> ... (k r) i', k=x.shape[-2])]
            x = nn.Dense(self.base_dim)(x)
            if k<self.depth-1 or self.final_activation:
                x = nn.gelu(x)
        
        x = jnp.swapaxes(x, self.axis, -2)
        if self.pooled_out:
            return self.pooling_func(x, axis=self.axis)
        else:
            return x


class TransformerDecoder(nn.Module):
    args:typing.Sequence
    dropout:float
    local_only:bool=False
    
    @nn.compact
    def __call__(self, z, context, img_feat_flat=None, train=False):
        '''
        if turn on local_only, only one attention layer is appied
            SPS - dif_model_version 1 - self Attention between z with context (extracted img fts) conditioning
        '''
        ncd = z.shape[-2]
        # first mixing
        z2 = nn.SelfAttention(num_heads=self.args.mixing_head, qkv_features=self.args.base_dim, 
                                out_features=self.args.base_dim, dropout_rate=0.0)(z, deterministic=True)
        if z.shape[-1] == z2.shape[-2]:
            z = z + nn.Dropout(self.dropout)(z2, deterministic=not train)
        else:
            z = nn.Dense(z2.shape[-1])(z) + nn.Dropout(self.dropout)(z2, deterministic=not train)
        z = AdaLayerNorm()(z, context)

        # second mixing -> conditioning (cross attention)
        if not self.local_only:
            z_gb = einops.rearrange(z, '... s c f -> ... (s c) f')
            assert img_feat_flat is None
            z2_gb = nn.SelfAttention(num_heads=self.args.mixing_head, qkv_features=self.args.base_dim, 
                                    out_features=self.args.base_dim, dropout_rate=0.0)(z_gb, deterministic=True)
            z2 = einops.rearrange(z2_gb, '... (r i) k -> ... r i k', i=ncd)
            z = z + nn.Dropout(self.dropout)(z2, deterministic=not train)
            z = AdaLayerNorm()(z, context)

        # final linear
        z2 = nn.Dense(self.args.base_dim)(nn.Dropout(self.dropout)(nn.gelu(nn.Dense(self.args.base_dim)(z)), deterministic=not train))
        z = z + nn.Dropout(self.dropout)(z2, deterministic=not train)
        z = AdaLayerNorm()(z, context)

        return z


class DenoisingModel(nn.Module):
    args:typing.NamedTuple
    rot_configs:typing.Sequence

    @nn.compact
    def __call__(self, objects_ptcl:loutil.LatentObjects, con_feat:structs.ImgFeatures, time, batch_condition_mask=None, previous_emb=None, train=False):
        '''
        objects_ptcl : (nb, ns, ...)
        con_feat : (nb, nc, nf)
        time : (nb, ) # if EDM it is sigma
        '''

        viewpoint_update = self.args.cam_loss_coef != 0

        if time.ndim == 0:
            time = time[None]

        extended = False
        if len(objects_ptcl.outer_shape) == 1:
            extended = True
            objects_ptcl, con_feat = jax.tree_map(lambda x:x[None], (objects_ptcl, con_feat))

        dropout=0.1
        ncd = objects_ptcl.rel_fps.shape[-2]

        time, (c_skip, c_out, c_in) = dfutil.calculate_cs(time, self.args)

        if self.args.dm_type == 'regression':
            embeddings = jnp.zeros_like(time)[...,None,None,:]
        else:
            # positional embedding to time
            pe_L = 16
            embeddings = np.power(2, np.arange(pe_L)) * np.pi
            embeddings = embeddings*time
            embeddings = jnp.concatenate([jnp.sin(embeddings), jnp.cos(embeddings)], axis=-1)

            if self.args.condition_type == 0:
                context_emb_size = self.args.base_dim
            elif self.args.condition_type == 1:
                context_emb_size = con_feat.img_feat.shape[-1]
            embeddings = nn.Dense(context_emb_size)(embeddings)
            embeddings = nn.gelu(embeddings)
            embeddings = nn.Dense(context_emb_size)(embeddings)
            embeddings = nn.gelu(embeddings)
            embeddings = nn.Dense(context_emb_size)(embeddings)
            embeddings = nn.LayerNorm()(embeddings)
            embeddings = embeddings[...,None,None,:]

        if self.args.fps_only:
            objects_ptcl = objects_ptcl.replace(z=jnp.zeros_like(objects_ptcl.z))
        
        def embedder(x, base_dim):
            emb = nn.Dense(base_dim)(x)
            emb = nn.gelu(emb)
            emb = nn.Dense(base_dim)(emb)
            emb = nn.gelu(emb)
            emb = nn.Dense(base_dim)(emb)
            emb = nn.LayerNorm()(emb)
            return emb

        pos_emb = embedder(objects_ptcl.pos, self.args.base_dim)
        cen_emb = embedder(objects_ptcl.rel_fps, self.args.base_dim)
        cen_tf_emb = embedder(objects_ptcl.fps_tf, self.args.base_dim)
        z_emb = embedder(objects_ptcl.z_flat, self.args.base_dim)
        embs = (pos_emb, cen_emb, cen_tf_emb, z_emb)
        if previous_emb is not None:
            embs = jax.tree_util.tree_map(lambda x,y: x+y, embs, previous_emb)

        def feature_extractor(p_, con_feat_, embs, stop_gradient=False):
            pos_emb, cen_emb, cen_tf_emb, z_emb = embs
            # feature extraction
            if stop_gradient:
                # extracted_img_fts, _ = extract_pixel_features(jax.lax.stop_gradient(p_), con_feat_)
                extracted_img_fts, _ = extract_pixel_features(jax.lax.stop_gradient(p_), con_feat_, feat_type='img_feat')
                extracted_spatial_PE, _ = extract_pixel_features(jax.lax.stop_gradient(p_), con_feat_, feat_type='spatial')
            else:
                # extracted_img_fts, _ = extract_pixel_features(p_, con_feat_)
                extracted_img_fts, _ = extract_pixel_features(p_, con_feat_, feat_type='img_feat')
                extracted_spatial_PE, _ = extract_pixel_features(p_, con_feat_, feat_type='spatial')
            extracted_img_fts = nn.LayerNorm(name='LayerNorm_img_feat_0')(extracted_img_fts) + extracted_spatial_PE
            
            nviews = extracted_img_fts.shape[-2]
            if viewpoint_update:
                ref_positional_embeddings = jnp.zeros(nviews, dtype=jnp.float32)
                ref_positional_embeddings = ref_positional_embeddings.at[0].set(1)[...,None]
                for _ in range(2):
                    ref_positional_embeddings = nn.Dense(extracted_img_fts.shape[-1])(ref_positional_embeddings)
                    ref_positional_embeddings = nn.gelu(ref_positional_embeddings)
                ref_positional_embeddings = nn.Dense(extracted_img_fts.shape[-1])(ref_positional_embeddings)
                extracted_img_fts += ref_positional_embeddings

            if self.args.condition_type == 0:
                pooling_func = jnp.max
                viewpoint_fts = Aggregator(self.args.base_dim, pooling_func=pooling_func, pooled_out=False)(extracted_img_fts) # (B NO ND d)
                extracted_aggregated_img_fts = jnp.max(viewpoint_fts, axis=-2)
            elif self.args.condition_type == 1:
                viewpoint_fts = None
                extracted_aggregated_img_fts = extracted_img_fts

            if viewpoint_update and viewpoint_fts is None:
                # viewpoint_fts = extracted_img_fts
                viewpoint_fts = Aggregator(self.args.base_dim, pooling_func=pooling_func, pooled_out=False)(extracted_img_fts) # (B NO ND d)
                # viewpoint_fts = einops.rearrange(viewpoint_fts, '... o d v f -> ... (o d) v f')
                # viewpoint_fts = Aggregator(self.args.base_dim, axis=-3, pooled_out=False)(viewpoint_fts) # (B NO ND NV d)
                # viewpoint_fts = einops.rearrange(viewpoint_fts, '... (o d) v f -> ... o d v f', d=ncd)

            if batch_condition_mask is not None:
                batch_condition_mask_ext = batch_condition_mask[...,None]
                for _ in range(extracted_aggregated_img_fts.ndim - batch_condition_mask_ext.ndim):
                    batch_condition_mask_ext = batch_condition_mask_ext[...,None]
                extracted_aggregated_img_fts = jnp.where(batch_condition_mask_ext, extracted_aggregated_img_fts, 0)

            if self.args.condition_type == 0:
                context = jnp.c_[extracted_aggregated_img_fts, cen_tf_emb, 
                                jnp.broadcast_to(embeddings, jnp.broadcast_shapes(extracted_aggregated_img_fts.shape, embeddings.shape))]
            elif self.args.condition_type == 1:
                context_emb_size = extracted_aggregated_img_fts.shape[-1]
                context = jnp.concat([extracted_aggregated_img_fts, nn.Dense(context_emb_size)(cen_tf_emb[...,None,:]),
                                jnp.broadcast_to(embeddings[...,None,:], jnp.broadcast_shapes(extracted_aggregated_img_fts[...,:1,:].shape, embeddings[...,None,:].shape))], axis=-2)
            img_feat_flat = None

            if batch_condition_mask is not None and img_feat_flat is not None:
                batch_condition_mask_ = batch_condition_mask
                for _ in range(img_feat_flat.ndim - batch_condition_mask.ndim):
                    batch_condition_mask_ = batch_condition_mask_[...,None]
                img_feat_flat = jnp.where(batch_condition_mask_, img_feat_flat, 0)
            
            return context, img_feat_flat, viewpoint_fts
        
        def c_p_prediction_head(z, context, pos_emb, cen_emb, conf_out=False):
            # heads for the rest
            dz = z
            for i in range(2):
                dz = nn.Dense(self.args.base_dim)(dz)
                dz = AdaLayerNorm()(dz, context)
                dz = nn.gelu(dz)
                dz = jnp.c_[dz, einops.repeat(jnp.mean(dz, axis=-2), '... c -> ... r c', r=ncd)]
            # dz = nn.gelu(nn.Dense(self.args.base_dim)(dz)) + z
            dz = z + nn.Dropout(dropout)(nn.gelu(nn.Dense(self.args.base_dim)(dz)), deterministic=not train)
            dcenter = dz
            dpos=conf=jnp.mean(dz, axis=-2)

            # # pos branch
            dpos = jnp.c_[dpos, nn.Dense(dpos.shape[-1])(AdaLayerNorm()(pos_emb, embeddings.squeeze(-2)))]
            dpos = nn.gelu(nn.Dense(self.args.base_dim)(dpos))
            dpos = nn.Dense(3)(dpos)

            if self.args.add_c_skip:
                pos_scaled = c_skip[...,None]*objects_ptcl.pos + c_out[...,None]*dpos
            else:
                c_skip_ = nn.sigmoid(nn.Dense(1)(embeddings.squeeze(-2)))
                c_out_ = nn.sigmoid(nn.Dense(1)(embeddings.squeeze(-2)))
                pos_scaled = c_skip_*objects_ptcl.pos + c_out_*dpos

            # # conf branch
            if conf_out:
                conf = nn.gelu(nn.Dense(self.args.base_dim)(conf))
                conf = nn.Dense(1)(conf)
            else:
                conf = None

            # center pred branch
            cen_emb_tmp = AdaLayerNorm()(cen_emb, embeddings)
            dcenter = jnp.c_[jnp.broadcast_to(dcenter, jnp.broadcast_shapes(cen_emb_tmp[...,:1].shape, dcenter.shape)), 
                            jnp.broadcast_to(cen_emb_tmp, jnp.broadcast_shapes(dcenter[...,:1].shape, cen_emb_tmp.shape))]
            dcenter = nn.Dense(self.args.base_dim)(dcenter)
            dcenter = nn.gelu(dcenter)
            dcenter = nn.Dense(3)(dcenter)

            if self.args.add_c_skip:
                center_scaled = c_skip[...,None,None]*objects_ptcl.fps_tf + c_out[...,None,None]*dcenter
                # center_scaled = c_skip[...,None,None]*objects_ptcl.rel_fps + c_out[...,None,None]*dcenter
            else:
                c_skip_ = nn.sigmoid(nn.Dense(1)(embeddings))
                c_out_ = nn.sigmoid(nn.Dense(1)(embeddings))
                center_scaled = c_skip_*objects_ptcl.fps_tf + c_out_*dcenter
                # center_scaled = c_skip_*objects_ptcl.rel_fps + c_out_*dcenter
            return pos_scaled, center_scaled, conf

        context, _, viewpoint_fts = \
            feature_extractor(objects_ptcl.fps_tf, con_feat, embs, stop_gradient=False)

        pos_emb_init = pos_emb
        cen_emb_init = cen_emb
        # generate object queries
        z = jnp.c_[cen_emb, einops.repeat(pos_emb, '... d -> ... r2 d', r2=ncd), z_emb] # (B NO ND d)
        z = AdaLayerNorm()(nn.Dense(self.args.base_dim)(z), context)

        # apply transformer Decoders
        for first_mixing_itr in range(self.args.first_mixing_depth):
            z = TransformerDecoder(self.args, dropout)(z, context, None, train=train)
        pos_scaled, center_scaled, conf = c_p_prediction_head(z, context, pos_emb_init, cen_emb_init, conf_out=True)

        # viewpoint out
        if viewpoint_update:
            viewpoint_fts = jnp.c_[viewpoint_fts, einops.repeat(z, '... i j k -> ... i j r k', r=viewpoint_fts.shape[-2])]
            viewpoint_fts = Aggregator(self.args.base_dim, axis=-3)(viewpoint_fts)
            viewpoint_fts = Aggregator(self.args.base_dim, axis=-3)(viewpoint_fts)
            viewpoint_dpos = nn.Dense(3)(viewpoint_fts)
            viewpoint_se3 = nn.Dense(3)(viewpoint_fts)

            cam_pos_quat_updated = tutil.pq_multi(con_feat.cam_posquat, jnp.c_[viewpoint_dpos, tutil.qExp(viewpoint_se3)])
        else:
            cam_pos_quat_updated = con_feat.cam_posquat

        if self.args.fps_only:
            z_sh = jnp.zeros_like(objects_ptcl.z)
        else:
            # pred z_sh
            z_sh = jnp.c_[z_emb, z]
            z_sh = nn.Dense(self.args.base_dim)(z_sh)
            z_sh = AdaLayerNorm()(z_sh, context)
            z_sh = ResnetBlockFC(self.args, self.args.base_dim, self.args.base_dim)(z_sh)
            z_sh = nn.Dense(objects_ptcl.z_flat.shape[-1])(z_sh)
            z_sh = einops.rearrange(z_sh, '... (r i) -> ... r i', i=objects_ptcl.nz)
            if self.args.add_c_skip:
                z_sh = c_skip[...,None,None,None]*objects_ptcl.z + c_out[...,None,None,None]*z_sh
            else:
                c_skipout = nn.sigmoid(nn.Dense(2)(embeddings))
                z_sh = c_skipout[...,None,0:1]*objects_ptcl.z + c_skipout[...,None,1:2]*z_sh
        # obj_pred0 = loutil.LatentObjects().replace(rel_fps=center_scaled-jax.lax.stop_gradient(pos_scaled[...,None,:]), pos=pos_scaled, z=z_sh)
        obj_pred0 = replace(loutil.LatentObjects(), pos=pos_scaled, z=z_sh).set_fps_tf(center_scaled)
        obj_pred0 = obj_pred0.set_conf(conf)
        
        if extended:
            return jax.tree_map(lambda x: x[0], (obj_pred0, conf, cam_pos_quat_updated, embs))
        else:
            return obj_pred0, conf, cam_pos_quat_updated, embs


class SegModel(nn.Module):
    args:typing.NamedTuple

    @nn.compact
    def __call__(self, img_feat_struct:structs.ImgFeatures, train=False):
        # img_feat = img_feat_struct.img_feat
        img_feat = img_feat_struct.dino_feat
        rgb_size = (14*img_feat.shape[-3], 14*img_feat.shape[-2])

        img_feat_dim = 512
        img_feat = nn.Dense(img_feat_dim, use_bias=True)(img_feat)
        # img_feat = nn.LayerNorm()(img_feat)

        img_feat = jax.image.resize(img_feat, img_feat.shape[:-3] + rgb_size + img_feat.shape[-1:], method='linear')

        for _ in range(3):
            img_feat = nn.Dense(self.args.base_dim)(img_feat)
            img_feat = nn.relu(img_feat)
        img_feat = nn.Dense(1)(img_feat)

        return img_feat



        # # img_feat = nn.Dense(49*self.args.base_dim)(img_feat)
        # # img_feat = einops.rearrange(img_feat, '... i j (p q k) -> ... (i p) (j q) k', p=7, q=7)
        # img_feat = nn.selu(img_feat)
        # img_feat = nn.ConvTranspose(self.args.base_dim, (3,3), strides=(2,2), padding='SAME')(img_feat)
        # img_feat = nn.selu(img_feat)
        # img_feat = nn.Dense(self.args.base_dim)(img_feat)
        # img_feat = nn.selu(img_feat)
        # img_feat = nn.Dense(1)(img_feat)


        
        # transpose upsampling
        x = nn.ConvTranspose(features=self.args.base_dim*2, kernel_size=(3, 3), strides=(2, 2),)(img_feat)
        x = nn.gelu(x)
        x = nn.Conv(features=self.args.base_dim*2, kernel_size=(3, 3))(x)
        x = nn.gelu(x)
        x = nn.ConvTranspose(features=self.args.base_dim, kernel_size=(3, 3), strides=(2, 2),)(x)
        x = nn.gelu(x)
        x = nn.Conv(features=self.args.base_dim, kernel_size=(3, 3))(x)
        x = nn.gelu(x)

        x = jax.image.resize(img_feat, img_feat.shape[:-3] + rgb_size + x.shape[-1:], method='linear')

        # # Determine the number of upsampling steps needed to reach the desired output size
        # feat_size = img_feat.shape[-3:-1]            # Current feature map size

        # # Calculate the number of times we need to upsample to reach the target size
        # upsample_steps = int(np.log2(rgb_size[0] / feat_size[0]))

        # x = img_feat
        # for _ in range(upsample_steps):
        #     x = nn.ConvTranspose(
        #         features=self.base_dim,
        #         kernel_size=(3, 3),
        #         strides=(2, 2),
        #         padding='SAME'
        #     )(x)
        #     x = nn.relu(x)
        # # x = cutil.resize_img(x, rgb_size)
        # x = jax.image.resize(x, x.shape[:-3] + rgb_size + x.shape[-1:], method='linear')

        # # Additional convolutional layers to refine the output
        # x = nn.Conv(features=self.base_dim // 2, kernel_size=(3, 3), padding='SAME')(x)
        # x = nn.relu(x)
        # x = nn.Conv(features=self.base_dim // 4, kernel_size=(3, 3), padding='SAME')(x)
        # x = nn.relu(x)

        # one resnet block
        skip = x
        x = nn.Conv(self.args.base_dim, (3,3))(x)
        x = nn.gelu(x)
        x = nn.Conv(self.args.base_dim, (3,3))(x)
        x = nn.gelu(x)
        x += nn.Dense(self.args.base_dim)(skip)

        # Final convolution to produce the RGB image (assuming 3 color channels)
        rgb_img = nn.Dense(1)(x)

        # Apply activation if needed (e.g., sigmoid for values between 0 and 1)
        # rgb_img = nn.sigmoid(rgb_img)

        return rgb_img


class SegModelCNN(nn.Module):
    args:typing.NamedTuple

    @nn.compact
    def __call__(self, x, train=False):
        depth =1

        def cnn(base_dim, filter, depth, x_):
            for _ in range(depth):
                x_ = nn.Conv(base_dim, (filter,filter))(x_)
                x_ = nn.relu(x_)
            return x_

        if x.dtype in [jnp.uint8, jnp.int16, jnp.int32]:
            x = x.astype(jnp.float32)/255.

        # down
        c_list = []
        for _ in range(2):
            x = nn.Conv(self.args.img_base_dim, (5,5))(x)
            x = nn.relu(x)
        x = cnn(2*self.args.img_base_dim, 3, depth, x)
        c_list.append(x)
        x = nn.Conv(2*self.args.img_base_dim, (3,3), strides=(2,2))(x)
        x = nn.relu(x)
        x = cnn(4*self.args.img_base_dim, 3, depth, x)
        c_list.append(x)
        x = nn.Conv(4*self.args.img_base_dim, (3,3), strides=(2,2))(x)
        x = nn.relu(x)
        x = cnn(8*self.args.img_base_dim, 3, depth, x)
        c_list.append(x)
        x = nn.Conv(8*self.args.img_base_dim, (3,3), strides=(2,2))(x)
        x = nn.relu(x)

        x = nn.Conv(8*self.args.img_base_dim, (3,3), kernel_dilation=(2,2))(x)
        x = nn.relu(x)
        x = nn.Conv(8*self.args.img_base_dim, (5,5), kernel_dilation=(2,2))(x)
        x = nn.relu(x)
        x = cnn(8*self.args.img_base_dim, 3, depth, x)

        def repeat_ft(x, r, ft_dim):
            x = nn.Dense(ft_dim)(x)
            x = nn.relu(x)
            x = einops.repeat(x, '... i j k -> ... (i r1) (j r2) k', r1=r, r2=r)
            return x

        # up
        c_list = list(reversed(c_list))
        # p_list = [einops.repeat(x, '... i j k -> ... (i r1) (j r2) k', r1=8, r2=8)]
        p_list = [repeat_ft(x, 8, 4*self.args.img_base_dim)]
        x = nn.ConvTranspose(8*self.args.img_base_dim, (3,3), strides=(2,2))(x)
        x = nn.relu(x) + c_list[0]
        # p_list.append(einops.repeat(x, '... i j k -> ... (i r1) (j r2) k', r1=4, r2=4))
        p_list.append(repeat_ft(x, 4, 4*self.args.img_base_dim))
        x = nn.ConvTranspose(4*self.args.img_base_dim, (3,3), strides=(2,2))(x)
        x = nn.relu(x) + c_list[1]
        # p_list.append(einops.repeat(x, '... i j k -> ... (i r1) (j r2) k', r1=2, r2=2))
        p_list.append(repeat_ft(x, 2, 4*self.args.img_base_dim))
        x = nn.ConvTranspose(2*self.args.img_base_dim, (3,3), strides=(2,2))(x)
        x = nn.relu(x) + c_list[2]
        p_list.append(repeat_ft(x, 1, 4*self.args.img_base_dim))

        x = jnp.concatenate(p_list, axis=-1)
        
        for _ in range(3):
            x = nn.Dense(self.args.base_dim)(x)
            x = nn.relu(x)
        x = nn.Dense(1)(x)

        return x



class FeatureRenderer(nn.Module):
    base_dim:int
    rot_configs:typing.Sequence
    '''
    latent objects -> img feature
    aggregate viewpoints -> novel view feature synthesis
    '''

    @nn.compact
    def __call__(self, latent_objs:loutil.LatentObjects, img_feat_struct:structs.ImgFeatures, cam_pqc, cam_intrinsic, train=False):
        '''
        latent_objs : (nb, ns, ...)
        img_feat : (nb, nv, nc1, nc2, nf)
        cam_pqc : (nb, 7)
        cam_intrinsic : (nb, 6)
        return:
            img_feat : (nb, nv, nc1, nc2, nf)
        '''

        img_feat = img_feat_struct.img_feat + img_feat_struct.spatial_PE
        output_pixel_size = img_feat.shape[-3:-1]
        feat_dim = img_feat.shape[-1]

        object_valid_mask = latent_objs.obj_valid_mask

        # extract features for each image 
        # fps points of (nb, ns, nfps, 3) -> extract pixel features (nb, ns, nfps, nv, nf)
        img_fts_vies, cam_fts = extract_pixel_features(latent_objs.fps_tf, img_feat_struct) # (nb, ns, nfps, nv, nf)

        # aggregate viewpoints
        img_fts_objs = Aggregator(self.base_dim, pooling_func=jnp.max, pooled_out=True)(img_fts_vies) # (nb, ns, nfps, nf)
        img_fts_objs = nn.LayerNorm()(img_fts_objs)

        distance_from_cam = jnp.linalg.norm(latent_objs.fps_tf - cam_pqc[...,None,None,:3], axis=-1)
        depth_emb = Aggregator(self.base_dim, pooling_func=jnp.mean, pooled_out=False)(distance_from_cam[...,None])
        depth_emb = nn.LayerNorm()(depth_emb)

        # z normalize
        latent_objs_wrt_cam = latent_objs.apply_pq_z(tutil.pq_inv(cam_pqc[...,None,:]), self.rot_configs)
        z_emb = Aggregator(self.base_dim, pooling_func=jnp.mean, pooled_out=False)(latent_objs_wrt_cam.z_flat)
        z_emb = nn.LayerNorm()(z_emb)

        # novel view feature synthesis
        # project z, c, p and img_feat for each fps into novel image plane (padded images -> fill projected features)
        # I think this is the simplest with minimized complexity... but using ray marching would be better
        pixel_coord, out_mask = cutil.global_pnts_to_pixel(cam_intrinsic[...,None,None,:], cam_pqc[...,None,None,:], latent_objs.fps_tf)
        pixel_coord = jax.lax.stop_gradient(pixel_coord)

        # residual embeddings
        residuals = pixel_coord - jnp.floor(pixel_coord)
        residual_emb = nn.Dense(self.base_dim)(residuals)
        residual_emb = nn.gelu(residual_emb)
        residual_skip = residual_emb
        residual_emb = nn.Dense(self.base_dim)(residual_emb)
        residual_emb = nn.gelu(residual_emb)
        residual_emb = nn.Dense(self.base_dim)(residual_emb) + residual_skip
        residual_emb = nn.LayerNorm()(residual_emb)

        feat_objs = jnp.concatenate([img_fts_objs, z_emb, depth_emb, residual_emb], axis=-1)
        feat_objs = nn.Dense(self.base_dim)(feat_objs)
        feat_skip = feat_objs
        feat_objs = nn.gelu(feat_objs)
        feat_objs = nn.Dense(self.base_dim)(feat_objs) + feat_skip

        # mask invalid objs
        feat_objs = jnp.where(object_valid_mask[...,None,None], feat_objs, 0)

        # fill projected features
        segment_ids = pixel_coord.astype(jnp.int32)
        segment_ids = segment_ids[...,0]*output_pixel_size[1] + segment_ids[...,1]
        origin_outer_shape = segment_ids.shape[:-1]
        flat_inputs = (feat_objs.reshape(-1, *feat_objs.shape[-2:]), segment_ids.reshape(-1, *segment_ids.shape[-1:]))
        feat_flat = jax.vmap(partial(jax.ops.segment_max, num_segments=np.prod(output_pixel_size)))(*flat_inputs)
        # feat_flat = jax.vmap(partial(jax.ops.segment_sum, num_segments=np.prod(output_pixel_size)))(*flat_inputs)
        feat_flat = feat_flat.reshape(*origin_outer_shape, *feat_flat.shape[-2:])
        feat_flat = jnp.where(jnp.isfinite(feat_flat), feat_flat, 0)
        feat_flat = jnp.where(jnp.abs(feat_flat)<1e6, feat_flat, 0)
        feat_res = einops.rearrange(feat_flat, '... (i j) k -> ... i j k', i=output_pixel_size[0], j=output_pixel_size[1])

        # # # debug visualization - projected points and rgb images
        # vis_batch_idx = 0
        # for vis_batch_idx in range(feat_res.shape[0]):
        #     vis_feat_res = feat_res[vis_batch_idx]
        #     vis_rgb = img_feat_struct.rgb[vis_batch_idx]
            
        #     # perform PCA for visualization
        #     import matplotlib.pyplot as plt
        #     from sklearn.decomposition import PCA
        #     pca = PCA(n_components=3)
        #     pca.fit(vis_feat_res.reshape(-1, vis_feat_res.shape[-1]))
        #     vis_feat_res_pca = pca.transform(vis_feat_res.reshape(-1, vis_feat_res.shape[-1])).reshape(vis_feat_res.shape[:-1]+(3,))
        #     vis_feat_res_pca = (vis_feat_res_pca - vis_feat_res_pca.min()) / (vis_feat_res_pca.max() - vis_feat_res_pca.min())

        #     plt.figure()
        #     plt.subplot(3, 2, 1)
        #     plt.imshow(vis_rgb[-1])
        #     plt.subplot(3, 2, 2)
        #     plt.imshow(vis_feat_res_pca[0])
        #     plt.subplot(3, 2, 3)
        #     plt.imshow(vis_feat_res_pca[1])
        #     plt.subplot(3, 2, 4)
        #     plt.imshow(vis_feat_res_pca[2])
        #     plt.subplot(3, 2, 5)
        #     plt.imshow(vis_feat_res_pca[3])
        #     # plt.show()
        #     plt.savefig(f'tmp_{vis_batch_idx}.png')
        # # # debug visualization - projected points and rgb images

        for _ in range(2):
            feat_res = nn.Conv(self.base_dim, (5,5))(feat_res)
            feat_res = nn.gelu(feat_res)
        feat_res = jnp.max(feat_res, axis=-4)
        skip = feat_res
        feat_res = nn.Conv(self.base_dim, (3,3), kernel_dilation=(2,2))(feat_res)
        feat_res = nn.gelu(feat_res)
        feat_res = nn.Conv(self.base_dim, (5,5), kernel_dilation=(2,2))(feat_res)
        feat_res = nn.gelu(feat_res)
        feat_res = nn.Conv(self.base_dim, (3,3))(feat_res) + skip
        feat_res = nn.Conv(feat_dim, (1,1))(feat_res)

        return feat_res



class oriCORNRenderer(nn.Module):
    base_dim:int
    rot_configs:typing.Sequence
    '''
    latent objects -> img feature
    aggregate viewpoints -> novel view feature synthesis
    '''

    @nn.compact
    def __call__(self, latent_objs:loutil.LatentObjects, img_feat_struct:structs.ImgFeatures, cam_pqc, cam_intrinsic, output_pixel_size, train=False):
        '''
        latent_objs : (nb, ns, ...)
        img_feat : (nb, nv, nc1, nc2, nf)
        cam_pqc : (nb, 7)
        cam_intrinsic : (nb, 6)
        return:
            img_feat : (nb, nv, nc1, nc2, nf)
        '''

        img_feat = img_feat_struct.img_feat
        feat_pixel_size = img_feat.shape[-3:-1]
        feat_dim = img_feat.shape[-1]
        intermediate_pixel_size = (feat_pixel_size[0]*4, feat_pixel_size[1]*4)

        # object_valid_mask = latent_objs.obj_valid_mask

        # extract features for each image 
        # fps points of (nb, ns, nfps, 3) -> extract pixel features (nb, ns, nfps, nv, nf)
        img_fts_vies, cam_fts = extract_pixel_features(latent_objs.fps_tf, img_feat_struct) # (nb, ns, nfps, nv, nf)

        # aggregate viewpoints
        img_fts_objs = Aggregator(self.base_dim, pooling_func=jnp.max, pooled_out=True)(img_fts_vies) # (nb, ns, nfps, nf)
        img_fts_objs = nn.LayerNorm()(img_fts_objs)

        distance_from_cam = jnp.linalg.norm(latent_objs.fps_tf - cam_pqc[...,None,None,:3], axis=-1)
        depth_emb = Aggregator(self.base_dim, pooling_func=jnp.mean, pooled_out=False)(distance_from_cam[...,None])
        depth_emb = nn.LayerNorm()(depth_emb)

        # z normalize
        latent_objs_wrt_cam = latent_objs.apply_pq_z(tutil.pq_inv(cam_pqc[...,None,:]), self.rot_configs)
        z_emb = Aggregator(self.base_dim, pooling_func=jnp.mean, pooled_out=False)(latent_objs_wrt_cam.z_flat)
        z_emb = nn.LayerNorm()(z_emb)

        # novel view feature synthesis
        # project z, c, p and img_feat for each fps into novel image plane (padded images -> fill projected features)
        # I think this is the simplest with minimized complexity... but using ray marching would be better
        cam_intrinsic_proj = cutil.resize_intrinsic(cam_intrinsic, feat_pixel_size, intermediate_pixel_size)
        pixel_coord, out_mask = cutil.global_pnts_to_pixel(cam_intrinsic_proj[...,None,None,:], cam_pqc[...,None,None,:], latent_objs.fps_tf)
        # pixel_coord = jax.lax.stop_gradient(pixel_coord)

        # residual embeddings
        residuals = pixel_coord/np.array(intermediate_pixel_size).astype(np.float32)
        residual_emb = nn.Dense(self.base_dim)(residuals)
        residual_emb = nn.gelu(residual_emb)
        residual_skip = residual_emb
        residual_emb = nn.Dense(self.base_dim)(residual_emb)
        residual_emb = nn.gelu(residual_emb)
        residual_emb = nn.Dense(self.base_dim)(residual_emb) + residual_skip
        residual_emb = nn.LayerNorm()(residual_emb)

        feat_objs = jnp.concatenate([img_fts_objs, z_emb, depth_emb, residual_emb], axis=-1)
        feat_objs = nn.Dense(self.base_dim)(feat_objs)
        feat_skip = feat_objs
        feat_objs = nn.gelu(feat_objs)
        feat_objs = nn.Dense(self.base_dim)(feat_objs) + feat_skip # (NB, NO, NFPS, NF)

        # generate one global features
        img_positional_embeddings = self.param('img_positional_embeddings',
                                    nn.initializers.lecun_normal(), # Initialization function
                                    (feat_pixel_size[0]*feat_pixel_size[1], 16))  # shape info.

        feat_objs_float = einops.rearrange(feat_objs, '... i j f -> ... (i j) f')
        img_positional_embeddings = jnp.broadcast_to(img_positional_embeddings, feat_objs_float.shape[:-2]+img_positional_embeddings.shape[-2:])
        base_feature_img = nn.MultiHeadAttention(num_heads=2, qkv_features=self.base_dim//2, out_features=self.base_dim)(img_positional_embeddings, feat_objs_float)
        base_feature_img = base_feature_img.reshape(*base_feature_img.shape[:-2], *feat_pixel_size, *base_feature_img.shape[-1:])
        base_feature_img = jax.image.resize(base_feature_img, base_feature_img.shape[:-3]+intermediate_pixel_size+base_feature_img.shape[-1:], method='linear')

        # mask invalid objs
        # feat_objs = jnp.where(object_valid_mask[...,None,None], feat_objs, 0)

        # fill projected features - set by idx
        # def fill_by_idx(values, indices):
        #     filled_res = jnp.zeros((np.prod(intermediate_pixel_size), values.shape[-1]))
        #     return filled_res.at[indices].set(values)
        # segment_ids = pixel_coord.astype(jnp.int32)
        # segment_ids = segment_ids[...,0]*intermediate_pixel_size[1] + segment_ids[...,1]
        # origin_outer_shape = segment_ids.shape[:-1]
        # flat_inputs = (feat_objs.reshape(-1, *feat_objs.shape[-2:]), segment_ids.reshape(-1, *segment_ids.shape[-1:]))
        # feat_flat_per_obj = jax.vmap(fill_by_idx)(*flat_inputs)
        # feat_flat_per_obj = feat_flat_per_obj.reshape(*origin_outer_shape, *feat_flat_per_obj.shape[-2:])
        # feat_res = einops.rearrange(feat_flat_per_obj, '... (i j) k -> ... i j k', i=intermediate_pixel_size[0], j=intermediate_pixel_size[1]) # fps per obj

        # # fill projected features - segment max over obj and fps
        segment_ids = pixel_coord.astype(jnp.int32)
        segment_ids = segment_ids[...,0]*intermediate_pixel_size[1] + segment_ids[...,1]
        flat_inputs = (feat_objs.reshape(*feat_objs.shape[:-3], -1, feat_objs.shape[-1]), segment_ids.reshape(*segment_ids.shape[:-2], -1))
        # flatten outer shape and recover
        origin_outer_shape = flat_inputs[1].shape[:-1]
        flat_inputs = (flat_inputs[0].reshape(-1, *flat_inputs[0].shape[-2:]), flat_inputs[1].reshape(-1,*flat_inputs[1].shape[-1:]))
        feat_flat = jax.vmap(partial(jax.ops.segment_max, num_segments=np.prod(intermediate_pixel_size)))(*flat_inputs) # aggregate over obj
        feat_flat = feat_flat.reshape(*origin_outer_shape, *intermediate_pixel_size, *feat_flat.shape[-1:])
        feat_flat = jnp.where(jnp.isfinite(feat_flat), feat_flat, base_feature_img)
        feat_res = jnp.where(jnp.abs(feat_flat)<1e6, feat_flat, base_feature_img)
        # feat_res = einops.rearrange(feat_flat, '... (i j) k -> ... i j k', i=intermediate_pixel_size[0], j=intermediate_pixel_size[1])
        
        # fps projection results
        pixel_fps_points = jnp.any(jnp.abs(feat_res)>1e-6, axis=-1)
        pixel_fps_points = jnp.any(pixel_fps_points, axis=-3)[...,None]

        # # # debug visualization - projected points and rgb images
        # vis_batch_idx = 0
        # for vis_batch_idx in range(feat_res.shape[0]):
        #     vis_feat_res = feat_res[vis_batch_idx]
        #     vis_rgb = img_feat_struct.rgb[vis_batch_idx]
            
        #     # perform PCA for visualization
        #     import matplotlib.pyplot as plt
        #     from sklearn.decomposition import PCA
        #     pca = PCA(n_components=3)
        #     pca.fit(vis_feat_res.reshape(-1, vis_feat_res.shape[-1]))
        #     vis_feat_res_pca = pca.transform(vis_feat_res.reshape(-1, vis_feat_res.shape[-1])).reshape(vis_feat_res.shape[:-1]+(3,))
        #     vis_feat_res_pca = (vis_feat_res_pca - vis_feat_res_pca.min()) / (vis_feat_res_pca.max() - vis_feat_res_pca.min())

        #     plt.figure()
        #     plt.subplot(3, 2, 1)
        #     plt.imshow(vis_rgb[-1])
        #     plt.subplot(3, 2, 2)
        #     plt.imshow(vis_feat_res_pca[0])
        #     plt.subplot(3, 2, 3)
        #     plt.imshow(vis_feat_res_pca[1])
        #     plt.subplot(3, 2, 4)
        #     plt.imshow(vis_feat_res_pca[2])
        #     # plt.subplot(3, 2, 5)
        #     # plt.imshow(vis_feat_res_pca[3])
        #     # plt.show()
        #     plt.savefig(f'tmp_{vis_batch_idx}.png')
        # # # debug visualization - projected points and rgb images

        for _ in range(2):
            skip = feat_res
            for _ in range(2):
                feat_res = nn.Conv(self.base_dim, (5,5))(feat_res)
                feat_res = nn.gelu(feat_res)
            feat_res += skip
        # feat_res = jnp.max(feat_res, axis=-4) # aggregate over objects

        feat_res = jax.image.resize(feat_res, feat_res.shape[:-3] + output_pixel_size + feat_res.shape[-1:], method='linear')
        skip = feat_res

        # apply resnet blocks
        for _ in range(2):
            feat_res = nn.Conv(self.base_dim, (3,3))(feat_res)
            feat_res = nn.gelu(feat_res)
        feat_res += skip

        # rgb construction
        rgb_img = nn.Dense(3)(feat_res)
        if not train:
            rgb_img = rgb_img.clip(0, 1)

        return rgb_img, pixel_fps_points


class FeatureDecoder(nn.Module):
    '''
    img feature -> segmented RGB images
    '''
    base_dim: int

    @nn.compact
    def __call__(self, img_feat_struct: structs.ImgFeatures, rgb_size, train=False):
        img_feat = img_feat_struct.img_feat  # (batch_size, height, width, channels)

        # transpose upsampling
        x = nn.ConvTranspose(features=self.base_dim*2, kernel_size=(3, 3), strides=(2, 2),)(img_feat)
        x = nn.gelu(x)
        x = nn.Conv(features=self.base_dim*2, kernel_size=(3, 3))(x)
        x = nn.gelu(x)
        x = nn.ConvTranspose(features=self.base_dim, kernel_size=(3, 3), strides=(2, 2),)(x)
        x = nn.gelu(x)
        x = nn.Conv(features=self.base_dim, kernel_size=(3, 3))(x)
        x = nn.gelu(x)

        x = jax.image.resize(img_feat, img_feat.shape[:-3] + rgb_size + img_feat.shape[-1:], method='linear')

        # # Determine the number of upsampling steps needed to reach the desired output size
        # feat_size = img_feat.shape[-3:-1]            # Current feature map size

        # # Calculate the number of times we need to upsample to reach the target size
        # upsample_steps = int(np.log2(rgb_size[0] / feat_size[0]))

        # x = img_feat
        # for _ in range(upsample_steps):
        #     x = nn.ConvTranspose(
        #         features=self.base_dim,
        #         kernel_size=(3, 3),
        #         strides=(2, 2),
        #         padding='SAME'
        #     )(x)
        #     x = nn.relu(x)
        # # x = cutil.resize_img(x, rgb_size)
        # x = jax.image.resize(x, x.shape[:-3] + rgb_size + x.shape[-1:], method='linear')

        # # Additional convolutional layers to refine the output
        # x = nn.Conv(features=self.base_dim // 2, kernel_size=(3, 3), padding='SAME')(x)
        # x = nn.relu(x)
        # x = nn.Conv(features=self.base_dim // 4, kernel_size=(3, 3), padding='SAME')(x)
        # x = nn.relu(x)

        # one resnet block
        skip = x
        x = nn.Conv(self.base_dim, (3,3))(x)
        x = nn.gelu(x)
        x = nn.Conv(self.base_dim, (3,3))(x)
        x = nn.gelu(x)
        x += nn.Dense(self.base_dim)(skip)

        # Final convolution to produce the RGB image (assuming 3 color channels)
        rgb_img = nn.Dense(3)(x)

        # Apply activation if needed (e.g., sigmoid for values between 0 and 1)
        # rgb_img = nn.sigmoid(rgb_img)
        if not train:
            rgb_img = rgb_img.clip(0, 1)

        return rgb_img



if __name__ == '__main__':
    np.random.seed(0)
    cond_feat = cutil.default_cond_feat(pixel_size=[32,32])
    cond_feat = cond_feat.replace(img_feat=np.arange(32*32*2).reshape(1,32,32,2).astype(jnp.float32))
    dc_centers_tf = np.random.uniform(-1,1,size=(10000,4,3))
    res = extract_pixel_features(dc_centers_tf, cond_feat)

    # grad = jax.grad(lambda x: jnp.sum(extract_pixel_features(*x)[0]))((dc_centers_tf, cond_feat))

    print(1)