import torch
from torch.autograd import Function
import math
import numpy as np
import pywt
import torch.nn as nn
import torch.nn.functional as F
from .EMA import EMA3D

class DWTFunction_3D(Function):
    @staticmethod
    def forward(ctx, input,
                matrix_Low_0, matrix_Low_1, matrix_Low_2,
                matrix_High_0, matrix_High_1, matrix_High_2):
        ctx.save_for_backward(matrix_Low_0, matrix_Low_1, matrix_Low_2,
                              matrix_High_0, matrix_High_1, matrix_High_2)
        L = torch.matmul(matrix_Low_0, input)
        H = torch.matmul(matrix_High_0, input)
        LL = torch.matmul(L, matrix_Low_1).transpose(dim0=2, dim1=3)
        LH = torch.matmul(L, matrix_High_1).transpose(dim0=2, dim1=3)
        HL = torch.matmul(H, matrix_Low_1).transpose(dim0=2, dim1=3)
        HH = torch.matmul(H, matrix_High_1).transpose(dim0=2, dim1=3)
        LLL = torch.matmul(matrix_Low_2, LL).transpose(dim0=2, dim1=3)
        LLH = torch.matmul(matrix_Low_2, LH).transpose(dim0=2, dim1=3)
        LHL = torch.matmul(matrix_Low_2, HL).transpose(dim0=2, dim1=3)
        LHH = torch.matmul(matrix_Low_2, HH).transpose(dim0=2, dim1=3)
        HLL = torch.matmul(matrix_High_2, LL).transpose(dim0=2, dim1=3)
        HLH = torch.matmul(matrix_High_2, LH).transpose(dim0=2, dim1=3)
        HHL = torch.matmul(matrix_High_2, HL).transpose(dim0=2, dim1=3)
        HHH = torch.matmul(matrix_High_2, HH).transpose(dim0=2, dim1=3)
        return LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH

    @staticmethod
    def backward(ctx, grad_LLL, grad_LLH, grad_LHL, grad_LHH,
                 grad_HLL, grad_HLH, grad_HHL, grad_HHH):
        matrix_Low_0, matrix_Low_1, matrix_Low_2, matrix_High_0, matrix_High_1, matrix_High_2 = ctx.saved_variables
        grad_LL = torch.add(torch.matmul(matrix_Low_2.t(), grad_LLL.transpose(dim0=2, dim1=3)), torch.matmul(
            matrix_High_2.t(), grad_HLL.transpose(dim0=2, dim1=3))).transpose(dim0=2, dim1=3)
        grad_LH = torch.add(torch.matmul(matrix_Low_2.t(), grad_LLH.transpose(dim0=2, dim1=3)), torch.matmul(
            matrix_High_2.t(), grad_HLH.transpose(dim0=2, dim1=3))).transpose(dim0=2, dim1=3)
        grad_HL = torch.add(torch.matmul(matrix_Low_2.t(), grad_LHL.transpose(dim0=2, dim1=3)), torch.matmul(
            matrix_High_2.t(), grad_HHL.transpose(dim0=2, dim1=3))).transpose(dim0=2, dim1=3)
        grad_HH = torch.add(torch.matmul(matrix_Low_2.t(), grad_LHH.transpose(dim0=2, dim1=3)), torch.matmul(
            matrix_High_2.t(), grad_HHH.transpose(dim0=2, dim1=3))).transpose(dim0=2, dim1=3)
        grad_L = torch.add(torch.matmul(grad_LL, matrix_Low_1.t()),
                           torch.matmul(grad_LH, matrix_High_1.t()))
        grad_H = torch.add(torch.matmul(grad_HL, matrix_Low_1.t()),
                           torch.matmul(grad_HH, matrix_High_1.t()))
        grad_input = torch.add(torch.matmul(
            matrix_Low_0.t(), grad_L), torch.matmul(matrix_High_0.t(), grad_H))
        return grad_input, None, None, None, None, None, None, None, None


