import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from timm.models.layers import DropPath, trunc_normal_
from itertools import product
from .build import MODELS
from utils import misc
from utils.logger import *
from extensions.chamfer_dist import ChamferDistanceL2
from models.transformer import TransformerEncoder, TransformerDecoder, FusionTransformer, SeqInvariantTransformer
from models.point_transformer import Group, Encoder, PointEncoder, FourierPosEmbedding
from models.prompt import TextEncoder, ImageEncoder
from models.vqgan import VQGAN
from modules.voxelization import Voxelization, voxel_to_point


def random_mask(x):
    B, N, _ = x.shape

    overall_mask = np.zeros([B, N])
    for i in range(B):
        num_mask = int(np.cos(np.random.random() * np.pi * 0.5) * N)
        mask = np.hstack([
            np.zeros(N - num_mask),
            np.ones(num_mask),
        ])
        np.random.shuffle(mask)
        overall_mask[i, :] = mask
    overall_mask = torch.from_numpy(overall_mask).to(torch.bool)
    return overall_mask.to(x.device)


class GridSmoother(nn.Module):
    def __init__(self, config, **kwargs):
        super().__init__()
        self.config = config
        self.trans_dim = config.trans_dim
        self.depth = config.depth
        self.num_heads = config.num_heads
        print_log(f'[args] {config}', logger='GridSmoother')

        self.embed = nn.Linear(3, self.trans_dim, bias=False)
        self.projection = nn.Linear(self.trans_dim, 3, bias=False)
        self.voxelization = Voxelization(config=config)

        dpr = [x.item() for x in torch.linspace(0, 0.1, self.depth)]
        self.blocks = SeqInvariantTransformer(
            embed_dim=self.trans_dim,
            depth=self.depth,
            drop_path_rate=dpr,
            num_heads=self.num_heads,
        )

        self.rec_point_loss = ChamferDistanceL2().cuda()
        self.homogeneity_loss = HomogeneityLoss(k=5)

    def forward(self, pts):
        B = pts.shape[0]
        voxel = self.voxelization(pts)
        rec_point_loss = torch.zeros(1).to(pts.device)
        for i in range(B):
            grid_point = voxel_to_point(voxel[i])
            N = grid_point.shape[0]

            x = self.embed(grid_point.unsqueeze(dim=0))
            x = self.blocks(x)
            grid_point = self.projection(x)

            center_point = misc.fps(pts[i].unsqueeze(dim=0), N)
            rec_point_loss += self.rec_point_loss(grid_point, center_point) / B
        return rec_point_loss

    def inference(self, grid_point):
        x = self.embed(grid_point)
        x = self.blocks(x)
        pred_point = self.projection(x)
        return pred_point


