import torch.nn as nn
import torch_scatter

from pointcept.models.losses import build_criteria, LJLoss
from pointcept.models.utils.structure import Point
from .builder import MODELS, build_model


@MODELS.register_module()
class DefaultSegmentor(nn.Module):
    def __init__(self, backbone=None, criteria=None):
        super().__init__()
        self.backbone = build_model(backbone)
        self.criteria = build_criteria(criteria)
        self.lj_loss = LJLoss()

    def forward(self, input_dict):
        if "condition" in input_dict.keys():
            # PPT (https://arxiv.org/abs/2308.09718)
            # currently, only support one batch one condition
            input_dict["condition"] = input_dict["condition"][0]
        seg_logits, loss_feat = self.backbone(input_dict)
        # train
        if self.training:
            lj_loss = self.lj_loss(loss_feat)
            loss = self.criteria(seg_logits, input_dict["segment"]) + lj_loss
            return dict(loss=loss, lj_loss=lj_loss)
        # eval
        elif "segment" in input_dict.keys():
            lj_loss = self.lj_loss(loss_feat)
            loss = self.criteria(seg_logits, input_dict["segment"]) + lj_loss
            return dict(loss=loss, seg_logits=seg_logits)
        # test
        else:
            return dict(seg_logits=seg_logits)



@MODELS.register_module()
class DefaultSegmentorV2(nn.Module):
    def __init__(
        self,
        num_classes,
        backbone_out_channels,
        backbone=None,
        criteria=None,
    ):
        super().__init__()
        self.seg_head = (
            nn.Linear(backbone_out_channels, num_classes)
            if num_classes > 0
            else nn.Identity()
        )
        self.backbone = build_model(backbone)
        self.criteria = build_criteria(criteria)
        self.lj_loss = LJLoss()

    def forward(self, input_dict):
        point = Point(input_dict)
        point = self.backbone(point)

        # print(point.keys())
        assert "loss_feat" in point.keys()
        loss_feat = point["loss_feat"]

        # Backbone added after v1.5.0 return Point instead of feat and use DefaultSegmentorV2
        # TODO: remove this part after make all backbone return Point only.
        if isinstance(point, Point):
            feat = point.feat
        else:
            feat = point
        seg_logits = self.seg_head(feat)
        # train
        if self.training:
            lj_loss = self.lj_loss(loss_feat)
            loss = self.criteria(seg_logits, input_dict["segment"]) + lj_loss
            return dict(loss=loss, lj_loss=lj_loss)
        # eval
        elif "segment" in input_dict.keys():
            lj_loss = self.lj_loss(loss_feat)
            loss = self.criteria(seg_logits, input_dict["segment"]) + lj_loss
            return dict(loss=loss, seg_logits=seg_logits)
        # test
        else:
            return dict(seg_logits=seg_logits)


@MODELS.register_module()
class DefaultClassifier(nn.Module):
    def __init__(
        self,
        backbone=None,
        criteria=None,
        num_classes=40,
        backbone_embed_dim=256,
    ):
        super().__init__()
        self.backbone = build_model(backbone)
        self.criteria = build_criteria(criteria)
        self.num_classes = num_classes
        self.backbone_embed_dim = backbone_embed_dim
        self.cls_head = nn.Sequential(
            nn.Linear(backbone_embed_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(128, num_classes),
        )
        self.lj_loss = LJLoss()

    def forward(self, input_dict):
        point = Point(input_dict)
        point = self.backbone(point)

        assert "loss_feat" in point.keys()
        loss_feat = point["loss_feat"]
        # Backbone added after v1.5.0 return Point instead of feat
        # And after v1.5.0 feature aggregation for classification operated in classifier
        # TODO: remove this part after make all backbone return Point only.
        if isinstance(point, Point):
            point.feat = torch_scatter.segment_csr(
                src=point.feat,
                indptr=nn.functional.pad(point.offset, (1, 0)),
                reduce="mean",
            )
            feat = point.feat
        else:
            feat = point
        cls_logits = self.cls_head(feat)
        if self.training:
            lj_loss = self.lj_loss(loss_feat)
            loss = self.criteria(cls_logits, input_dict["category"]) + lj_loss
            return dict(loss=loss, lj_loss=lj_loss)
        elif "category" in input_dict.keys():
            lj_loss = self.lj_loss(loss_feat)
            loss = self.criteria(cls_logits, input_dict["category"]) + lj_loss
            return dict(loss=loss, cls_logits=cls_logits)
        else:
            return dict(cls_logits=cls_logits)
