
import torch.nn as nn
from torch.nn import Module


from .backbone import Backbone
from .PSA import PSA
from .MPGM import MPGM
from .Decoder import Decoder

class Denoiser_Block(Module):
    def __init__(self, PSA, Decoder, MPGM):
        super(Denoiser_Block, self).__init__()

        self.PSA = PSA
        self.Decoder = Decoder
        self.MPGM = MPGM
        self.w_1 = 1
        self.w_2 = 2


    def forward(self, x):
        feature, position_noisy = x['feature'], x['position']

        feature = self.PSA(feature, position_noisy)

        d_i = self.Decoder(feature)
        position_denoised_seed = position_noisy + d_i * self.w_1

        position_mirror_point = position_noisy + d_i * self.w_2
        d_i_tilde = self.Decoder(self.MPGM(feature, position_noisy, position_mirror_point))
        position_denoised_mirror_point = position_mirror_point + d_i_tilde

        return {'feature': feature,
                'position': position_denoised_seed,
                'position_DMP': position_denoised_mirror_point}


class Denoise(Module):
    def __init__(self, cfg):
        super(Denoise, self).__init__()
        self.Backbone = Backbone(cfg.Backbone)
        self.Denoiser_Blocks = nn.ModuleList()

        layer_PSA = PSA(cfg.PSA)
        layer_Decoder = Decoder(cfg.Decoder)
        layer_MPGM = MPGM(cfg.MPGM)
        for num in range(cfg.Layer_Num):
            self.Denoiser_Blocks.append(
                Denoiser_Block(
                    layer_PSA,
                    layer_Decoder,
                    layer_MPGM
                )
            )



    def forward(self, data):
        point_set = self.Backbone(data['pcl_noisy'])

        position = []
        position_DMP = []
        for DB in self.Denoiser_Blocks:
            point_set = DB(point_set)
            position.append(point_set['position'])
            position_DMP.append(point_set['position_DMP'])
        o = {
            'denoised': position,
            'denoised_MP': position_DMP,
        }

        return o



def get_model(cfg):
    model = Denoise(cfg)
    return model



