import torch
import torch.nn as nn
import torch_scatter
import torch_cluster
import torch.nn.functional as F

from pointcept.models.losses import build_criteria
from pointcept.models.utils.structure import Point
from pointcept.models.utils import offset2batch
from .builder import MODELS, build_model
from collections import OrderedDict


@MODELS.register_module()
class DefaultSegmentor(nn.Module):
    def __init__(self, backbone=None, criteria=None):
        super().__init__()
        self.backbone = build_model(backbone)
        self.criteria = build_criteria(criteria)

    def forward(self, input_dict):
        if "condition" in input_dict.keys():
            # currently, only support one batch one condition
            input_dict["condition"] = input_dict["condition"][0]
        seg_logits = self.backbone(input_dict)
        # train
        if self.training:
            loss = self.criteria(seg_logits, input_dict["segment"])
            return dict(loss=loss)
        # eval
        elif "segment" in input_dict.keys():
            loss = self.criteria(seg_logits, input_dict["segment"])
            return dict(loss=loss, seg_logits=seg_logits)
        # test
        else:
            return dict(seg_logits=seg_logits)


@MODELS.register_module()
class DefaultSegmentorV2(nn.Module):
    def __init__(
        self,
        num_classes,
        backbone_out_channels,
        backbone=None,
        criteria=None,
        freeze_backbone=False,
    ):
        super().__init__()
        self.seg_head = (
            nn.Linear(backbone_out_channels, num_classes)
            if num_classes > 0
            else nn.Identity()
        )
        self.backbone = build_model(backbone)
        self.criteria = build_criteria(criteria)
        self.freeze_backbone = freeze_backbone
        if self.freeze_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False

    def forward(self, input_dict, return_point=False):
        point = Point(input_dict)
        if self.freeze_backbone:
            with torch.no_grad():
                point = self.backbone(point)
        else:
            point = self.backbone(point)
        if isinstance(point, Point):
            while "pooling_parent" in point.keys():
                assert "pooling_inverse" in point.keys()
                parent = point.pop("pooling_parent")
                inverse = point.pop("pooling_inverse")
                parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1)
                point = parent
            feat = point.feat
        else:
            feat = point
        seg_logits = self.seg_head(feat)
        return_dict = dict()
        if return_point:
            # PCA evaluator parse feat and coord in point
            return_dict["point"] = point
        # train
        if self.training:
            loss = self.criteria(seg_logits, input_dict["segment"])
            return_dict["loss"] = loss
        # eval
        elif "segment" in input_dict.keys():
            loss = self.criteria(seg_logits, input_dict["segment"])
            return_dict["loss"] = loss
            return_dict["seg_logits"] = seg_logits
        # test
        else:
            return_dict["seg_logits"] = seg_logits
        return return_dict


@MODELS.register_module()  # MLP
class DefaultSegmentorV3(nn.Module):
    def __init__(
        self,
        num_classes,
        backbone_out_channels,
        backbone=None,
        criteria=None,
        freeze_backbone=False,
        hidden_dims=[512, 256],
        dropout=0.1,
        activation="relu",
    ):
        super().__init__()

        if num_classes > 0:
            layers = []
            in_dim = backbone_out_channels

            for hidden_dim in hidden_dims:
                layers.append(nn.Linear(in_dim, hidden_dim))

                if activation == "relu":
                    layers.append(nn.ReLU())
                elif activation == "gelu":
                    layers.append(nn.GELU())
                elif activation == "leakyrelu":
                    layers.append(nn.LeakyReLU(0.2))

                layers.append(nn.BatchNorm1d(hidden_dim))

                if dropout > 0:
                    layers.append(nn.Dropout(dropout))

                in_dim = hidden_dim

            layers.append(nn.Linear(in_dim, num_classes))
            self.seg_head = nn.Sequential(*layers)
        else:
            self.seg_head = nn.Identity()

        self.backbone = build_model(backbone)
        self.criteria = build_criteria(criteria)
        self.freeze_backbone = freeze_backbone

        if self.freeze_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False

    def forward(self, input_dict, return_point=False):
        point = Point(input_dict)

        if self.freeze_backbone:
            with torch.no_grad():
                point = self.backbone(point)
        else:
            point = self.backbone(point)

        if isinstance(point, Point):
            while "pooling_parent" in point.keys():
                assert "pooling_inverse" in point.keys()
                parent = point.pop("pooling_parent")
                inverse = point.pop("pooling_inverse")
                parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1)
                point = parent
            feat = point.feat
        else:
            feat = point

        seg_logits = self.seg_head(feat)

        return_dict = dict()
        if return_point:
            return_dict["point"] = point

        if self.training:
            loss = self.criteria(seg_logits, input_dict["segment"])
            return_dict["loss"] = loss
        elif "segment" in input_dict.keys():
            loss = self.criteria(seg_logits, input_dict["segment"])
            return_dict["loss"] = loss
            return_dict["seg_logits"] = seg_logits
        else:
            return_dict["seg_logits"] = seg_logits

        return return_dict