class IDWTFunction_3D(Function):
    @staticmethod
    def forward(ctx, input_LLL, input_LLH, input_LHL, input_LHH,
                input_HLL, input_HLH, input_HHL, input_HHH,
                matrix_Low_0, matrix_Low_1, matrix_Low_2,
                matrix_High_0, matrix_High_1, matrix_High_2):
        ctx.save_for_backward(matrix_Low_0, matrix_Low_1, matrix_Low_2,
                              matrix_High_0, matrix_High_1, matrix_High_2)
        input_LL = torch.add(torch.matmul(matrix_Low_2.t(), input_LLL.transpose(dim0=2, dim1=3)), torch.matmul(
            matrix_High_2.t(), input_HLL.transpose(dim0=2, dim1=3))).transpose(dim0=2, dim1=3)
        input_LH = torch.add(torch.matmul(matrix_Low_2.t(), input_LLH.transpose(dim0=2, dim1=3)), torch.matmul(
            matrix_High_2.t(), input_HLH.transpose(dim0=2, dim1=3))).transpose(dim0=2, dim1=3)
        input_HL = torch.add(torch.matmul(matrix_Low_2.t(), input_LHL.transpose(dim0=2, dim1=3)), torch.matmul(
            matrix_High_2.t(), input_HHL.transpose(dim0=2, dim1=3))).transpose(dim0=2, dim1=3)
        input_HH = torch.add(torch.matmul(matrix_Low_2.t(), input_LHH.transpose(dim0=2, dim1=3)), torch.matmul(
            matrix_High_2.t(), input_HHH.transpose(dim0=2, dim1=3))).transpose(dim0=2, dim1=3)
        input_L = torch.add(torch.matmul(input_LL, matrix_Low_1.t()),
                            torch.matmul(input_LH, matrix_High_1.t()))
        input_H = torch.add(torch.matmul(input_HL, matrix_Low_1.t()),
                            torch.matmul(input_HH, matrix_High_1.t()))
        output = torch.add(torch.matmul(matrix_Low_0.t(), input_L),
                           torch.matmul(matrix_High_0.t(), input_H))
        return output

    @staticmethod
    def backward(ctx, grad_output):
        matrix_Low_0, matrix_Low_1, matrix_Low_2, matrix_High_0, matrix_High_1, matrix_High_2 = ctx.saved_variables
        grad_L = torch.matmul(matrix_Low_0, grad_output)
        grad_H = torch.matmul(matrix_High_0, grad_output)
        grad_LL = torch.matmul(grad_L, matrix_Low_1).transpose(dim0=2, dim1=3)
        grad_LH = torch.matmul(grad_L, matrix_High_1).transpose(dim0=2, dim1=3)
        grad_HL = torch.matmul(grad_H, matrix_Low_1).transpose(dim0=2, dim1=3)
        grad_HH = torch.matmul(grad_H, matrix_High_1).transpose(dim0=2, dim1=3)
        grad_LLL = torch.matmul(
            matrix_Low_2, grad_LL).transpose(dim0=2, dim1=3)
        grad_LLH = torch.matmul(
            matrix_Low_2, grad_LH).transpose(dim0=2, dim1=3)
        grad_LHL = torch.matmul(
            matrix_Low_2, grad_HL).transpose(dim0=2, dim1=3)
        grad_LHH = torch.matmul(
            matrix_Low_2, grad_HH).transpose(dim0=2, dim1=3)
        grad_HLL = torch.matmul(
            matrix_High_2, grad_LL).transpose(dim0=2, dim1=3)
        grad_HLH = torch.matmul(
            matrix_High_2, grad_LH).transpose(dim0=2, dim1=3)
        grad_HHL = torch.matmul(
            matrix_High_2, grad_HL).transpose(dim0=2, dim1=3)
        grad_HHH = torch.matmul(
            matrix_High_2, grad_HH).transpose(dim0=2, dim1=3)
        return grad_LLL, grad_LLH, grad_LHL, grad_LHH, grad_HLL, grad_HLH, grad_HHL, grad_HHH, None, None, None, None, None, None