class VoxelGenerator(nn.Module):
    def __init__(self, config, **kwargs):
        super().__init__()
        self.config = config
        self.trans_dim = config.vqgan_config.codebook_dim
        self.depth = config.depth
        self.num_heads = config.num_heads

        self.cfg_ratio = config.get('cfg_ratio', 0.0)

        self.codebook_resolution = config.vqgan_config.resolution // config.vqgan_config.down_sample
        self.codebook_num = config.vqgan_config.codebook_num

        self.prompt_dim = config.prompt_dim
        self.temperature = config.get('temperature', 1.5)
        print_log(f'[args] {config}', logger='VoxelGenerator')

        n_points = self.codebook_resolution
        x = torch.linspace(-1, 1, n_points)
        y = torch.linspace(-1, 1, n_points)
        z = torch.linspace(-1, 1, n_points)

        self.grid = nn.Parameter(torch.Tensor(list(product(x, y, z))), requires_grad=False)

        self.mask_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim))
        self.projection = nn.Linear(self.trans_dim, self.codebook_num, bias=False)

        self.pos_embed = nn.Sequential(
            nn.Linear(3, 128),
            nn.GELU(),
            nn.Linear(128, self.trans_dim),
        )

        dpr = [x.item() for x in torch.linspace(0, 0.1, self.depth)]
        self.blocks = FusionTransformer(
            embed_dim=self.trans_dim,
            prompt_dim=self.prompt_dim,
            depth=self.depth,
            drop_path_rate=dpr,
            num_heads=self.num_heads,
            cfg_ratio=self.cfg_ratio
        )

        self.norm = nn.LayerNorm(self.trans_dim)
        self.apply(self._init_weights)
        trunc_normal_(self.mask_token, std=.02)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv1d):
            trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, prompt_features, encoded_voxel):
        B, C, R, _, _ = encoded_voxel.shape
        encoded_voxel = encoded_voxel.reshape(B, C, R * R * R)
        encoded_voxel = encoded_voxel.transpose(1, 2)
        mask = random_mask(encoded_voxel)
        encoded_voxel[mask] = self.mask_token

        pos = self.pos_embed(self.grid)
        x = self.blocks(encoded_voxel, pos, prompt_features)
        x = self.norm(x)

        logits = self.projection(x)

        return logits, mask

    def inference(self, prompt_features, codebooks, steps):

        B, _ = prompt_features.shape
        R = self.codebook_resolution
        device = next(self.parameters()).device

        mask_token = self.mask_token.expand(B, R * R * R, -1).clone()
        pos_full = self.pos_embed(self.grid)
        prob = torch.zeros(B, R * R * R, self.codebook_num, device=device)
        topk_indices = None

        temp_grid = 1 - torch.mean(self.grid.clone()**2, dim=-1).reshape(R, R, R) + 1e-2

        for step in range(steps):

            mask_ratio = float(np.cos((step + 1) / steps * np.pi * 0.5))
            fix_len = int((1 - mask_ratio) * R * R * R)

            # transformer
            x = self.blocks(mask_token, pos_full, prompt_features)
            x = self.norm(x)
            logits = self.projection(x)

            temperature = temp_grid.reshape(1, R * R * R, 1) * self.temperature * (1 - step / steps)

            prob = F.softmax(logits / temperature, dim=-1)  # B, N, codebook_num
            index = self.probabilistic_select(prob)  # B, N, 1
            quant_embeding = codebooks(index.squeeze(dim=-1))  # B, N, codebook_dim

            score = torch.gather(prob, 2, index.long()).squeeze(dim=-1)  # B, N
            if topk_indices is not None:
                for i in range(B):
                    for j in range(topk_indices.shape[1]):
                        score[i][topk_indices[i][j]] = 1.0
            _, topk_indices = torch.topk(score, k=fix_len, dim=-1)  # B, fix_len

            mask_token = self.mask_token.expand(B, R * R * R, -1).clone()
            for i in range(B):
                for j in range(fix_len):
                    mask_token[i][topk_indices[i][j]] = quant_embeding[i][topk_indices[i][j]]

        _, logits_top1_indices = torch.topk(prob, k=1, dim=-1)
        features = codebooks(logits_top1_indices.squeeze(dim=-1))
        features = features.transpose(1, 2)
        features = features.reshape(B, self.trans_dim, R, R, R)

        return features

    def part_generation(self, tokens, mask, prompt_features, codebooks, steps):

        B, _ = prompt_features.shape
        R = self.codebook_resolution
        device = next(self.parameters()).device

        tokens = tokens.reshape(B, R * R * R, -1)
        part_num = R * R * R - mask[0].sum()
        mask_token = tokens.clone()

        pos_full = self.pos_embed(self.grid)
        prob = torch.zeros(B, R * R * R, self.codebook_num, device=device)
        topk_indices = None

        temp_grid = 1 - torch.mean(self.grid.clone()**2, dim=-1).reshape(R, R, R) + 1e-2

        for step in range(steps):

            mask_ratio = float(np.cos((step + 1) / steps * np.pi * 0.5))
            fix_len = int((1 - mask_ratio) * mask[0].sum()) + part_num

            # transformer
            x = self.blocks(mask_token, pos_full, prompt_features)
            x = self.norm(x)
            logits = self.projection(x)

            temperature = temp_grid.reshape(1, R * R * R, 1) * self.temperature * (1 - step / steps)

            prob = F.softmax(logits / temperature, dim=-1)  # B, N, codebook_num
            index = self.probabilistic_select(prob)  # B, N, 1
            quant_embeding = codebooks(index.squeeze(dim=-1))  # B, N, codebook_dim

            score = torch.gather(prob, 2, index.long()).squeeze(dim=-1)  # B, N
            score[~mask] = 1.0
            if topk_indices is not None:
                for i in range(B):
                    for j in range(topk_indices.shape[1]):
                        score[i][topk_indices[i][j]] = 1.0
            _, topk_indices = torch.topk(score, k=fix_len, dim=-1)  # B, fix_len

            mask_token = self.mask_token.expand(B, R * R * R, -1).clone()
            for i in range(B):
                for j in range(fix_len):
                    mask_token[i][topk_indices[i][j]] = quant_embeding[i][topk_indices[i][j]]

        _, logits_top1_indices = torch.topk(prob, k=1, dim=-1)
        features = codebooks(logits_top1_indices.squeeze(dim=-1))
        features = features.transpose(1, 2)
        features = features.reshape(B, self.trans_dim, R, R, R)

        return features

    def probabilistic_select(self, logits):
        B, N, _ = logits.shape
        probs = logits.reshape(B * N, -1)
        topk_probs, topk_indices = torch.topk(probs, k=50, dim=1)
        index = torch.multinomial(topk_probs, num_samples=1)
        select = torch.gather(topk_indices, dim=1, index=index)
        return select.reshape(B, N, 1)