from peft import LoraConfig, get_peft_model


@MODELS.register_module()
class DefaultLORASegmentorV2(nn.Module):
    def __init__(
        self,
        num_classes,
        backbone_out_channels,
        backbone=None,
        criteria=None,
        freeze_backbone=False,
        use_lora=False,
        lora_r=8,
        lora_alpha=16,
        lora_dropout=0.1,
        backbone_path=None,
        keywords=None,
        replacements=None,
    ):
        super().__init__()
        self.seg_head = (
            nn.Linear(backbone_out_channels, num_classes)
            if num_classes > 0
            else nn.Identity()
        )
        self.keywords = keywords
        self.replacements = replacements
        self.backbone = build_model(backbone)
        backbone_weight = torch.load(
            backbone_path,
            map_location=lambda storage, loc: storage.cuda(),
        )
        self.backbone_load(backbone_weight)

        self.criteria = build_criteria(criteria)
        self.freeze_backbone = freeze_backbone
        self.use_lora = use_lora

        if self.use_lora:
            lora_config = LoraConfig(
                r=lora_r,
                lora_alpha=lora_alpha,
                target_modules=["qkv"],
                lora_dropout=lora_dropout,
                bias="none",
            )
            self.backbone.enc = get_peft_model(self.backbone.enc, lora_config)

        if self.freeze_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False
        if self.use_lora:
            for name, param in self.backbone.named_parameters():
                if "lora_" in name:
                    param.requires_grad = True
        self.backbone.enc.print_trainable_parameters()

    def backbone_load(self, checkpoint):
        weight = OrderedDict()
        for key, value in checkpoint["state_dict"].items():
            if not key.startswith("module."):
                key = "module." + key  # xxx.xxx -> module.xxx.xxx
            # Now all keys contain "module." no matter DDP or not.
            if self.keywords in key:
                key = key.replace(self.keywords, self.replacements)
            key = key[7:]  # module.xxx.xxx -> xxx.xxx
            if key.startswith("backbone."):
                key = key[9:]
            weight[key] = value

        load_state_info = self.backbone.load_state_dict(weight, strict=False)
        print(f"Missing keys: {load_state_info[0]}")
        print(f"Unexpected keys: {load_state_info[1]}")

    def forward(self, input_dict, return_point=False):
        point = Point(input_dict)
        if self.freeze_backbone and not self.use_lora:
            with torch.no_grad():
                point = self.backbone(point)
        else:
            point = self.backbone(point)

        if isinstance(point, Point):
            while "pooling_parent" in point.keys():
                assert "pooling_inverse" in point.keys()
                parent = point.pop("pooling_parent")
                inverse = point.pop("pooling_inverse")
                parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1)
                point = parent
            feat = point.feat
        else:
            feat = point

        seg_logits = self.seg_head(feat)
        return_dict = dict()
        if return_point:
            return_dict["point"] = point

        if self.training:
            loss = self.criteria(seg_logits, input_dict["segment"])
            return_dict["loss"] = loss
        elif "segment" in input_dict.keys():
            loss = self.criteria(seg_logits, input_dict["segment"])
            return_dict["loss"] = loss
            return_dict["seg_logits"] = seg_logits
        else:
            return_dict["seg_logits"] = seg_logits
        return return_dict