class DWT_3D(nn.Module):
    """
    input: the 3D data to be decomposed -- (N, C, D, H, W)
    output: lfc -- (N, C, D/2, H/2, W/2)
            hfc_llh -- (N, C, D/2, H/2, W/2)
            hfc_lhl -- (N, C, D/2, H/2, W/2)
            hfc_lhh -- (N, C, D/2, H/2, W/2)
            hfc_hll -- (N, C, D/2, H/2, W/2)
            hfc_hlh -- (N, C, D/2, H/2, W/2)
            hfc_hhl -- (N, C, D/2, H/2, W/2)
            hfc_hhh -- (N, C, D/2, H/2, W/2)
    """

    def __init__(self, wavename, device='cuda:0' if torch.cuda.is_available() else 'cpu'):
        """
        3D discrete wavelet transform (DWT) for 3D data decomposition
        :param wavename: pywt.wavelist(); in the paper, 'chx.y' denotes 'biorx.y'.
        """
        super(DWT_3D, self).__init__()
        self.device = device
        wavelet = pywt.Wavelet(wavename)
        self.band_low = wavelet.rec_lo
        self.band_high = wavelet.rec_hi
        assert len(self.band_low) == len(self.band_high)
        self.band_length = len(self.band_low)
        assert self.band_length % 2 == 0
        self.band_length_half = math.floor(self.band_length / 2)

    def get_matrix(self):
        """
        generating the matrices: \mathcal{L}, \mathcal{H}
        :return: self.matrix_low = \mathcal{L}, self.matrix_high = \mathcal{H}
        """
        L1 = np.max((self.input_height, self.input_width))
        L = math.floor(L1 / 2)
        matrix_h = np.zeros((L, L1 + self.band_length - 2))
        matrix_g = np.zeros((L1 - L, L1 + self.band_length - 2))
        end = None if self.band_length_half == 1 else (
            - self.band_length_half + 1)

        index = 0
        for i in range(L):
            for j in range(self.band_length):
                matrix_h[i, index + j] = self.band_low[j]
            index += 2
        matrix_h_0 = matrix_h[0:(math.floor(
            self.input_height / 2)), 0:(self.input_height + self.band_length - 2)]
        matrix_h_1 = matrix_h[0:(math.floor(
            self.input_width / 2)), 0:(self.input_width + self.band_length - 2)]
        matrix_h_2 = matrix_h[0:(math.floor(
            self.input_depth / 2)), 0:(self.input_depth + self.band_length - 2)]

        index = 0
        for i in range(L1 - L):
            for j in range(self.band_length):
                matrix_g[i, index + j] = self.band_high[j]
            index += 2
        matrix_g_0 = matrix_g[0:(self.input_height - math.floor(
            self.input_height / 2)), 0:(self.input_height + self.band_length - 2)]
        matrix_g_1 = matrix_g[0:(self.input_width - math.floor(
            self.input_width / 2)), 0:(self.input_width + self.band_length - 2)]
        matrix_g_2 = matrix_g[0:(self.input_depth - math.floor(
            self.input_depth / 2)), 0:(self.input_depth + self.band_length - 2)]

        matrix_h_0 = matrix_h_0[:, (self.band_length_half - 1):end]
        matrix_h_1 = matrix_h_1[:, (self.band_length_half - 1):end]
        matrix_h_1 = np.transpose(matrix_h_1)
        matrix_h_2 = matrix_h_2[:, (self.band_length_half - 1):end]

        matrix_g_0 = matrix_g_0[:, (self.band_length_half - 1):end]
        matrix_g_1 = matrix_g_1[:, (self.band_length_half - 1):end]
        matrix_g_1 = np.transpose(matrix_g_1)
        matrix_g_2 = matrix_g_2[:, (self.band_length_half - 1):end]

        self.matrix_low_0 = torch.Tensor(matrix_h_0).to(self.device)
        self.matrix_low_1 = torch.Tensor(matrix_h_1).to(self.device)
        self.matrix_low_2 = torch.Tensor(matrix_h_2).to(self.device)
        self.matrix_high_0 = torch.Tensor(matrix_g_0).to(self.device)
        self.matrix_high_1 = torch.Tensor(matrix_g_1).to(self.device)
        self.matrix_high_2 = torch.Tensor(matrix_g_2).to(self.device)

    def forward(self, input):
        """
        :param input: the 3D data to be decomposed
        :return: the eight components of the input data, one low-frequency and seven high-frequency components
        """
        assert len(input.size()) == 5
        self.input_depth = input.size()[-3]
        self.input_height = input.size()[-2]
        self.input_width = input.size()[-1]
        self.get_matrix()
        return DWTFunction_3D.apply(input, self.matrix_low_0, self.matrix_low_1, self.matrix_low_2,
                                    self.matrix_high_0, self.matrix_high_1, self.matrix_high_2)