class PointUpsampler(nn.Module):
    def __init__(self, config, **kwargs):
        super().__init__()
        self.config = config
        # define the transformer argparse
        self.trans_dim = config.trans_dim
        self.depth = config.depth
        self.num_heads = config.num_heads
        print_log(f'[args] {config}', logger='TokenGenerator')

        self.mask_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim))
        self.pos_embed = nn.Sequential(
            nn.Linear(3, 128),
            nn.GELU(),
            nn.Linear(128, self.trans_dim),
        )

        dpr = [x.item() for x in torch.linspace(0, 0.1, self.depth)]
        self.blocks = TransformerEncoder(
            embed_dim=self.trans_dim,
            depth=self.depth,
            drop_path_rate=dpr,
            num_heads=self.num_heads,
        )

        self.norm = nn.LayerNorm(self.trans_dim)
        self.apply(self._init_weights)
        trunc_normal_(self.mask_token, std=.02)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv1d):
            trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, features, center):

        bool_masked_pos = random_mask(center)
        features[bool_masked_pos] = self.mask_token
        pos = self.pos_embed(center)

        # transformer
        x = self.blocks(features, pos)
        x = self.norm(x)

        return x, bool_masked_pos

    def inference(self, center):
        B, G, _ = center.shape
        x_full = self.mask_token.expand(B, G, -1)
        pos_full = self.pos_embed(center)
        # transformer
        x = self.blocks(x_full, pos_full)
        x = self.norm(x)

        return x


class Decoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.trans_dim = config.point_config.trans_dim
        self.group_size = config.group_size
        self.decoder_pos_embed = nn.Sequential(
            nn.Linear(3, 128),
            nn.GELU(),
            nn.Linear(128, self.trans_dim)
        )
        self.with_color = config.with_color
        self.channel = 6 if self.with_color else 3
        self.decoder_depth = config.decoder_depth
        self.decoder_num_heads = config.decoder_num_heads
        dpr = [x.item() for x in torch.linspace(0, 0.1, self.decoder_depth)]
        self.blocks = TransformerDecoder(
            embed_dim=self.trans_dim,
            depth=self.decoder_depth,
            drop_path_rate=dpr,
            num_heads=self.decoder_num_heads,
        )
        # prediction head
        self.increase_dim = nn.Conv1d(self.trans_dim, self.channel * self.group_size, 1)

    def forward(self, center, token):
        decoder_pos = self.decoder_pos_embed(center)
        x_rec = self.blocks(token, decoder_pos)
        rebuild_points = self.increase_dim(x_rec.transpose(1, 2)).transpose(1, 2)
        return rebuild_points


class HomogeneityLoss(nn.Module):
    def __init__(self, k, **kwargs):
        super().__init__()
        self.k = k

    def forward(self, center):
        B, N, _ = center.shape
        loss = torch.zeros(B).cuda()
        for i in range(B):
            dist_matrix = torch.cdist(center[i], center[i])
            _, indices = torch.topk(dist_matrix, k=self.k + 1, largest=False)
            neighbor_indices = indices[:, 1:]
            mean_dists = torch.gather(dist_matrix, dim=1, index=neighbor_indices).mean(dim=1)

            p = F.softmax(mean_dists, dim=0)
            q = torch.ones_like(p) / len(p)
            loss[i] = F.kl_div(p.log(), q, reduction='sum')

        return loss.mean()