from transformers import AutoTokenizer, AutoModel


@MODELS.register_module()
class SIGLIPZSSegmentorV2(nn.Module):
    def __init__(
        self,
        text_model_path,
        backbone_out_channels,
        _num_channels,
        backbone=None,
        freeze_backbone=False,
        label_list=[],
    ):
        super().__init__()
        self.label_list = label_list
        self.tokenizer = AutoTokenizer.from_pretrained(text_model_path)
        model = AutoModel.from_pretrained(text_model_path)
        self.text_model = model.text_model
        self.text_model.requires_grad_(False)
        self.patch_proj = torch.nn.Linear(backbone_out_channels, _num_channels)
        self.patch_proj.requires_grad_(False)
        self.backbone = build_model(backbone)
        self.freeze_backbone = freeze_backbone
        if self.freeze_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False

    def get_embeddings(self, texts):
        inputs = self.tokenizer(
            texts, return_tensors="pt", padding=True, truncation=True
        ).to("cuda")
        with torch.no_grad():
            outputs = self.text_model(**inputs)
            embeddings = outputs.last_hidden_state.mean(dim=1)
            # embeddings = embeddings - embeddings.mean(dim=-1, keepdim=True)
            embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
        return embeddings

    def forward(self, input_dict, return_point=False):
        point = Point(input_dict)
        if self.freeze_backbone:
            with torch.no_grad():
                point = self.backbone(point)
        else:
            point = self.backbone(point)
        # Backbone added after v1.5.0 return Point instead of feat and use DefaultSegmentorV2
        # TODO: remove this part after make all backbone return Point only.
        if isinstance(point, Point):
            while "pooling_parent" in point.keys():
                assert "pooling_inverse" in point.keys()
                parent = point.pop("pooling_parent")
                inverse = point.pop("pooling_inverse")
                parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1)
                point = parent
            feat = point.feat
        else:
            feat = point
        feat = self.patch_proj(feat)
        # feat = feat - feat.mean(dim=-1, keepdim=True)
        feat = feat / feat.norm(dim=-1, keepdim=True)
        # texts = [label for label in self.label_list]
        texts = [f"This is a photo of {label}." for label in self.label_list]
        text_embeddings = self.get_embeddings(texts)
        logits_per_point = feat @ text_embeddings.T  # [1, n_classes]
        logits_per_point = logits_per_point / logits_per_point.sum(dim=-1, keepdim=True)
        # logits_per_point = torch.sigmoid(feat @ text_embeddings.T)  # [1, n_classes]
        # pred_class = classes[logits_per_point.argmax().item()]
        return_dict = dict()
        if return_point:
            # PCA evaluator parse feat and coord in point
            return_dict["point"] = point
        # test
        return_dict["seg_logits"] = logits_per_point
        print(logits_per_point.shape)
        print(logits_per_point[0, :])
        print(max(logits_per_point[0, :]))
        exit()
        return return_dict


