"""
@Description :   碎片拼接网络的特征提取部分
@Author      :   tqychy 
@Time        :   2024/12/30 16:55:55
"""
import sys

sys.path.append("./")

import torch
import torch.nn as nn

from nets.utils.gcn import DeepGCN


# ------------ ResGCN ------------ #
class FlattenNet(nn.Module):
    """
    This is the net to encode the patches to point feature.[bs, n, 64]
    """

    def __init__(self, input_dim, output_dim):
        super(FlattenNet, self).__init__()
        self.fc1 = nn.Linear(input_dim, output_dim, bias=False)
        self.activate = nn.ReLU(inplace=True)

    def forward(self, x):
        bs, n, _, _ = x.size()
        x = x.view(bs, n, -1)
        x = self.activate(self.fc1(x))
        return x


class Conv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
                 padding=None, groups=1, activation=True):
        super(Conv, self).__init__()
        padding = kernel_size // 2 if padding is None else padding
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride,
                              padding, groups=groups, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.act = nn.ReLU(inplace=True) if activation else nn.Identity()

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))


class FlattenNet_average(nn.Module):
    """
    This is the net to encode the patches to point feature.[bs, n, 64]
    """

    def __init__(self, input_dim, output_dim):
        super(FlattenNet_average, self).__init__()
        self.fc1 = nn.Linear(input_dim,
                             output_dim, bias=False)
        self.activate = nn.ReLU(inplace=True)

        self.conv1 = Conv(1, 64, kernel_size=3, stride=1, padding=0)
        # self.max_pool = nn.MaxPool2d(kernel_size=7, stride=1)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):
        bs, n, patch_size, patch_size = x.size()
        x = x.view(bs*n, 1, patch_size, patch_size)
        x = self.conv1(x)
        x = self.avg_pool(x)
        x = x.view(bs, n, -1)
        return x


class PatchEncoder_average(nn.Module):
    def __init__(self):
        super(PatchEncoder_average, self).__init__()
        self.conv1 = Conv(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = Conv(32, 64, kernel_size=3, stride=1, padding=0)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.avg_pool(x)
        return x


class ResGCNFeatureExtract(nn.Module):
    """
        This is the pipeline for considering both contour and texture information.
        In this pipeline, we concatenate the contour feature and texture feature
        directly.
    """

    def __init__(self, *args):
        super().__init__()
        self.cfg, self.logger = args
        input_dim = self.cfg.NET.PATCH_SIZE ** 2
        output_dim = self.cfg.NET.FEATURE_EXTRACT_DIM

        self.flatten_net = FlattenNet(input_dim, output_dim)
        self.encoder_c = DeepGCN(self.cfg, self.logger)

        self.flatten_net = FlattenNet_average(input_dim, output_dim)
        self.encoder_t = PatchEncoder_average()

        self.gcn_t = DeepGCN(self.cfg, self.logger)
        self.fc = nn.Linear(output_dim+2, output_dim)
        self.c_feature = output_dim

    def forward(self, c_input: torch.Tensor, t_input: torch.Tensor, pcd: torch.Tensor, adj: torch.Tensor):

        bs, n, _, _ = c_input.shape
        c_feature = self.c_feature

        pcd += torch.tensor([1, 1]).cuda()

        flatted_c = self.flatten_net(c_input)

        flatted_c = flatted_c.view(bs, n, -1)
        contour_in_c = pcd - torch.mean(pcd, dim=1, keepdim=True)
        contour_in_c -= torch.tensor([1, 1]).cuda()
        flatted_c = torch.cat((flatted_c, contour_in_c), dim=-1)
        flatted_c = self.fc(flatted_c)
        flatted_c = flatted_c.view(-1, c_feature)
        l_c = self.encoder_c(flatted_c, adj)
        l_c = l_c.view(bs, n, -1)

        t_input = t_input.view(bs*n, 3, 7, 7)
        l_t = self.encoder_t(t_input)  # bs*n, 64, 1, 1
        l_t = l_t.view(bs*n, -1)
        l_t = self.gcn_t(l_t, adj)
        l_t = l_t.view(bs, n, -1)

        return l_c, l_t


# ------------ ViT ------------ #
class BaseViTFeatueExtract(nn.Module):
    def __init__(self, channels, *args):
        super().__init__()
        self.cfg, self.logger = args

        dim = self.cfg.NET.FEATURE_EXTRACT_DIM
        depth = self.cfg.NET.BLOCKS
        seq_len = self.cfg.DATASET.CONTOUR_MAX_LEN
        num_heads = self.cfg.NET.VIT.NUM_HEADS
        self.patch_size = self.cfg.NET.PATCH_SIZE

        self.patch_embed = nn.Linear(channels * self.patch_size ** 2, dim)

        # 位置编码
        # TODO: 二维坐标 MLP 成位置编码
        self.pos_embed = nn.Parameter(torch.zeros(1, seq_len, dim))

        # Transformer编码器
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=dim,
            nhead=num_heads,
            batch_first=False  # PyTorch默认的Transformer需要序列维度在前
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer, num_layers=depth)

    def forward(self, patch_inputs: torch.Tensor, pos_embeds: torch.Tensor):
        # 输入形状: (bs, seq_len, c, 7, 7); (bs, seq_len, 2)
        bs, seq_len = patch_inputs.size(0), patch_inputs.size(1)

        # 展平并投影每个patch
        patch_inputs = patch_inputs.view(
            bs * seq_len, -1)  # (bs*seq_len, c*7*7)
        patch_inputs = self.patch_embed(patch_inputs)       # (bs*seq_len, dim)
        patch_inputs = patch_inputs.view(
            bs, seq_len, -1)   # (bs, seq_len, dim)

        # 添加位置编码（边缘归一化坐标）
        patch_inputs = patch_inputs + self.pos_embed
        # patch_inputs = torch.cat([patch_inputs, pos_embeds]) # (bs, seq_len, 2 + dim)
        # patch_inputs = torch.add(
        #     patch_inputs, pos_embeds)  # (bs, seq_len, dim)

        # 调整维度以适应PyTorch Transformer (seq_len, bs, dim)
        patch_inputs = patch_inputs.permute(1, 0, 2)

        # 通过Transformer编码器
        patch_inputs = self.transformer_encoder(
            patch_inputs)  # 输出形状 (seq_len, bs, dim)

        # 恢复维度到 (bs, seq_len, dim)
        patch_inputs = patch_inputs.permute(1, 0, 2)

        return patch_inputs


class ViTFeatureExtract(nn.Module):
    def __init__(self, *args):
        super().__init__()
        self.cfg, self.logger = args
        self.texture_feature_extract = BaseViTFeatueExtract(3, *args)
        self.contour_feature_extract = BaseViTFeatueExtract(1, *args)
    def forward(self, c_input, t_input, pcd):
        """
        c_input: (bs * seq_len * 3 * ps * ps)
        t_input: (bs * seq_len * ps * ps)
        pcd: (bs * seq_len * 2)
        """
        pcd += torch.tensor([1, 1]).cuda()
        contour_in_c = pcd - torch.mean(pcd, dim=1, keepdim=True)
        contour_in_c -= torch.tensor([1, 1]).cuda()

        fc = self.contour_feature_extract(c_input, contour_in_c)
        ft = self.texture_feature_extract(t_input, contour_in_c)

        return fc, ft