import torch
import torch.nn as nn
import MinkowskiEngine as ME
from MinkowskiEngine import SparseTensor
from timm.layers import trunc_normal_

from .mink_layers import MinkConvBNRelu, MinkResBlock
from .swin3d_layers import GridDownsample, GridKNNDownsample, BasicLayer, Upsample
from pointcept.models.builder import MODELS
from pointcept.models.utils import offset2batch, batch2offset


@MODELS.register_module("Swin3D-v1m1")
class Swin3DUNet(nn.Module):
    def __init__(
        self,
        in_channels,
        num_classes,
        base_grid_size,
        depths,
        channels,
        num_heads,
        window_sizes,
        quant_size,
        drop_path_rate=0.2,
        up_k=3,
        num_layers=5,
        stem_transformer=True,
        down_stride=2,
        upsample="linear",
        knn_down=True,
        cRSE="XYZ_RGB",
        fp16_mode=0,
    ):
        super().__init__()
        dpr = [
            x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
        ]  # stochastic depth decay rule
        if knn_down:
            downsample = GridKNNDownsample
        else:
            downsample = GridDownsample

        self.cRSE = cRSE
        if stem_transformer:
            self.stem_layer = MinkConvBNRelu(
                in_channels=in_channels,
                out_channels=channels[0],
                kernel_size=3,
                stride=1,
            )
            self.layer_start = 0
        else:
            self.stem_layer = nn.Sequential(
                MinkConvBNRelu(
                    in_channels=in_channels,
                    out_channels=channels[0],
                    kernel_size=3,
                    stride=1,
                ),
                MinkResBlock(in_channels=channels[0], out_channels=channels[0]),
            )
            self.downsample = downsample(
                channels[0], channels[1], kernel_size=down_stride, stride=down_stride
            )
            self.layer_start = 1
        self.layers = nn.ModuleList(
            [
                BasicLayer(
                    dim=channels[i],
                    depth=depths[i],
                    num_heads=num_heads[i],
                    window_size=window_sizes[i],
                    quant_size=quant_size,
                    drop_path=dpr[sum(depths[:i]) : sum(depths[: i + 1])],
                    downsample=downsample if i < num_layers - 1 else None,
                    down_stride=down_stride if i == 0 else 2,
                    out_channels=channels[i + 1] if i < num_layers - 1 else None,
                    cRSE=cRSE,
                    fp16_mode=fp16_mode,
                )
                for i in range(self.layer_start, num_layers)
            ]
        )

        if "attn" in upsample:
            up_attn = True
        else:
            up_attn = False

        self.upsamples = nn.ModuleList(
            [
                Upsample(
                    channels[i],
                    channels[i - 1],
                    num_heads[i - 1],
                    window_sizes[i - 1],
                    quant_size,
                    attn=up_attn,
                    up_k=up_k,
                    cRSE=cRSE,
                    fp16_mode=fp16_mode,
                )
                for i in range(num_layers - 1, 0, -1)
            ]
        )

        self.classifier = nn.Sequential(
            nn.Linear(channels[0], channels[0]),
            nn.BatchNorm1d(channels[0]),
            nn.ReLU(inplace=True),
            nn.Linear(channels[0], num_classes),
        )
        self.num_classes = num_classes
        self.base_grid_size = base_grid_size
        self.init_weights()

    def forward(self, data_dict):
        grid_coord = data_dict["grid_coord"]
        feat = data_dict["feat"]
        coord_feat = data_dict["coord_feat"]
        coord = data_dict["coord"]
        offset = data_dict["offset"]
        batch = offset2batch(offset)
        in_field = ME.TensorField(
            features=torch.cat(
                [
                    batch.unsqueeze(-1),
                    coord / self.base_grid_size,
                    coord_feat / 1.001,
                    feat,
                ],
                dim=1,
            ),
            coordinates=torch.cat([batch.unsqueeze(-1).int(), grid_coord.int()], dim=1),
            quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE,
            minkowski_algorithm=ME.MinkowskiAlgorithm.SPEED_OPTIMIZED,
            device=feat.device,
        )

        sp = in_field.sparse()
        coords_sp = SparseTensor(
            features=sp.F[:, : coord_feat.shape[-1] + 4],
            coordinate_map_key=sp.coordinate_map_key,
            coordinate_manager=sp.coordinate_manager,
        )
        sp = SparseTensor(
            features=sp.F[:, coord_feat.shape[-1] + 4 :],
            coordinate_map_key=sp.coordinate_map_key,
            coordinate_manager=sp.coordinate_manager,
        )
        sp_stack = []
        coords_sp_stack = []
        sp = self.stem_layer(sp)
        if self.layer_start > 0:
            sp_stack.append(sp)
            coords_sp_stack.append(coords_sp)
            sp, coords_sp = self.downsample(sp, coords_sp)

        for i, layer in enumerate(self.layers):
            coords_sp_stack.append(coords_sp)
            sp, sp_down, coords_sp = layer(sp, coords_sp)
            sp_stack.append(sp)
            assert (coords_sp.C == sp_down.C).all()
            sp = sp_down

        sp = sp_stack.pop()
        coords_sp = coords_sp_stack.pop()
        for i, upsample in enumerate(self.upsamples):
            sp_i = sp_stack.pop()
            coords_sp_i = coords_sp_stack.pop()
            sp = upsample(sp, coords_sp, sp_i, coords_sp_i)
            coords_sp = coords_sp_i

        output = self.classifier(sp.slice(in_field).F)
        return output

    def init_weights(self):
        """Initialize the weights in backbone."""

        def _init_weights(m):
            if isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=0.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm) or isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)

        self.apply(_init_weights)