@MODELS.register_module()  # text intra align
class SIGLIPSegmentorV2(nn.Module):
    def __init__(
        self,
        text_model_path,
        backbone_out_channels,
        _num_channels,
        backbone=None,
        freeze_backbone=False,
        label_list=[],
        ignore_index=(-1,),
        enc_mode=True,
        up_cast_level=4,
    ):
        super().__init__()
        self.enc_mode = enc_mode
        self.up_cast_level = up_cast_level
        self.seg_head = nn.Linear(backbone_out_channels, _num_channels)
        self.tokenizer = AutoTokenizer.from_pretrained(text_model_path)
        model = AutoModel.from_pretrained(text_model_path)
        self.text_model = model.text_model.cuda()
        self.text_model.requires_grad_(False)
        # self.patch_proj = torch.nn.Linear(backbone_out_channels, _num_channels)
        # self.patch_proj.requires_grad_(False)
        self.backbone = build_model(backbone)
        self.freeze_backbone = freeze_backbone
        self.label_list = label_list
        self.ignore_index = torch.tensor(ignore_index).cuda()
        texts = [label for label in self.label_list]
        # texts = [f'This is a photo of {label}.' for label in self.label_list]
        self.text_embeddings = self.get_embeddings(texts)
        if self.freeze_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False

    def up_cast(self, point, normalize=False, upcast_level=None):
        if upcast_level is None:
            upcast_level = self.up_cast_level
        else:
            upcast_level = upcast_level
        for _ in range(upcast_level):
            assert "pooling_parent" in point.keys()
            assert "pooling_inverse" in point.keys()
            parent = point.pop("pooling_parent")
            inverse = point.pop("pooling_inverse")
            # parent = point["pooling_parent"]
            # inverse = point["pooling_inverse"]
            parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1)
            point = parent
        return point

    def get_embeddings(self, texts):
        inputs = self.tokenizer(
            texts, return_tensors="pt", padding=True, truncation=True
        ).to("cuda")
        with torch.no_grad():
            outputs = self.text_model(**inputs)
            embeddings = outputs.last_hidden_state.mean(dim=1)
            # embeddings = embeddings - embeddings.mean(dim=-1, keepdim=True)
            # embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
        return embeddings

    def forward(self, input_dict, return_point=False):
        point = Point(input_dict)
        if self.freeze_backbone:
            with torch.no_grad():
                point = self.backbone(point)
        else:
            point = self.backbone(point)

        if self.enc_mode:
            point_ = self.up_cast(point)
            feat = point_.feat
        else:
            inverse_list = []
            point_feat = [point.feat]
            if isinstance(point, Point):
                while "unpooling_parent" in point.keys():
                    parent = point.pop("unpooling_parent")
                    assert "pooling_inverse" in parent.keys()
                    inverse = parent.pop("pooling_inverse")
                    inverse_list.append(inverse)
                    # print([inverse_i.shape for inverse_i in inverse_list],parent.feat.shape, point.feat.shape)
                    for inverse_i in inverse_list[::-1]:
                        parent.feat = parent.feat[inverse_i]
                    point_feat.append(parent.feat)
                    point = parent
                feat = torch.cat(point_feat, dim=-1)
            else:
                feat = point
        seg_logits = self.seg_head(feat)

        if "segment" in input_dict.keys():
            mask = ~torch.isin(input_dict["segment"], self.ignore_index)
            segment_mask = input_dict["segment"][mask]
            seg_logits_mask = seg_logits[mask]
            text_embeddings = self.text_embeddings[segment_mask]

        return_dict = dict()
        if return_point:
            # PCA evaluator parse feat and coord in point
            return_dict["point"] = point
        # train
        if self.training:
            cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
            loss = (1 - cos(seg_logits_mask, text_embeddings)).mean()
            return_dict["loss"] = loss
        # eval
        elif "segment" in input_dict.keys():
            cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
            loss = (1 - cos(seg_logits_mask, text_embeddings)).mean()
            return_dict["loss"] = loss
            batch_size = 1024  # Adjust this based on your available memory
            similarities_chunks = []

            # Split the features into chunks
            for i in range(0, seg_logits.size(0), batch_size):
                chunk = seg_logits[i : i + batch_size]  # [batch_size, 512]

                # Compute cosine similarity for this chunk
                chunk_similarities = F.cosine_similarity(
                    self.text_embeddings.unsqueeze(0),  # [1, 20, 512]
                    chunk.unsqueeze(1),  # [batch_size, 1, 512]
                    dim=2,
                )  # Result shape: [batch_size, 20]

                similarities_chunks.append(chunk_similarities)

            # Concatenate all chunks along the first dimension
            similarities = torch.cat(similarities_chunks, dim=0)  # [N, 20]
            softmax_scores = F.softmax(similarities, dim=0).squeeze()
            # max_index = torch.argmax(softmax_scores,dim=-1)
            return_dict["seg_logits"] = softmax_scores
        # test
        else:
            feat = torch.nn.functional.normalize(seg_logits_mask, dim=1)
            batch_size = 10240  # Adjust this based on your available memory
            similarities_chunks = []

            # Split the features into chunks
            for i in range(0, seg_logits.size(0), batch_size):
                chunk = seg_logits[i : i + batch_size]  # [batch_size, 512]

                # Compute cosine similarity for this chunk
                chunk_similarities = F.cosine_similarity(
                    self.text_embeddings.unsqueeze(0),  # [1, 20, 512]
                    chunk.unsqueeze(1),  # [batch_size, 1, 512]
                    dim=2,
                )  # Result shape: [batch_size, 20]

                similarities_chunks.append(chunk_similarities)

            # Concatenate all chunks along the first dimension
            similarities = torch.cat(similarities_chunks, dim=0)  # [N, 20]
            softmax_scores = F.softmax(similarities, dim=0)
            # max_index = torch.argmax(softmax_scores,dim=-1)
            return_dict["seg_logits"] = softmax_scores
        return return_dict