class IDWT_3D(nn.Module):
    """
    input:  lfc -- (N, C, D/2, H/2, W/2)
            hfc_llh -- (N, C, D/2, H/2, W/2)
            hfc_lhl -- (N, C, D/2, H/2, W/2)
            hfc_lhh -- (N, C, D/2, H/2, W/2)
            hfc_hll -- (N, C, D/2, H/2, W/2)
            hfc_hlh -- (N, C, D/2, H/2, W/2)
            hfc_hhl -- (N, C, D/2, H/2, W/2)
            hfc_hhh -- (N, C, D/2, H/2, W/2)
    output: the original 3D data -- (N, C, D, H, W)
    """

    def __init__(self, wavename, device='cuda:0' if torch.cuda.is_available() else 'cpu'):
        """
        3D inverse DWT (IDWT) for 3D data reconstruction
        :param wavename: pywt.wavelist(); in the paper, 'chx.y' denotes 'biorx.y'.
        """
        super(IDWT_3D, self).__init__()
        wavelet = pywt.Wavelet(wavename)
        self.device = device
        self.band_low = wavelet.dec_lo
        self.band_high = wavelet.dec_hi
        self.band_low.reverse()
        self.band_high.reverse()
        assert len(self.band_low) == len(self.band_high)
        self.band_length = len(self.band_low)
        assert self.band_length % 2 == 0
        self.band_length_half = math.floor(self.band_length / 2)

    def get_matrix(self):
        """
        generating the matrices: \mathcal{L}, \mathcal{H}
        :return: self.matrix_low = \mathcal{L}, self.matrix_high = \mathcal{H}
        """
        L1 = np.max((self.input_height, self.input_width))
        L = math.floor(L1 / 2)
        matrix_h = np.zeros((L, L1 + self.band_length - 2))
        matrix_g = np.zeros((L1 - L, L1 + self.band_length - 2))
        end = None if self.band_length_half == 1 else (
            - self.band_length_half + 1)

        index = 0
        for i in range(L):
            for j in range(self.band_length):
                matrix_h[i, index + j] = self.band_low[j]
            index += 2
        matrix_h_0 = matrix_h[0:(math.floor(
            self.input_height / 2)), 0:(self.input_height + self.band_length - 2)]
        matrix_h_1 = matrix_h[0:(math.floor(
            self.input_width / 2)), 0:(self.input_width + self.band_length - 2)]
        matrix_h_2 = matrix_h[0:(math.floor(
            self.input_depth / 2)), 0:(self.input_depth + self.band_length - 2)]

        index = 0
        for i in range(L1 - L):
            for j in range(self.band_length):
                matrix_g[i, index + j] = self.band_high[j]
            index += 2
        matrix_g_0 = matrix_g[0:(self.input_height - math.floor(
            self.input_height / 2)), 0:(self.input_height + self.band_length - 2)]
        matrix_g_1 = matrix_g[0:(self.input_width - math.floor(
            self.input_width / 2)), 0:(self.input_width + self.band_length - 2)]
        matrix_g_2 = matrix_g[0:(self.input_depth - math.floor(
            self.input_depth / 2)), 0:(self.input_depth + self.band_length - 2)]

        matrix_h_0 = matrix_h_0[:, (self.band_length_half - 1):end]
        matrix_h_1 = matrix_h_1[:, (self.band_length_half - 1):end]
        matrix_h_1 = np.transpose(matrix_h_1)
        matrix_h_2 = matrix_h_2[:, (self.band_length_half - 1):end]

        matrix_g_0 = matrix_g_0[:, (self.band_length_half - 1):end]
        matrix_g_1 = matrix_g_1[:, (self.band_length_half - 1):end]
        matrix_g_1 = np.transpose(matrix_g_1)
        matrix_g_2 = matrix_g_2[:, (self.band_length_half - 1):end]

        self.matrix_low_0 = torch.Tensor(matrix_h_0).to(self.device)
        self.matrix_low_1 = torch.Tensor(matrix_h_1).to(self.device)
        self.matrix_low_2 = torch.Tensor(matrix_h_2).to(self.device)
        self.matrix_high_0 = torch.Tensor(matrix_g_0).to(self.device)
        self.matrix_high_1 = torch.Tensor(matrix_g_1).to(self.device)
        self.matrix_high_2 = torch.Tensor(matrix_g_2).to(self.device)

    def forward(self, LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH):
        """
        :param LLL: the low-frequency component, lfc
        :param LLH: the high-frequency component, hfc_llh
        :param LHL: the high-frequency component, hfc_lhl
        :param LHH: the high-frequency component, hfc_lhh
        :param HLL: the high-frequency component, hfc_hll
        :param HLH: the high-frequency component, hfc_hlh
        :param HHL: the high-frequency component, hfc_hhl
        :param HHH: the high-frequency component, hfc_hhh
        :return: the original 3D input data
        """
        assert len(LLL.size()) == len(LLH.size()) == len(
            LHL.size()) == len(LHH.size()) == 5
        assert len(HLL.size()) == len(HLH.size()) == len(
            HHL.size()) == len(HHH.size()) == 5
        self.input_depth = LLL.size()[-3] + HHH.size()[-3]
        self.input_height = LLL.size()[-2] + HHH.size()[-2]
        self.input_width = LLL.size()[-1] + HHH.size()[-1]
        self.get_matrix()
        return IDWTFunction_3D.apply(LLL, LLH, LHL, LHH, HLL, HLH, HHL, HHH,
                                     self.matrix_low_0, self.matrix_low_1, self.matrix_low_2,
                                     self.matrix_high_0, self.matrix_high_1, self.matrix_high_2)