@MODELS.register_module()
class MaskDream(nn.Module):
    def __init__(self, config):
        super().__init__()
        print_log(f'[MaskDream] ', logger='MaskDream')
        self.config = config
        self.mode = config.mode
        self.with_color = config.get('with_color', False)
        print_log(f'Training with color: {self.with_color}', logger='MaskDream')

        if self.mode == 'coarse':
            # load condition encoder
            self.text_encoder = TextEncoder(config=config)
            self.image_encoder = ImageEncoder(config=config)
            for p in self.text_encoder.parameters():
                p.requires_grad = False
            for p in self.image_encoder.parameters():
                p.requires_grad = False

            # load VQGAN
            config.vqgan_config.with_color = self.with_color
            self.vqgan = VQGAN(config=config.vqgan_config)
            self.vqgan.load_model_from_ckpt(ckpt_path=config.vqgan_config.ckpt_path)
            for p in self.vqgan.parameters():
                p.requires_grad = False

            # load voxel_generator
            config.voxel_config.prompt_dim = self.text_encoder.embed_dim
            config.voxel_config['vqgan_config'] = config.vqgan_config
            self.voxel_generator = VoxelGenerator(config=config.voxel_config)

            # loss
            self.rec_codebook_loss = nn.CrossEntropyLoss()

        elif self.mode == 'smoother':
            # load gird_smoother
            self.grid_smoother = GridSmoother(config=config)

        elif self.mode == 'fine':
            # load point_encoder
            config.encoder_config.num_group = config.num_group
            config.encoder_config.group_size = config.group_size
            self.encoder = PointEncoder(config=config.encoder_config)
            self.encoder.load_model_from_ckpt(ckpt_path=config.encoder_config.ckpt_path)
            for p in self.encoder.parameters():
                p.requires_grad = False

            # load point_upsampler
            self.group_size = config.point_config.group_size = config.group_size
            self.num_group = config.point_config.num_group = config.num_group
            config.point_config.with_color = self.with_color
            self.point_upsampler = PointUpsampler(config=config.point_config)

            # load decoder
            self.decoder = Decoder(config=config)

            # loss
            self.rec_token_loss = nn.SmoothL1Loss()
            self.rec_point_loss = ChamferDistanceL2().cuda()
        else:
            raise NotImplementedError

    def training_voxel_generator(self, pts, image, text, **kwargs):

        voxel, encoded_voxel, codebook_indices, _ = self.vqgan.encode(pts)

        text_features = self.text_encoder(text)
        image_features = self.image_encoder(image)
        prompt_features = 0.7 * image_features + 0.3 * text_features

        logits, mask = self.voxel_generator(prompt_features, encoded_voxel)
        B, N = mask.shape
        codebook_indices = codebook_indices.reshape(B, N)
        rec_codebook_loss = self.rec_codebook_loss(logits[mask], codebook_indices[mask])

        _, pred_encoding_indices = torch.topk(logits, k=1, dim=-1)
        pred_encoding_indices = pred_encoding_indices.squeeze(dim=-1)
        z_q = self.vqgan.codebook.embedding(pred_encoding_indices)
        z_q = z_q.transpose(1, 2)
        z_q = z_q.view(encoded_voxel.shape)
        decoded_voxel = self.vqgan.decode(z_q)

        return rec_codebook_loss, voxel, decoded_voxel

    def training_grid_smoother(self, pts, **kwargs):

        grid_smoother_loss = self.grid_smoother(pts[:, :, :3])
        return grid_smoother_loss

    def training_point_upsampler(self, pts):
        point_features, neighborhood, center = self.encoder(pts[:, :, :3])
        B, M, _ = center.shape

        point_features_rec, mask = self.point_upsampler(point_features, center)
        rec_token_loss = self.rec_token_loss(point_features[mask], point_features_rec[mask]).mean()

        rebuild_points = self.decoder(center, point_features_rec).reshape(B * M, -1, self.decoder.channel)
        gt_points = neighborhood.reshape(B * M, -1, self.decoder.channel)
        rec_point_loss = self.rec_point_loss(rebuild_points, gt_points)

        print(rec_token_loss, rec_point_loss)
        loss = rec_token_loss + rec_point_loss

        return loss