from pointcept.models.concerto_lseg_pro.models.lseg_net import LSegNet
import clip


@MODELS.register_module()  # text intra align
class CLIPSegmentorV2(nn.Module):
    def __init__(
        self,
        text_weight_name,
        text_weight_path,
        backbone_out_channels,
        _num_channels,
        backbone=None,
        freeze_backbone=False,
        labels_list=[],
        ignore_index=(-1,),
        enc_mode=True,
        up_cast_level=4,
    ):
        super().__init__()
        self.labels_list = labels_list
        self.enc_mode = enc_mode
        self.up_cast_level = up_cast_level
        self.seg_head = nn.Linear(backbone_out_channels, _num_channels)
        self.lseg_model = self.load_lseg(text_weight_name, text_weight_path)
        self.lseg_model.requires_grad_(False)
        self.clip_text_encoder = self.lseg_model.clip_pretrained.encode_text
        self.backbone = build_model(backbone)
        self.freeze_backbone = freeze_backbone
        self.ignore_index = torch.tensor(ignore_index).cuda()
        self.CLIPText_forward()
        if self.freeze_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False

    def up_cast(self, point, normalize=False, upcast_level=None):
        if upcast_level is None:
            upcast_level = self.up_cast_level
        else:
            upcast_level = upcast_level
        for _ in range(upcast_level):
            assert "pooling_parent" in point.keys()
            assert "pooling_inverse" in point.keys()
            parent = point.pop("pooling_parent")
            inverse = point.pop("pooling_inverse")
            # parent = point["pooling_parent"]
            # inverse = point["pooling_inverse"]
            parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1)
            point = parent
        return point

    def load_lseg(
        self,
        model_name,
        model_weight="openclip/open_clip_pytorch_model.bin",
    ):
        net = LSegNet(
            labels=self.labels_list,
            backbone=model_name,
            features=256,
            crop_size=480,
            arch_option=0,
            block_depth=0,
            activation="lrelu",
        )
        weight = torch.load(str(model_weight))["state_dict"]
        weight_key_list = list(weight.keys())
        for key in weight_key_list:
            if key.startswith("net."):
                new_key = key.replace("net.", "")
                weight[new_key] = weight[key]
                del weight[key]
        net.load_state_dict(weight)
        net.eval()
        net.cuda()
        return net

    def CLIPText_forward(self):
        with torch.no_grad():
            prompt = [clip.tokenize(lc).cuda() for lc in self.labels_list]
            text_feat_list = [self.clip_text_encoder(p) for p in prompt]
            self.text_embeddings = torch.cat(
                [torch.nn.functional.normalize(tf) for tf in text_feat_list], dim=0
            )

    def forward(self, input_dict, return_point=False):
        point = Point(input_dict)
        if self.freeze_backbone:
            with torch.no_grad():
                point = self.backbone(point)
        else:
            point = self.backbone(point)
        # Backbone added after v1.5.0 return Point instead of feat and use DefaultSegmentorV2
        # TODO: remove this part after make all backbone return Point only.
        if self.enc_mode:
            point_ = self.up_cast(point)
            feat = point_.feat
        else:
            inverse_list = []
            point_feat = [point.feat]
            if isinstance(point, Point):
                while "unpooling_parent" in point.keys():
                    parent = point.pop("unpooling_parent")
                    assert "pooling_inverse" in parent.keys()
                    inverse = parent.pop("pooling_inverse")
                    inverse_list.append(inverse)
                    # print([inverse_i.shape for inverse_i in inverse_list],parent.feat.shape, point.feat.shape)
                    for inverse_i in inverse_list[::-1]:
                        parent.feat = parent.feat[inverse_i]
                    point_feat.append(parent.feat)
                    point = parent
                feat = torch.cat(point_feat, dim=-1)
            else:
                feat = point
        seg_logits = self.seg_head(feat)
        if "segment" in input_dict.keys():
            mask = ~torch.isin(input_dict["segment"], self.ignore_index)
            segment_mask = input_dict["segment"][mask]
            seg_logits_mask = seg_logits[mask]
            text_embeddings = self.text_embeddings[segment_mask]
        return_dict = dict()
        if return_point:
            # PCA evaluator parse feat and coord in point
            return_dict["point"] = point
        # train
        if self.training:
            cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
            loss = (1 - cos(seg_logits_mask, text_embeddings)).mean()
            return_dict["loss"] = loss
        # eval
        elif "segment" in input_dict.keys():
            cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
            loss = (1 - cos(seg_logits_mask, text_embeddings)).mean()
            return_dict["loss"] = loss
            batch_size = 1024  # Adjust this based on your available memory
            similarities_chunks = []

            # Split the features into chunks
            for i in range(0, seg_logits.size(0), batch_size):
                chunk = seg_logits[i : i + batch_size]  # [batch_size, 512]

                # Compute cosine similarity for this chunk
                chunk_similarities = F.cosine_similarity(
                    self.text_embeddings.unsqueeze(0),  # [1, 20, 512]
                    chunk.unsqueeze(1),  # [batch_size, 1, 512]
                    dim=2,
                )  # Result shape: [batch_size, 20]

                similarities_chunks.append(chunk_similarities)

            # Concatenate all chunks along the first dimension
            similarities = torch.cat(similarities_chunks, dim=0)  # [N, 20]
            softmax_scores = F.softmax(similarities, dim=0).squeeze()
            # max_index = torch.argmax(softmax_scores,dim=-1)
            return_dict["seg_logits"] = softmax_scores
        # test
        else:
            feat = torch.nn.functional.normalize(seg_logits_mask, dim=1)
            batch_size = 10240  # Adjust this based on your available memory
            similarities_chunks = []

            # Split the features into chunks
            for i in range(0, seg_logits.size(0), batch_size):
                chunk = seg_logits[i : i + batch_size]  # [batch_size, 512]

                # Compute cosine similarity for this chunk
                chunk_similarities = F.cosine_similarity(
                    self.text_embeddings.unsqueeze(0),  # [1, 20, 512]
                    chunk.unsqueeze(1),  # [batch_size, 1, 512]
                    dim=2,
                )  # Result shape: [batch_size, 20]

                similarities_chunks.append(chunk_similarities)

            # Concatenate all chunks along the first dimension
            similarities = torch.cat(similarities_chunks, dim=0)  # [N, 20]
            softmax_scores = F.softmax(similarities, dim=0)
            # max_index = torch.argmax(softmax_scores,dim=-1)
            return_dict["seg_logits"] = softmax_scores
        return return_dict