class DynamicWaveletUNet3D(nn.Module):
    def __init__(self, d=64, wavelet='haar', device='cuda:0' if torch.cuda.is_available() else 'cpu'):
        super(DynamicWaveletUNet3D, self).__init__()
        
        self.device = device
        
        self.dwt = DWT_3D(wavelet, device=device)
        self.idwt = IDWT_3D(wavelet, device=device)
        
        self.conv1 = nn.Conv3d(1, d, 4, 2, 1)
        self.conv2 = nn.Conv3d(d, d * 2, 4, 2, 1)
        self.conv2_bn = nn.BatchNorm3d(d * 2)
        self.conv3 = nn.Conv3d(d * 2, d * 4, 4, 2, 1)
        self.conv3_bn = nn.BatchNorm3d(d * 4)
        self.conv4 = nn.Conv3d(d * 4, d * 8, 4, 2, 1)
        self.conv4_bn = nn.BatchNorm3d(d * 8)
        self.conv5 = nn.Conv3d(d * 8, d * 8, 4, 2, 1)
        self.conv5_bn = nn.BatchNorm3d(d * 8)
        self.conv6 = nn.Conv3d(d * 8, d * 8, 4, 2, 1)
        self.conv6_bn = nn.BatchNorm3d(d * 8)
        self.conv7 = nn.Conv3d(d * 8, d * 8, 4, 2, 1)
        self.conv7_bn = nn.BatchNorm3d(d * 8)
        self.conv8 = nn.Conv3d(d * 8, d * 8, 4, 2, 1)
        
        self.ema_bottleneck = EMA3D(d * 8, factor=8)
        
        self.deconv1 = nn.ConvTranspose3d(d * 8, d * 8, 4, 2, 1)
        self.deconv1_bn = nn.BatchNorm3d(d * 8)
        self.deconv2 = nn.ConvTranspose3d(d * 8 * 2, d * 8, 4, 2, 1)
        self.deconv2_bn = nn.BatchNorm3d(d * 8)
        self.deconv3 = nn.ConvTranspose3d(d * 8 * 2, d * 8, 4, 2, 1)
        self.deconv3_bn = nn.BatchNorm3d(d * 8)
        self.deconv4 = nn.ConvTranspose3d(d * 8 * 2, d * 8, 4, 2, 1)
        self.deconv4_bn = nn.BatchNorm3d(d * 8)
        self.deconv5 = nn.ConvTranspose3d(d * 8 * 2, d * 4, 4, 2, 1)
        self.deconv5_bn = nn.BatchNorm3d(d * 4)
        self.deconv6 = nn.ConvTranspose3d(d * 4 * 2, d * 2, 4, 2, 1)
        self.deconv6_bn = nn.BatchNorm3d(d * 2)
        self.deconv7 = nn.ConvTranspose3d(d * 2 * 2, d, 4, 2, 1)
        self.deconv7_bn = nn.BatchNorm3d(d)
        self.deconv8 = nn.ConvTranspose3d(d * 2, 1, 4, 2, 1)
        
        self.wave_process1_low = nn.Sequential(
            nn.Conv3d(d, d, 3, 1, 1),
            nn.BatchNorm3d(d),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        self.wave_process1_high = nn.ModuleList([
            nn.Sequential(
                nn.Conv3d(d, d, 3, 1, 1),
                nn.BatchNorm3d(d),
                nn.LeakyReLU(0.2, inplace=True)
            ) for _ in range(7)
        ])
        
        self.wave_process2_low = nn.Sequential(
            nn.Conv3d(d*2, d*2, 3, 1, 1),
            nn.BatchNorm3d(d*2),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        self.wave_process2_high = nn.ModuleList([
            nn.Sequential(
                nn.Conv3d(d*2, d*2, 3, 1, 1),
                nn.BatchNorm3d(d*2),
                nn.LeakyReLU(0.2, inplace=True)
            ) for _ in range(7)
        ])
        
        self.wave_process3_low = nn.Sequential(
            nn.Conv3d(d*4, d*4, 3, 1, 1),
            nn.BatchNorm3d(d*4),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        self.wave_process3_high = nn.ModuleList([
            nn.Sequential(
                nn.Conv3d(d*4, d*4, 3, 1, 1),
                nn.BatchNorm3d(d*4),
                nn.LeakyReLU(0.2, inplace=True)
            ) for _ in range(7)
        ])
        
        self.wave_fusion1 = nn.Conv3d(d, d, 1, 1, 0)
        self.wave_fusion2 = nn.Conv3d(d*2, d*2, 1, 1, 0)
        self.wave_fusion3 = nn.Conv3d(d*4, d*4, 1, 1, 0)
        
        self.ema_e2 = EMA3D(d*2, factor=8)
        self.ema_e3 = EMA3D(d*4, factor=8)
        self.ema_d5 = EMA3D(d*4, factor=8)
        self.ema_d7 = EMA3D(d, factor=8)
        
        self.dynamic_weight_module1 = nn.Sequential(
            nn.AdaptiveAvgPool3d(1),
            nn.Conv3d(d, d//4, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv3d(d//4, 2, 1, bias=False),
            nn.Softmax(dim=1)
        )
        
        self.dynamic_weight_module2 = nn.Sequential(
            nn.AdaptiveAvgPool3d(1),
            nn.Conv3d(d*2, d*2//4, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv3d(d*2//4, 2, 1, bias=False),
            nn.Softmax(dim=1)
        )
        
        self.dynamic_weight_module3 = nn.Sequential(
            nn.AdaptiveAvgPool3d(1),
            nn.Conv3d(d*4, d*4//4, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv3d(d*4//4, 2, 1, bias=False),
            nn.Softmax(dim=1)
        )
        
        self.wavelet_bypass1 = nn.Parameter(torch.tensor(0.5))
        self.wavelet_bypass2 = nn.Parameter(torch.tensor(0.5))
        self.wavelet_bypass3 = nn.Parameter(torch.tensor(0.5))
        
        self.attention_intensity1 = nn.Parameter(torch.tensor(1.0))
        self.attention_intensity2 = nn.Parameter(torch.tensor(1.0))
        self.attention_intensity3 = nn.Parameter(torch.tensor(1.0))

    def weight_init(self, mean, std):
        for m in self._modules:
            if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d):
                m.weight.data.normal_(mean, std)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm3d):
                m.weight.data.normal_(1.0, 0.02)
                m.bias.data.zero_()

    def process_wavelet_components(self, components, processor_low, processor_high_list):
        processed_low = processor_low(components[0])
        
        processed_high = []
        for i in range(1, 8):
            processed_high.append(processor_high_list[i-1](components[i]))
            
        return [processed_low] + processed_high

    def calculate_dynamic_weights(self, feature, weight_module):
        return weight_module(feature)

    def forward(self, input):
        e1 = self.conv1(input)
        
        wave1_components = self.dwt(e1)
        
        processed_wave1 = self.process_wavelet_components(
            wave1_components, 
            self.wave_process1_low, 
            self.wave_process1_high
        )
        
        wave1_recon = self.idwt(*processed_wave1)
        
        wave1_enhanced = self.wave_fusion1(wave1_recon)
        
        bypass_weight1 = torch.sigmoid(self.wavelet_bypass1)
        e1_enhanced = bypass_weight1 * e1 + (1 - bypass_weight1) * wave1_enhanced
        
        e2 = self.conv2_bn(self.conv2(F.leaky_relu(e1_enhanced, 0.2)))
        
        wave2_components = self.dwt(e2)
        
        processed_wave2 = self.process_wavelet_components(
            wave2_components, 
            self.wave_process2_low, 
            self.wave_process2_high
        )
        
        wave2_recon = self.idwt(*processed_wave2)
        
        wave2_enhanced = self.wave_fusion2(wave2_recon)
        
        dynamic_weights2 = self.calculate_dynamic_weights(e2, self.dynamic_weight_module2)
        
        e2_weighted = dynamic_weights2[:, 0:1] * e2 + dynamic_weights2[:, 1:2] * wave2_enhanced
        
        bypass_weight2 = torch.sigmoid(self.wavelet_bypass2)
        e2_enhanced = bypass_weight2 * e2 + (1 - bypass_weight2) * e2_weighted
        
        e2_attention = self.ema_e2(e2_enhanced)
        e2_enhanced = e2_enhanced + self.attention_intensity1 * (e2_attention - e2_enhanced)
        
        e3 = self.conv3_bn(self.conv3(F.leaky_relu(e2_enhanced, 0.2)))
        
        wave3_components = self.dwt(e3)
        
        processed_wave3 = self.process_wavelet_components(
            wave3_components, 
            self.wave_process3_low, 
            self.wave_process3_high
        )
        
        wave3_recon = self.idwt(*processed_wave3)
        
        wave3_enhanced = self.wave_fusion3(wave3_recon)
        
        dynamic_weights3 = self.calculate_dynamic_weights(e3, self.dynamic_weight_module3)
        
        e3_weighted = dynamic_weights3[:, 0:1] * e3 + dynamic_weights3[:, 1:2] * wave3_enhanced
        
        bypass_weight3 = torch.sigmoid(self.wavelet_bypass3)
        e3_enhanced = bypass_weight3 * e3 + (1 - bypass_weight3) * e3_weighted
        
        e3_attention = self.ema_e3(e3_enhanced)
        e3_enhanced = e3_enhanced + self.attention_intensity2 * (e3_attention - e3_enhanced)
        
        e4 = self.conv4_bn(self.conv4(F.leaky_relu(e3_enhanced, 0.2)))
        e5 = self.conv5_bn(self.conv5(F.leaky_relu(e4, 0.2)))
        e6 = self.conv6_bn(self.conv6(F.leaky_relu(e5, 0.2)))
        e7 = self.conv7_bn(self.conv7(F.leaky_relu(e6, 0.2)))
        e8 = self.conv8(F.leaky_relu(e7, 0.2))
        
        e8 = self.ema_bottleneck(e8)
        
        d1 = self.deconv1_bn(self.deconv1(F.relu(e8)))
        d1 = torch.cat([d1, e7], 1)
        
        d2 = self.deconv2_bn(self.deconv2(F.relu(d1)))
        d2 = torch.cat([d2, e6], 1)
        
        d3 = self.deconv3_bn(self.deconv3(F.relu(d2)))
        d3 = torch.cat([d3, e5], 1)
        
        d4 = self.deconv4_bn(self.deconv4(F.relu(d3)))
        d4 = torch.cat([d4, e4], 1)
        
        d5 = self.deconv5_bn(self.deconv5(F.relu(d4)))
        
        d5_attention = self.ema_d5(d5)
        d5 = d5 + self.attention_intensity3 * (d5_attention - d5)
        
        d5 = torch.cat([d5, e3_enhanced], 1)
        
        d6 = self.deconv6_bn(self.deconv6(F.relu(d5)))
        d6 = torch.cat([d6, e2_enhanced], 1)
        
        d7 = self.deconv7_bn(self.deconv7(F.relu(d6)))
        
        d7 = self.ema_d7(d7)
        
        d7 = torch.cat([d7, e1_enhanced], 1)
        
        d8 = self.deconv8(F.relu(d7))
        
        o = torch.tanh(d8)
        
        return o
    
    def init_wavelet_params(self):
        for name, param in self.named_parameters():
            if 'wavelet_bypass' in name:
                param.data.fill_(0.9)

def normal_init3d(m, mean, std):
    if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d):
        m.weight.data.normal_(mean, std)
        if m.bias is not None:
            m.bias.data.zero_()
    elif isinstance(m, nn.BatchNorm3d):
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.zero_()