@MODELS.register_module()
class MaskDreamInference(nn.Module):
    def __init__(self, config):
        super().__init__()
        print_log(f'[MaskDreamInference] ', logger='MaskDreamInference')
        self.config = config
        self.group_size = config.get('group_size', 32)
        self.npoints = config.get('npoints', 4096)
        self.num_group = int(self.npoints / self.group_size)
        self.steps = config.get('steps', 8)
        self.with_color = config.with_color
        self.channel = 6 if self.with_color else 3

        # load text encoder
        if 'prompt_encoder' in config.keys():
            self.text_encoder = TextEncoder(config=config)
            self.image_encoder = ImageEncoder(config=config)
            self.prompt_dim = self.text_encoder.embed_dim

        # load vqgan
        if 'voxel_config' in config.keys():
            config.vqgan_config.with_color = self.with_color
            self.vqgan = VQGAN(config=config.vqgan_config)

        # load voxel_generator
        if 'voxel_config' in config.keys():
            config.voxel_config.prompt_dim = self.prompt_dim
            config.voxel_config['vqgan_config'] = config.vqgan_config
            self.voxel_generator = VoxelGenerator(config=config.voxel_config)

        # load grid_smoother
        if 'smooth_config' in config.keys():
            config.smooth_config.resolution = config.vqgan_config.resolution
            self.grid_smoother = GridSmoother(config=config.smooth_config)

        # load point upsampler and decoder
        if 'point_config' in config.keys():
            config.point_config.with_color = self.with_color
            self.point_upsampler = PointUpsampler(config=config.point_config)
            self.decoder = Decoder(config=config)

        self.load_model_from_ckpt()
        for p in self.parameters():
            p.requires_grad = False

    def load_model_from_ckpt(self):
        ckpt = {}

        if 'voxel_config' in self.config.keys():
            voxel_generator = torch.load(self.config.voxel_config.ckpt_path, map_location='cpu')
            voxel_ckpt = {k.replace("module.", ""): v for k, v in voxel_generator['base_model'].items()}
            ckpt.update(voxel_ckpt)
            print_log(f'[MaskDreamInference] Loading voxel_generator...', logger='MaskDreamInference')

        if 'smooth_config' in self.config.keys():
            grid_smoother = torch.load(self.config.smooth_config.ckpt_path, map_location='cpu')
            smooth_ckpt = {k.replace("module.", ""): v for k, v in grid_smoother['base_model'].items()}
            ckpt.update(smooth_ckpt)
            print_log(f'[MaskDreamInference] Loading grid_smoother...', logger='MaskDreamInference')

        if 'point_config' in self.config.keys():
            point_upsampler = torch.load(self.config.point_config.ckpt_path, map_location='cpu')
            point_ckpt = {k.replace("module.", ""): v for k, v in point_upsampler['base_model'].items()}
            ckpt.update(point_ckpt)
            print_log(f'[MaskDreamInference] Loading point_upsampler...', logger='MaskDreamInference')

        state_dict = self.state_dict()
        for key in state_dict.keys():
            if key not in ckpt.keys():
                raise ValueError(f"missing ckpt keys: {key}")
            state_dict[key] = ckpt[key]
        self.load_state_dict(state_dict, strict=True)

        print_log(f'[MaskDreamInference] Successful Loading all the ckpt', logger='MaskDreamInference')

    def upsample(self, center, **kwargs):

        B, M, _ = center.shape
        point_features = self.point_upsampler.inference(center)
        rebuild_points = self.decoder(center, point_features).reshape(B, M, -1, self.channel)
        rebuild_points[:, :, :, :3] = rebuild_points[:, :, :, :3] + center.unsqueeze(dim=2)
        rebuild_points = rebuild_points.reshape(B, -1, self.channel)

        return rebuild_points

    def part_generation(self, pts, mask_axis=0, mask_direct=True, mask_pos=0, text="a 3d model."):
        B, N, C = pts.shape
        voxel = self.vqgan.encoder.voxelization(pts)
        _, _, R, _, _ = voxel.shape
        voxel_channel = 4 if self.with_color else 1
        mask_direct = 1 if mask_direct else -1

        num_voxel = R * R * R
        num_mask = int(R * R * R * (1 + mask_pos * mask_direct) / 2)

        mask = np.hstack([
            np.zeros(num_voxel - num_mask),
            np.ones(num_mask),
        ])
        mask = torch.from_numpy(mask).to(torch.bool).reshape(R, R, R)
        mask = mask.transpose(0, mask_axis)
        mask = mask.reshape(1, 1, R, R, R).expand(B, voxel_channel, R, R, R)
        voxel[mask] = 0

        down_sample = self.vqgan.config.down_sample
        mask = mask[:, 0, ::down_sample, ::down_sample, ::down_sample].reshape(B, -1)

        text_features = self.text_encoder(text)
        codebooks = self.vqgan.codebook.embedding
        token = self.vqgan.encoder.blocks(voxel)

        features = self.voxel_generator.part_generation(token, mask, text_features, codebooks, self.steps)
        decoded_voxel = self.vqgan.decode(features)
        grid_center = torch.zeros(B, self.num_group, self.channel, device=decoded_voxel.device)

        for i in range(B):
            points = voxel_to_point(decoded_voxel[i]).contiguous()
            num_points = points.shape[0]
            if num_points >= self.num_group:
                points = misc.fps(points.unsqueeze(0), self.num_group).squeeze(0)
            elif num_points > 0:
                points = F.interpolate(points.reshape(1, self.channel, num_points),
                                       size=self.num_group, mode='linear').reshape(self.num_group, self.channel)
            else:
                points = torch.zeros(self.num_group, self.channel, device=decoded_voxel.device)
            grid_center[i] = points

        coordinate = grid_center[:, :, :3]
        center = self.grid_smoother.inference(coordinate)

        B, M, _ = center.shape
        point_features = self.point_upsampler.inference(center)
        rebuild_points = self.decoder(center, point_features).reshape(B, M, -1, 3)
        rebuild_points = rebuild_points + center.unsqueeze(2)
        rebuild_points = rebuild_points.reshape(B, -1, 3)

        if mask_direct:
            mask_index = pts[:, :, mask_axis] > mask_pos
        else:
            mask_index = pts[:, :, mask_axis] < mask_pos
        part_points = pts[~mask_index].reshape(B, -1, C)
        pre_points = rebuild_points[mask_index].reshape(B, -1, C)
        rebuild_points = torch.cat([part_points, pre_points], dim=1)

        return part_points, rebuild_points

    def forward(self, prompt_features, **kwargs):

        B, _ = prompt_features.shape

        codebooks = self.vqgan.codebook.embedding
        quant_embeding = self.voxel_generator.inference(prompt_features, codebooks, self.steps)
        decoded_voxel = self.vqgan.decode(quant_embeding)

        if B == 1:
            grid_center = voxel_to_point(decoded_voxel[0]).contiguous().unsqueeze(0)
        else:
            grid_center = torch.zeros(B, self.num_group, self.channel, device=decoded_voxel.device)
            for i in range(B):
                points = voxel_to_point(decoded_voxel[i]).contiguous()
                num_points = points.shape[0]
                if num_points >= self.num_group:
                    points = misc.fps(points.unsqueeze(0), self.num_group).squeeze(0)
                elif num_points > 0:
                    points = F.interpolate(points.reshape(1, self.channel, num_points),
                                           size=self.num_group, mode='linear').reshape(self.num_group, self.channel)
                else:
                    points = torch.zeros(self.num_group, self.channel, device=decoded_voxel.device)
                grid_center[i] = points

        coordinate = grid_center[:, :, :3]
        center = self.grid_smoother.inference(coordinate)

        B, M, _ = center.shape
        point_features = self.point_upsampler.inference(center)
        rebuild_points = self.decoder(center, point_features).reshape(B, M, -1, 3)
        rebuild_points = rebuild_points + center.unsqueeze(2)
        rebuild_points = rebuild_points.reshape(B, -1, 3)

        return rebuild_points

    def text_condition_generation(self, text):
        text_features = self.text_encoder(text)
        return self.forward(text_features)

    def image_condition_generation(self, img):
        img_features = self.image_encoder(img)
        return self.forward(img_features)