@MODELS.register_module()
class DINOEnhancedSegmentor(nn.Module):
    def __init__(
        self,
        num_classes,
        backbone_out_channels,
        backbone=None,
        criteria=None,
        freeze_backbone=False,
    ):
        super().__init__()
        self.seg_head = (
            nn.Linear(backbone_out_channels, num_classes)
            if num_classes > 0
            else nn.Identity()
        )
        self.backbone = build_model(backbone) if backbone is not None else None
        self.criteria = build_criteria(criteria)
        self.freeze_backbone = freeze_backbone
        if self.backbone is not None and self.freeze_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False

    def forward(self, input_dict, return_point=False):
        point = Point(input_dict)
        if self.backbone is not None:
            if self.freeze_backbone:
                with torch.no_grad():
                    point = self.backbone(point)
            else:
                point = self.backbone(point)
            point_list = [point]
            while "unpooling_parent" in point_list[-1].keys():
                point_list.append(point_list[-1].pop("unpooling_parent"))
            for i in reversed(range(1, len(point_list))):
                point = point_list[i]
                parent = point_list[i - 1]
                assert "pooling_inverse" in point.keys()
                inverse = point.pooling_inverse
                parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1)
            point = point_list[0]
            while "pooling_parent" in point.keys():
                assert "pooling_inverse" in point.keys()
                parent = point.pop("pooling_parent")
                inverse = point.pooling_inverse
                parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1)
                point = parent
            feat = [point.feat]
        else:
            feat = []
        dino_coord = input_dict["dino_coord"]
        dino_feat = input_dict["dino_feat"]
        dino_offset = input_dict["dino_offset"]
        idx = torch_cluster.knn(
            x=dino_coord,
            y=point.origin_coord,
            batch_x=offset2batch(dino_offset),
            batch_y=offset2batch(point.origin_offset),
            k=1,
        )[1]

        feat.append(dino_feat[idx])
        feat = torch.concatenate(feat, dim=-1)
        seg_logits = self.seg_head(feat)
        return_dict = dict()
        if return_point:
            # PCA evaluator parse feat and coord in point
            return_dict["point"] = point
        # train
        if self.training:
            loss = self.criteria(seg_logits, input_dict["segment"])
            return_dict["loss"] = loss
        # eval
        elif "segment" in input_dict.keys():
            loss = self.criteria(seg_logits, input_dict["segment"])
            return_dict["loss"] = loss
            return_dict["seg_logits"] = seg_logits
        # test
        else:
            return_dict["seg_logits"] = seg_logits
        return return_dict


@MODELS.register_module()
class SONATAEnhancedSegmentor(nn.Module):
    def __init__(
        self,
        num_classes,
        backbone_out_channels,
        backbone=None,
        sonata_backbone=None,
        sonata_backbone_path=None,
        criteria=None,
        freeze_backbone=False,
    ):
        super(SONATAEnhancedSegmentor, self).__init__()
        self.seg_head = (
            nn.Linear(backbone_out_channels, num_classes)
            if num_classes > 0
            else nn.Identity()
        )

        def load_sonata(
            model,
            path="sonata_release/ckpt/sonata.pth",
        ):
            checkpoint = torch.load(
                path, map_location=lambda storage, loc: storage.cuda()
            )
            name_test = [n for n, p in model.named_parameters()]
            # print(checkpoint.keys())
            # print(name_test)
            weight = {}
            whether_weight = False
            if "state_dict" in checkpoint.keys():
                checkpoint = checkpoint["state_dict"]
                for key, value in checkpoint.items():
                    if "module.student.backbone." in key:
                        whether_weight = True
                        key = key.replace("module.student.backbone.", "module.")
                        key = key[7:]  # module.xxx.xxx -> xxx.xxx
                        weight[key] = value
            if whether_weight:
                load_state_info = model.load_state_dict(weight)
            else:
                checkpoint.pop("embedding.mask_token")
                load_state_info = model.load_state_dict(checkpoint)
            print(f"Missing keys: {load_state_info[0]}")
            print(f"Unexpected keys: {load_state_info[1]}")
            return model

        self.backbone = build_model(backbone) if backbone is not None else None
        self.sonata_backbone = (
            build_model(sonata_backbone) if backbone is not None else None
        )
        self.sonata_backbone = (
            load_sonata(self.sonata_backbone, sonata_backbone_path)
            if sonata_backbone_path is not None
            else load_sonata(self.sonata_backbone)
        )
        for p in self.sonata_backbone.parameters():
            p.requires_grad = False
        self.criteria = build_criteria(criteria)
        self.freeze_backbone = freeze_backbone
        if self.backbone is not None and self.freeze_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False

    def forward(self, input_dict, return_point=False):
        point_input = Point(input_dict)
        if self.backbone is not None:
            if self.freeze_backbone:
                with torch.no_grad():
                    point = self.backbone(point_input)
            else:
                point = self.backbone(point_input)
            point_list = [point]
            while "unpooling_parent" in point_list[-1].keys():
                point_list.append(point_list[-1].pop("unpooling_parent"))
            for i in reversed(range(1, len(point_list))):
                point = point_list[i]
                parent = point_list[i - 1]
                assert "pooling_inverse" in point.keys()
                inverse = point.pooling_inverse
                parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1)
            point = point_list[0]
            while "pooling_parent" in point.keys():
                assert "pooling_inverse" in point.keys()
                parent = point.pop("pooling_parent")
                inverse = point.pooling_inverse
                parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1)
                point = parent
            feat = [point.feat]
        else:
            feat = []
        sonata_point = self.sonata_backbone(point_input)
        sonata_point_list = [sonata_point]
        while "unpooling_parent" in sonata_point_list[-1].keys():
            sonata_point_list.append(sonata_point_list[-1].pop("unpooling_parent"))
        for i in reversed(range(1, len(sonata_point_list))):
            sonata_point = sonata_point_list[i]
            sonata_parent = sonata_point_list[i - 1]
            assert "pooling_inverse" in sonata_point.keys()
            inverse = sonata_point.pooling_inverse
            sonata_parent.feat = torch.cat(
                [sonata_parent.feat, sonata_point.feat[inverse]], dim=-1
            )
        sonata_point = sonata_point_list[0]
        while "pooling_parent" in sonata_point.keys():
            assert "pooling_inverse" in sonata_point.keys()
            sonata_parent = sonata_point.pop("pooling_parent")
            inverse = sonata_point.pooling_inverse
            sonata_parent.feat = torch.cat(
                [sonata_parent.feat, sonata_point.feat[inverse]], dim=-1
            )
            sonata_point = sonata_parent
        feat.append(sonata_point.feat)
        feat = torch.concatenate(feat, dim=-1)
        seg_logits = self.seg_head(feat)
        return_dict = dict()
        if return_point:
            # PCA evaluator parse feat and coord in point
            return_dict["point"] = point
        # train
        if self.training:
            loss = self.criteria(seg_logits, input_dict["segment"])
            return_dict["loss"] = loss
        # eval
        elif "segment" in input_dict.keys():
            loss = self.criteria(seg_logits, input_dict["segment"])
            return_dict["loss"] = loss
            return_dict["seg_logits"] = seg_logits
        # test
        else:
            return_dict["seg_logits"] = seg_logits
        return return_dict


@MODELS.register_module()
class DefaultClassifier(nn.Module):
    def __init__(
        self,
        backbone=None,
        criteria=None,
        num_classes=40,
        backbone_embed_dim=256,
    ):
        super().__init__()
        self.backbone = build_model(backbone)
        self.criteria = build_criteria(criteria)
        self.num_classes = num_classes
        self.backbone_embed_dim = backbone_embed_dim
        self.cls_head = nn.Sequential(
            nn.Linear(backbone_embed_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(128, num_classes),
        )

    def forward(self, input_dict):
        point = Point(input_dict)
        point = self.backbone(point)
        # Backbone added after v1.5.0 return Point instead of feat
        # And after v1.5.0 feature aggregation for classification operated in classifier
        # TODO: remove this part after make all backbone return Point only.
        if isinstance(point, Point):
            point.feat = torch_scatter.segment_csr(
                src=point.feat,
                indptr=nn.functional.pad(point.offset, (1, 0)),
                reduce="mean",
            )
            feat = point.feat
        else:
            feat = point
        cls_logits = self.cls_head(feat)
        if self.training:
            loss = self.criteria(cls_logits, input_dict["category"])
            return dict(loss=loss)
        elif "category" in input_dict.keys():
            loss = self.criteria(cls_logits, input_dict["category"])
            return dict(loss=loss, cls_logits=cls_logits)
        else:
            return dict(cls_logits=cls_logits)
