import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange
import math
import os
import sys 
current_file_dir = os.path.dirname(os.path.abspath(__file__))
rootpath = os.path.abspath(os.path.join(current_file_dir, '../'))
sys.path.append(rootpath)
rootpath = os.path.abspath(os.path.join(rootpath, '../'))
sys.path.append(rootpath)

print(rootpath) 
from nfn_moe.common.weight_space import MoENetworkSpec, MoEWeightSpaceFeatures, NetworkSpec, LinearWeightSpaceFeatures
from layers.layer_utils import (
    set_init_,
    set_init_einsum_,
)
class HNPSLinear(nn.Module):
    def __init__(self, in_network_spec: NetworkSpec, in_channels, out_channels, init_type="pytorch_default"):
        super().__init__()
        self.c_in, self.c_out = in_channels, out_channels
        self.in_network_spec = in_network_spec

        layer_weight_shapes = in_network_spec.get_matrices_shape()
        layer_in_weight_shape, layer_out_weight_shape = layer_weight_shapes[0], layer_weight_shapes[-1]
        self.L = len(in_network_spec)
        in_filter_facs = [int(np.prod(spec.shape[2:])) for spec in in_network_spec.weight_spec]
        out_filter_facs = in_filter_facs
        
        for i in range(self.L):
            in_filter_fac = in_filter_facs[i]
            out_filter_fac = out_filter_facs[i]
            if i == 0:
                self.layer_0_Y_W = EinsumLayer(equation="mnqk, bnjq -> bmjk", 
                                              weight_shape=[out_filter_fac * out_channels, in_filter_fac * in_channels, 
                                                             layer_in_weight_shape[-1], layer_in_weight_shape[-1]],
                                              fan_in_mask = [0, 1, 1, 0])
                self.layer_0_Y_b = EinsumLayer(equation="mnk, bnj -> bmjk",
                                               weight_shape=[out_filter_fac * out_channels, in_channels, 
                                                             layer_in_weight_shape[-1]],
                                               fan_in_mask= [0, 1, 0])
                self.layer_0_z_W = EinsumLayer(equation="mnq, bnjq -> bmj",
                                               weight_shape=[out_channels, in_filter_fac * in_channels, 
                                                             layer_in_weight_shape[-1]],
                                               fan_in_mask=[0, 1, 1]
                                               )
                self.layer_0_z_b = EinsumLayer(equation="mn, bnj -> bmj",
                                               weight_shape=[out_channels, in_channels],
                                               fan_in_mask=[0, 1])

                set_init_einsum_(
                    self.layer_0_Y_W,
                    self.layer_0_Y_b,
                    init_type=init_type,
                )
                set_init_einsum_(
                    self.layer_0_z_W,
                    self.layer_0_z_b,
                    init_type=init_type,
                )

            elif i == self.L-1:
                self.add_module(f"layer_{i}_Y_W",
                                EinsumLayer(equation="mnpj, bnpk -> bmjk",
                                            weight_shape=[out_filter_fac * out_channels, in_filter_fac * in_channels,
                                                          layer_out_weight_shape[-2], layer_out_weight_shape[-2]],
                                            fan_in_mask=[0, 1, 1, 0])
                                )
                self.add_module(f"layer_{i}_z_b",
                                EinsumLayer(equation="mnpj, bnp -> bmj",
                                            weight_shape=[out_channels, in_filter_fac * in_channels,
                                                          layer_out_weight_shape[-2], layer_out_weight_shape[-2]],
                                            fan_in_mask=[0, 1, 1, 0])
                                )
                self.add_module(f"layer_{i}_z_tau",
                                EinsumLayer(equation="mj, b-> bmj",
                                            weight_shape=[out_channels, layer_out_weight_shape[-2]],
                                            fan_in_mask=[0, 0])
                                )
                
                set_init_einsum_(
                    getattr(self, f"layer_{i}_Y_W"),
                    init_type=init_type,
                )
                attributes = []
                attributes.extend([getattr(self, f"layer_{i}_z_b")])
                attributes.extend([getattr(self, f"layer_{i}_z_tau")])

                set_init_einsum_(*attributes,
                    init_type=init_type,
                )                
            else:
                self.add_module(f"layer_{i}_Y_W",
                                EinsumLayer(equation="mn, bnjk -> bmjk",
                                            weight_shape=[out_filter_fac * out_channels, in_filter_fac * in_channels],
                                            fan_in_mask=[0, 1])
                                )
                self.add_module(f"layer_{i}_z_b",
                                EinsumLayer(equation="mn, bnj -> bmj",
                                            weight_shape=[out_channels, in_channels],
                                            fan_in_mask=[0, 1])
                                )
                
                set_init_einsum_(
                    getattr(self, f"layer_{i}_Y_W"),
                    init_type=init_type,
                )
                set_init_einsum_(
                    getattr(self, f"layer_{i}_z_b"),
                    init_type=init_type,
                )

    
    def forward(self, wsfeat: LinearWeightSpaceFeatures) -> LinearWeightSpaceFeatures:
        out_weights, out_biases = [], []
        for i in range(self.L):
            # weight, bias = wsfeat['weight'][i], wsfeat['bias'][i]
            # weight, bias = weight.cuda(), bias.cuda()
            
            weight, bias = wsfeat[i]
            if  i == 0:
                Y_W = self.layer_0_Y_W(weight)
                Y_b = self.layer_0_Y_b(bias)
                #make a random tensor for Y_b with the same shape as Y_W
                #Y_b = torch.randn_like(Y_W)
                out_weights.append(Y_W + Y_b)
                
                z_W = self.layer_0_z_W(weight)
                z_b = self.layer_0_z_b(bias)
                out_biases.append(z_W + z_b)

            elif i == self.L-1:
                Y_W = getattr(self, f"layer_{i}_Y_W")(weight)
                out_weights.append(Y_W)
                
                z_b = getattr(self, f"layer_{i}_z_b")(bias)
                z_tau = getattr(self, f"layer_{i}_z_tau")(torch.tensor([1], device=weight.device))
                out_biases.append(z_b + z_tau)
            
            else:
                Y_W = getattr(self, f"layer_{i}_Y_W")(weight)
                out_weights.append(Y_W)
                
                z_b = getattr(self, f"layer_{i}_z_b")(bias)
                out_biases.append(z_b)
        #return {'weight': out_weights, 'bias': out_biases}
        return LinearWeightSpaceFeatures(out_weights, out_biases)

class ElementwiseParamNormalize(nn.Module):
    def __init__(self, hidden, mode_normalize) -> None:
        super().__init__()
        self.hidden = hidden
        self.mode_normalize = mode_normalize
        self.weight = nn.Parameter(torch.ones(hidden))
        self.bias = nn.Parameter(torch.ones(hidden))
        nn.init.normal_(self.weight)
        nn.init.normal_(self.bias)


    def forward(self, input):
        if self.mode_normalize == "param_mul_L2":
            if input.dim() == 6: #C NN
                    input_shape = input.shape
                    input = rearrange(input, 'b c i j k l -> b i j (c k l)')
                    input_regularized = F.normalize(input, p=2.0, dim=-1)
                    input_regularized = self.weight * input_regularized + self.bias
                    input_regularized = rearrange(input_regularized, 'b i j (c k l) -> b c i j k l',
                                                    c = input_shape[1], k = input_shape[-2],
                                                    l = input_shape[-1])
            elif input.dim() == 4: # MLP
                input = rearrange(input, 'b c i j-> b i j c')
                input_regularized = F.normalize(input, p=2.0, dim=-1)
                input_regularized = self.weight * input_regularized + self.bias
                input_regularized = rearrange(input_regularized, 'b i j c -> b c i j')

            elif input.dim() == 3: #bias
                input = rearrange(input, 'b c j-> b j c')
                input_regularized = F.normalize(input, p=2.0, dim=-1)
                input_regularized = self.weight * input_regularized + self.bias
                input_regularized = rearrange(input_regularized, 'b j c -> b c j')
        return input_regularized

class HNPSPool(nn.Module):
    def __init__(self, network_spec: NetworkSpec, nfn_channels, mode_pooling="param_mul_L2"):
        super().__init__()
        self.network_spec = network_spec
        self.mode_pooling = mode_pooling
        self.nfn_channels = nfn_channels
        if self.mode_pooling == "param_mul_L2":
            for i in range(len(network_spec)):
                if len(network_spec.weight_spec[i].shape) == 4: #CNN
                    self.add_module(f"regularize_W_{i}",
                                    ElementwiseParamNormalize(nfn_channels *
                                                math.prod(network_spec.weight_spec[i].shape[-2:]),
                                                mode_normalize=mode_pooling)
                                    )
                else:
                    self.add_module(f"regularize_W_{i}", ElementwiseParamNormalize(nfn_channels, mode_normalize=mode_pooling))
                self.add_module(f"regularize_b_{i}", ElementwiseParamNormalize(nfn_channels, mode_normalize=mode_pooling))


    def forward(self, wsfeat: LinearWeightSpaceFeatures) -> torch.Tensor:
        out = []
        for i in range(len(self.network_spec)):
            weight, bias = wsfeat[i]
            if self.mode_pooling == "param_mul_L2":
                regularizer_w = getattr(self, f"regularize_W_{i}")
                regularizer_b = getattr(self, f"regularize_b_{i}")
            else:
                regularizer_w = self.regularize_without_param
                regularizer_b = self.regularize_without_param


            if i == 0:
                weight_regularized = regularizer_w(weight)
                out.append(weight_regularized.mean(dim=2))  # average over rows

            elif i == len(wsfeat) - 1:
                weight_regularized = regularizer_w(weight)
                out.append(weight_regularized.mean(dim=3))  # average over cols

            else:
                weight_regularized = regularizer_w(weight)
                out.append(weight_regularized.mean(dim=(2,3)).unsqueeze(-1))

            if i == len(wsfeat) - 1:
                out.append(bias)
            else:
                # bias_regularized = F.normalize(bias, dim=1, p=2.0)
                bias_regularized = regularizer_b(bias)
                out.append(bias_regularized.mean(dim=-1).unsqueeze(-1))

        return torch.cat([torch.flatten(o, start_dim=2) for o in out], dim=-1)

    def regularize_without_param(self, weight):
        if self.mode_pooling == "L1":
            if weight.dim() == 6:
                weight_shape = weight.shape
                weight = rearrange(weight, 'b c i j k l -> b (c k l) i j')
                weight_regularized = F.normalize(weight, dim=1, p=1.0)
                weight_regularized = rearrange(weight_regularized, 'b (c k l) i j -> b c i j k l',
                                                c = weight_shape[1], k = weight_shape[-2],
                                                l = weight_shape[-1])
            else:
                weight_regularized = F.normalize(weight, dim=1, p=2.0)
        elif self.mode_pooling == "L2":
            if weight.dim() == 6:
                weight_shape = weight.shape
                weight = rearrange(weight, 'b c i j k l -> b (c k l) i j')
                weight_regularized = F.normalize(weight, dim=1, p=2.0)
                weight_regularized = rearrange(weight_regularized, 'b (c k l) i j -> b c i j k l',
                                                c = weight_shape[1], k = weight_shape[-2],
                                                l = weight_shape[-1])
            else:
                weight_regularized = F.normalize(weight, dim=1, p=2.0)
        elif self.mode_pooling == "L2_square":
            if weight.dim() == 6:
                weight_shape = weight.shape
                weight = rearrange(weight, 'b c i j k l -> b (c k l) i j')
                weight_regularized = F.normalize(weight, dim=1, p=2.0) ** 2
                weight_regularized = rearrange(weight_regularized, 'b (c k l) i j -> b c i j k l',
                                                c = weight_shape[1], k = weight_shape[-2],
                                                l = weight_shape[-1])
            else:
                weight_regularized = F.normalize(weight, dim=1, p=2.0) ** 2

        return weight_regularized

    @staticmethod
    def get_num_outs(network_spec):
        """Returns the number of outputs of the global pooling layer."""
        filter_facs = [int(np.prod(spec.shape[2:])) for spec in network_spec.weight_spec]
        n_in, n_out = network_spec.get_io()
        num_outs = 0
        for i, fac in enumerate(filter_facs):
            if i == 0:
                num_outs += n_in * fac + 1
            elif i == len(filter_facs) - 1:
                num_outs += n_out * fac + n_out
            else:
                num_outs += fac + 1
        return num_outs

class EinsumLayer(nn.Module):
    def __init__(self, equation="", weight_shape=None, input_shape=None, fan_in_mask=None, unsqueeze_dims=None, flag = None) -> None:
        super().__init__()
        self.equation = equation
        if len(self.equation) == 0:
            return

        self.weight_shape_list = weight_shape
        self.weight_shape_tensor = torch.tensor(weight_shape, dtype=torch.int)
        if input_shape is not None:
            self.input_shape_tensor = torch.tensor(input_shape, dtype=torch.int)
        else:
            self.input_shape_tensor = self.weight_shape_tensor

        # Get fan_in and fan_out
        self.fan_in_mask = torch.tensor(fan_in_mask).ge(0.5)
        if torch.all(self.fan_in_mask == False):
            self.fan_in = 0
        else:
            # self.fan_in = torch.prod(self.weight_shape_tensor[self.fan_in_mask])
            self.fan_in = torch.prod(self.input_shape_tensor[self.fan_in_mask])
        self.fan_out_mask = torch.tensor(fan_in_mask).lt(0.5)
        if torch.all(self.fan_out_mask == False):
            self.fan_out = 0
        else:
            # self.fan_out = torch.prod(self.weight_shape_tensor[self.fan_out_mask])
            self.fan_out = torch.abs(torch.prod(self.input_shape_tensor[self.fan_out_mask]))

        # Setup equation
        self.equation = equation
        self.weight = nn.Parameter(torch.empty(self.weight_shape_list))
        #self.weight = nn.Parameter(torch.ones(self.weight_shape_list))
        self.unsqueeze_dims = unsqueeze_dims if unsqueeze_dims is not None else []
        self.input_parts = self.equation.split('->')[0].split(',')
        self.num_inputs = len(self.input_parts)

        self.flag = flag
    def forward(self, input=None):
        if self.flag == "sub_mean":
            # substract the weight for the mean of weight in 3rd dim (dimension of n_experts)
            processed_weight = self.weight - torch.mean(self.weight, dim=2, keepdim=True)

        elif self.flag == "div_mean":
            # divide the weight for the mean of weight in 3rd dim (dimension of n_experts)
            processed_weight = self.weight / torch.mean(self.weight, dim=2, keepdim=True)
        
        if self.num_inputs == 1:
            result = torch.einsum(self.equation, self.weight)

        elif self.num_inputs == 2:
            if 'b' in self.input_parts[0]:  # The first part is the input tensor
                if self.flag:
                    result = torch.einsum(self.equation, input, processed_weight)
                else:
                    result = torch.einsum(self.equation, input, self.weight)
            elif 'b' in self.input_parts[1]:  # The second part is the input tensor
                if self.flag:
                    result = torch.einsum(self.equation, processed_weight, input)
                else:
                    result = torch.einsum(self.equation, self.weight, input)
            else:
                raise ValueError("No batch dimension 'b' found in the einsum equation input parts.")

        else:
            raise ValueError(f"Unexpected number of input parts ({self.num_inputs}) in einsum equation: {self.equation}")
        
        for dim in self.unsqueeze_dims:
            result = result.unsqueeze(dim)

        return result

class SharedEinsumLayer(nn.Module):
    def __init__(self, equation_1="", equation_2="", weight_shape=None, input_shape=None, fan_in_mask=None, unsqueeze_dims=None, shared_dim = 2) -> None:
        super().__init__()
        # both equations should have same input and weight shape, equation 2 will have additional dim (the shared dim)
        # eg. equation 1: bdnp, edjp -> bej ; euqation 2: bdnp, edjp -> benj
        self.equation_1 = equation_1
        self.equation_2 = equation_2
        if len(self.equation_1) == 0 or len(self.equation_2) == 0:
            return

        self.weight_shape_list = weight_shape
        self.weight_shape_tensor = torch.tensor(weight_shape, dtype=torch.int)
        if input_shape is not None:
            self.input_shape_tensor = torch.tensor(input_shape, dtype=torch.int)
        else:
            self.input_shape_tensor = self.weight_shape_tensor

        # Get fan_in and fan_out
        self.fan_in_mask = torch.tensor(fan_in_mask).ge(0.5)
        if torch.all(self.fan_in_mask == False):
            self.fan_in = 0
        else:
            # self.fan_in = torch.prod(self.weight_shape_tensor[self.fan_in_mask])
            self.fan_in = torch.prod(self.input_shape_tensor[self.fan_in_mask])
        self.fan_out_mask = torch.tensor(fan_in_mask).lt(0.5)
        if torch.all(self.fan_out_mask == False):
            self.fan_out = 0
        else:
            # self.fan_out = torch.prod(self.weight_shape_tensor[self.fan_out_mask])
            self.fan_out = torch.abs(torch.prod(self.input_shape_tensor[self.fan_out_mask]))

        # Setup equation
        self.weight = nn.Parameter(torch.empty(self.weight_shape_list))
        #self.weight = nn.Parameter(torch.ones(self.weight_shape_list))
        self.unsqueeze_dims = unsqueeze_dims if unsqueeze_dims is not None else []
        self.input_parts = self.equation_1.split('->')[0].split(',')
        self.num_inputs = len(self.input_parts)
        self.shared_dim = shared_dim

    def forward(self, input=None):

        if self.num_inputs == 2:
            if 'b' in self.input_parts[0]:  # The first part is the input tensor
                out_1 = torch.einsum(self.equation_1, input, self.weight)
                out_2 = torch.einsum(self.equation_2, input, self.weight) * - input.shape[self.shared_dim]
            elif 'b' in self.input_parts[1]:  # The second part is the input tensor
                out_1 = torch.einsum(self.equation_1, self.weight, input)
                out_2 = torch.einsum(self.equation_2, self.weight, input) * - input.shape[self.shared_dim]
            else:
                raise ValueError("No batch dimension 'b' found in the einsum equation input parts.")

        else:
            raise ValueError(f"Unexpected number of input parts ({self.num_inputs}) in einsum equation: {self.equation}")
        
        for i, dim in enumerate(self.unsqueeze_dims):
            out_1 = out_1.unsqueeze(dim)
            if i>0:
                out_2 = out_2.unsqueeze(dim)
        result = out_1 + out_2 
        return result
    

class MoELinearEquiv(nn.Module):
    def __init__(self, encoder_weight_spec: MoENetworkSpec, in_channels, out_channels, init_type="pytorch_default", scale_degree = 1):
        super().__init__()
        self.d, self.e = in_channels, out_channels
        self.encoder_weight_spec = encoder_weight_spec

        D, D_q, D_k, D_v, n_e, D_e, h =  encoder_weight_spec.get_all_dims() #not yet implemented
        
        self.L = len(encoder_weight_spec)
        for i in range(self.L):
            # -----------------------------------
            #            W_Q Terms
            # -----------------------------------
            self.add_module(f"layer_{i}_W_Q",EinsumLayer(equation="bdhpk, edjp -> behjk",
                                        weight_shape=[self.e, self.d, D, D],
                                        input_shape=[-1, self.d, h, D, D_q],
                                        fan_in_mask=[0, 1, 0, 0, 1]
                                        ))
            set_init_einsum_(getattr(self, f"layer_{i}_W_Q"), init_type=init_type)
            # -----------------------------------
            #            W_K Terms
            # -----------------------------------
            self.add_module(f"layer_{i}_W_K",EinsumLayer(equation="bdhpk, edjp -> behjk",
                                        weight_shape=[self.e, self.d, D, D],
                                        input_shape=[-1, self.d, h, D, D_k],
                                        fan_in_mask=[0, 1, 0, 0, 1]
                                        ))
            set_init_einsum_(getattr(self, f"layer_{i}_W_K"), init_type=init_type)
            # -----------------------------------
            #            W_V Terms
            # -----------------------------------
            self.add_module(f"layer_{i}_W_V",EinsumLayer(equation="bdhpk, edjp -> behjk",
                                        weight_shape=[self.e, self.d, D, D],
                                        input_shape=[-1, self.d, h, D, D_v],
                                        fan_in_mask=[0, 1, 0, 0, 1],
                                        ))
            set_init_einsum_(getattr(self, f"layer_{i}_W_V"), init_type=init_type)
            
            # -----------------------------------
            #            W_O Terms
            # -----------------------------------

            self.add_module(f"layer_{i}_W_O",EinsumLayer(equation="bdhjk, edkq -> behjk",
                                        weight_shape=[self.e, self.d, D, D],
                                        input_shape=[-1, self.d, h, D_v, D],
                                        fan_in_mask=[0, 1, 0, 0, 0],
                                        ))
            set_init_einsum_(getattr(self, f"layer_{i}_W_O"), init_type=init_type)

            # -----------------------------------
            #            W_G Terms
            # -----------------------------------
            # Total of 14 terms
            
            # 1st term
            self.add_module(f"layer_{i}_W_G_W_QK", EinsumLayer(equation="bdhpq, edjpq -> bej",
                                            weight_shape=[self.e, self.d, D, D, D],
                                            input_shape=[-1, self.d, h, D, D],
                                            fan_in_mask=[0, 1, 1, 1, 1],
                                            unsqueeze_dims=[-2]))
            
            # 2nd term
            self.add_module(f"layer_{i}_W_G_W_VO", EinsumLayer(equation="bdhpq, edjpq -> bej",
                                            weight_shape=[self.e, self.d, D, D, D],
                                            input_shape=[-1, self.d, h, D, D],
                                            fan_in_mask=[0, 1, 1, 1, 1],
                                            unsqueeze_dims=[-2]))
            self.add_module(f"layer_{i}_W_G_W_G", SharedEinsumLayer(
                        equation_1="bdnp, edjp -> bej",
                        equation_2="bdnp, edjp -> benj",
                        weight_shape=[self.e, self.d, D, D],
                        input_shape=[-1, self.d, n_e, D],
                        fan_in_mask=[0, 1, 0, 1],
                        unsqueeze_dims=[-2],))


            # 5th term # 1st WA term
            self.add_module(f"layer_{i}_W_G_W_A_1", EinsumLayer(equation="bdnpq, edjp -> bej",
                                            weight_shape=[self.e, self.d, D, D],
                                            input_shape=[-1, self.d, n_e, D, D_e],
                                            fan_in_mask=[0, 1, 1, 1, 1],
                                            unsqueeze_dims=[-2]))

            # 6th term # 2nd WA term
            self.add_module(f"layer_{i}_W_G_W_A_2", EinsumLayer(equation="bdnpq, edjp -> benj",
                                            weight_shape=[self.e, self.d, D, D],
                                            input_shape=[-1, self.d, n_e, D, D_e],
                                            fan_in_mask=[0, 1, 0, 1, 1],))

            # 7th term # 1st WB term
            self.add_module(f"layer_{i}_W_G_W_B_1", EinsumLayer(equation="bdnpq, edjq -> bej",
                                            weight_shape=[self.e, self.d, D, D],
                                            input_shape=[-1, self.d, n_e, D_e, D],
                                            fan_in_mask=[0, 1, 1, 1, 1],
                                            unsqueeze_dims=[-2]))

            # 8th term # 2nd WB term
            self.add_module(f"layer_{i}_W_G_W_B_2", EinsumLayer(equation="bdnpq, edjq -> benj",
                                            weight_shape=[self.e, self.d, D, D],
                                            input_shape=[-1, self.d, n_e, D_e, D],
                                            fan_in_mask=[0, 1, 0, 1, 1]))

            self.add_module(f"layer_{i}_W_G_b_G", SharedEinsumLayer(
                                    equation_1="bdn, edj -> bej",
                                    equation_2="bdn, edj -> benj",
                                    weight_shape=[self.e, self.d, D],
                                    input_shape=[-1, self.d, n_e],
                                    fan_in_mask=[0, 1, 0],
                                    unsqueeze_dims=[-2],))

            # 10th term 
            self.add_module(f"layer_{i}_W_G_b_A_1", EinsumLayer(equation="bdnq, edj -> bej",
                                            weight_shape=[self.e, self.d, D],
                                            input_shape=[-1, self.d, n_e, D_e],
                                            fan_in_mask=[0, 1, 1, 1],
                                            unsqueeze_dims=[-2]))

            # 11th term 
            self.add_module(f"layer_{i}_W_G_b_A_2", EinsumLayer(equation="bdnq, edj -> benj",
                                            weight_shape=[self.e, self.d, D],
                                            input_shape=[-1, self.d, n_e, D_e],
                                            fan_in_mask=[0, 1, 0, 1]))

            # 12th term 
            self.add_module(f"layer_{i}_W_G_b_B_1", EinsumLayer(equation="bdnq, edjq -> bej",
                                            weight_shape=[self.e, self.d, D, D],
                                            input_shape=[-1, self.d, n_e, D],
                                            fan_in_mask=[0, 1, 1, 1],
                                            unsqueeze_dims=[-2]))
            
            # 13th term
            self.add_module(f"layer_{i}_W_G_b_B_2", EinsumLayer(equation="bdnq, edjq -> benj",
                                            weight_shape=[self.e, self.d, D, D],
                                            input_shape=[-1, self.d, n_e, D],
                                            fan_in_mask=[0, 1, 0, 1]))

            # 14th Term
            self.add_module(f"layer_{i}_W_G_bias", EinsumLayer(equation="ej -> ej",
                                        weight_shape=[self.e, D],
                                        input_shape=[self.e, D],
                                        fan_in_mask=[0, 0],
                                        unsqueeze_dims=[0, -2]))
            
            set_init_einsum_(getattr(self, f"layer_{i}_W_G_W_QK"), 
                             getattr(self, f"layer_{i}_W_G_W_VO"),
                                getattr(self, f"layer_{i}_W_G_W_G"),
                                getattr(self, f"layer_{i}_W_G_W_A_1"),
                                getattr(self, f"layer_{i}_W_G_W_A_2"),
                                getattr(self, f"layer_{i}_W_G_W_B_1"),
                                getattr(self, f"layer_{i}_W_G_W_B_2"),
                                getattr(self, f"layer_{i}_W_G_b_G"),
                                getattr(self, f"layer_{i}_W_G_b_A_1"),
                                getattr(self, f"layer_{i}_W_G_b_A_2"),
                                getattr(self, f"layer_{i}_W_G_b_B_1"),
                                getattr(self, f"layer_{i}_W_G_b_B_2"),
                                getattr(self, f"layer_{i}_W_G_bias"),
                                init_type=init_type, scale_degree=scale_degree)
                        
            # -----------------------------------
            #            W_A Terms
            # -----------------------------------
            # Total of 20 terms

            # 1st term
            self.add_module(f"layer_{i}_W_A_W_QK", EinsumLayer(
                                    equation="bdhpq, edjpq -> bej",
                                    weight_shape=[self.e, self.d, D, D, D],
                                    input_shape=[-1, self.d, h, D, D],
                                    fan_in_mask=[0, 1, 1, 1, 1],
                                    unsqueeze_dims=[-2, -1]))

            # 2nd term
            self.add_module(f"layer_{i}_W_A_W_VO", EinsumLayer(
                                    equation="bdhpq, edjpq -> bej",
                                    weight_shape=[self.e, self.d, D, D, D],
                                    input_shape=[-1, self.d, h, D, D],
                                    fan_in_mask=[0, 1, 1, 1, 1],
                                    unsqueeze_dims=[-2, -1]))

            self.add_module(f"layer_{i}_W_A_W_G", SharedEinsumLayer(
                                    equation_1="bdnp, edjp -> bej",
                                    equation_2="bdnp, edjp -> benj",
                                    weight_shape=[self.e, self.d, D, D],
                                    input_shape=[-1, self.d, n_e, D],
                                    fan_in_mask=[0, 1, 0, 1],
                                    unsqueeze_dims=[-2, -1],))
            # 5th term
            self.add_module(f"layer_{i}_W_A_W_A_1", EinsumLayer(
                                    equation="bdnpq, edjp -> bej",
                                    weight_shape=[self.e, self.d, D, D],
                                    input_shape=[-1, self.d, n_e, D, D_e],
                                    fan_in_mask=[0, 1, 1, 1, 1],
                                    unsqueeze_dims=[-2, -1]))

            # 6th term
            self.add_module(f"layer_{i}_W_A_W_A_2", EinsumLayer(
                                    equation="bdnpq, edjp -> benj",
                                    weight_shape=[self.e, self.d, D, D],
                                    input_shape=[-1, self.d, n_e, D, D_e],
                                    fan_in_mask=[0, 1, 0, 1, 1],
                                    unsqueeze_dims=[-1]))

            # 8th term
            self.add_module(f"layer_{i}_W_A_W_A_3", EinsumLayer(
                                    equation="bdnpk, edjp -> benjk",
                                    weight_shape=[self.e, self.d, D, D],
                                    input_shape=[-1, self.d, n_e, D, D_e],
                                    fan_in_mask=[0, 1, 0, 1, 0]))

            # 9th term
            self.add_module(f"layer_{i}_W_A_W_B_1", EinsumLayer(
                                    equation="bdnpq, edjq -> bej",
                                    weight_shape=[self.e, self.d, D, D],
                                    input_shape=[-1, self.d, n_e, D_e, D],
                                    fan_in_mask=[0, 1, 1, 1, 1],
                                    unsqueeze_dims=[-2, -1]))

            # 10th term
            self.add_module(f"layer_{i}_W_A_W_B_2", EinsumLayer(
                                    equation="bdnpq, edjq -> benj",
                                    weight_shape=[self.e, self.d, D, D],
                                    input_shape=[-1, self.d, n_e, D_e, D],
                                    fan_in_mask=[0, 1, 0, 1, 1],
                                    unsqueeze_dims=[-1]))

            # 12th term
            self.add_module(f"layer_{i}_W_A_W_B_3", EinsumLayer(
                                    equation="bdnkq, edjq -> benjk",
                                    weight_shape=[self.e, self.d, D, D],
                                    input_shape=[-1, self.d, n_e, D_e, D],
                                    fan_in_mask=[0, 1, 0, 0, 1]))

            self.add_module(f"layer_{i}_W_A_b_G", SharedEinsumLayer(
                                    equation_1="bdn, edj -> bej",
                                    equation_2="bdn, edj -> benj",
                                    weight_shape=[self.e, self.d, D],
                                    input_shape=[-1, self.d, n_e],
                                    fan_in_mask=[0, 1, 0],
                                    unsqueeze_dims=[-2, -1],))

            # 14th term
            self.add_module(f"layer_{i}_W_A_b_A_1", EinsumLayer(
                                    equation="bdnq, edj -> bej",
                                    weight_shape=[self.e, self.d, D],
                                    input_shape=[-1, self.d, n_e, D_e],
                                    fan_in_mask=[0, 1, 1, 1],
                                    unsqueeze_dims=[-2, -1]))

            # 15th term
            self.add_module(f"layer_{i}_W_A_b_A_2", EinsumLayer(
                                    equation="bdnq, edj -> benj",
                                    weight_shape=[self.e, self.d, D],
                                    input_shape=[-1, self.d, n_e, D_e],
                                    fan_in_mask=[0, 1, 0, 1],
                                    unsqueeze_dims=[-1]))


            # 17th term
            self.add_module(f"layer_{i}_W_A_b_A_3", EinsumLayer(
                                    equation="bdnk, edj -> benjk",
                                    weight_shape=[self.e, self.d, D],
                                    input_shape=[-1, self.d, n_e, D_e],
                                    fan_in_mask=[0, 1, 0, 0]))

            # 18th term
            self.add_module(f"layer_{i}_W_A_b_B_1", EinsumLayer(
                                    equation="bdnq, edjq -> bej",
                                    weight_shape=[self.e, self.d, D, D],
                                    input_shape=[-1, self.d, n_e, D],
                                    fan_in_mask=[0, 1, 1, 1],
                                    unsqueeze_dims=[-2, -1]))

            # 19th term
            self.add_module(f"layer_{i}_W_A_b_B_2", EinsumLayer(
                                    equation="bdnq, edjq -> benj",
                                    weight_shape=[self.e, self.d, D, D],
                                    input_shape=[-1, self.d, n_e, D],
                                    fan_in_mask=[0, 1, 0, 1],
                                    unsqueeze_dims=[-1]))

            # 20th term
            self.add_module(f"layer_{i}_W_A_bias", EinsumLayer(
                                    equation="ej -> ej",
                                    weight_shape=[self.e, D],
                                    input_shape=[self.e, D],
                                    fan_in_mask=[0, 0],
                                    unsqueeze_dims=[0, -2, -1])) 

            # set init einsum for all WA terms
            set_init_einsum_(getattr(self, f"layer_{i}_W_A_b_B_1"),
                            getattr(self, f"layer_{i}_W_A_b_B_2"),
                            getattr(self, f"layer_{i}_W_A_bias"),
                            init_type=init_type)
            set_init_einsum_(getattr(self, f"layer_{i}_W_A_W_QK"),
                                getattr(self, f"layer_{i}_W_A_W_VO"),
                                getattr(self, f"layer_{i}_W_A_W_G"),
                                getattr(self, f"layer_{i}_W_A_W_A_1"),
                                getattr(self, f"layer_{i}_W_A_W_A_2"),
                                getattr(self, f"layer_{i}_W_A_W_A_3"),
                                getattr(self, f"layer_{i}_W_A_W_B_1"),
                                getattr(self, f"layer_{i}_W_A_W_B_2"),
                                getattr(self, f"layer_{i}_W_A_W_B_3"),
                                getattr(self, f"layer_{i}_W_A_b_G"),
                                getattr(self, f"layer_{i}_W_A_b_A_1"),
                                getattr(self, f"layer_{i}_W_A_b_A_2"),
                                getattr(self, f"layer_{i}_W_A_b_A_3"),
                                init_type=init_type, scale_degree=scale_degree)
            # -----------------------------------
            #            W_B Terms
            # -----------------------------------
            # Total of 20 terms

            # 1st term
            self.add_module(f"layer_{i}_W_B_W_QK", EinsumLayer(
                                            equation="bdhpq, edkpq -> bek",
                                            weight_shape=[self.e, self.d, D, D, D],
                                            input_shape=[-1, self.d, h, D, D],
                                            fan_in_mask=[0, 1, 1, 1, 1],
                                            unsqueeze_dims=[-2, -2]))

            # 2nd term
            self.add_module(f"layer_{i}_W_B_W_VO", EinsumLayer(
                                            equation="bdhpq, edkpq -> bek",
                                            weight_shape=[self.e, self.d, D, D, D],
                                            input_shape=[-1, self.d, h, D, D],
                                            fan_in_mask=[0, 1, 1, 1, 1],
                                            unsqueeze_dims=[-2, -2]))
            self.add_module(f"layer_{i}_W_B_W_G", SharedEinsumLayer(
                                    equation_1="bdnp, edkp -> bek",
                                    equation_2="bdnp, edkp -> benk",
                                    weight_shape=[self.e, self.d, D, D],
                                    input_shape=[-1, self.d, n_e, D],
                                    fan_in_mask=[0, 1, 0, 1],
                                    unsqueeze_dims=[-2, -2],))

            # 5th term
            self.add_module(f"layer_{i}_W_B_W_A_1", EinsumLayer(
                                            equation="bdnpq, edkp -> bek",
                                            weight_shape=[self.e, self.d, D, D],
                                            input_shape=[-1, self.d, n_e, D, D_e],
                                            fan_in_mask=[0, 1, 1, 1, 1],
                                            unsqueeze_dims=[-2, -2]))

            # 6th term
            self.add_module(f"layer_{i}_W_B_W_A_2", EinsumLayer(
                                            equation="bdnpq, edkp -> benk",
                                            weight_shape=[self.e, self.d, D, D],
                                            input_shape=[-1, self.d, n_e, D, D_e],
                                            fan_in_mask=[0, 1, 0, 1, 1],
                                            unsqueeze_dims=[-2]))

            # 8th term
            self.add_module(f"layer_{i}_W_B_W_A_3", EinsumLayer(
                                            equation="bdnpj, edkp -> benjk",
                                            weight_shape=[self.e, self.d, D, D],
                                            input_shape=[-1, self.d, n_e, D, D_e],
                                            fan_in_mask=[0, 1, 0, 1, 0]))

            # 9th term
            self.add_module(f"layer_{i}_W_B_W_B_1", EinsumLayer(
                                            equation="bdnpq, edkq -> bek",
                                            weight_shape=[self.e, self.d, D, D],
                                            input_shape=[-1, self.d, n_e, D_e, D],
                                            fan_in_mask=[0, 1, 1, 1, 1],
                                            unsqueeze_dims=[-2, -2]))

            # 10th term
            self.add_module(f"layer_{i}_W_B_W_B_2", EinsumLayer(
                                            equation="bdnpq, edkq -> benk",
                                            weight_shape=[self.e, self.d, D, D],
                                            input_shape=[-1, self.d, n_e, D_e, D],
                                            fan_in_mask=[0, 1, 0, 1, 1],
                                            unsqueeze_dims=[-2]))

            # 12th term
            self.add_module(f"layer_{i}_W_B_W_B_3", EinsumLayer(
                                            equation="bdnjq, edkq -> benjk",
                                            weight_shape=[self.e, self.d, D, D],
                                            input_shape=[-1, self.d, n_e, D_e, D],
                                            fan_in_mask=[0, 1, 0, 0, 1]))

            # # 13th term
            self.add_module(f"layer_{i}_W_B_b_G", SharedEinsumLayer(
                                    equation_1="bdn, edk -> bek",
                                    equation_2="bdn, edk -> benk",
                                    weight_shape=[self.e, self.d, D],
                                    input_shape=[-1, self.d, n_e],
                                    fan_in_mask=[0, 1, 0],
                                    unsqueeze_dims=[-2, -2],))


            # 14th term
            self.add_module(f"layer_{i}_W_B_b_A_1", EinsumLayer(
                                            equation="bdnq, edk -> bek",
                                            weight_shape=[self.e, self.d, D],
                                            input_shape=[-1, self.d, n_e, D_e],
                                            fan_in_mask=[0, 1, 1, 1],
                                            unsqueeze_dims=[-2, -2]))

            # 15th term
            self.add_module(f"layer_{i}_W_B_b_A_2", EinsumLayer(
                                            equation="bdnq, edk -> benk",
                                            weight_shape=[self.e, self.d, D],
                                            input_shape=[-1, self.d, n_e, D_e],
                                            fan_in_mask=[0, 1, 0, 1],
                                            unsqueeze_dims=[-2]))

            # 17th term
            self.add_module(f"layer_{i}_W_B_b_A_3", EinsumLayer(
                                            equation="bdnj, edk -> benjk",
                                            weight_shape=[self.e, self.d, D],
                                            input_shape=[-1, self.d, n_e, D_e],
                                            fan_in_mask=[0, 1, 0, 0]))

            # 18th term
            self.add_module(f"layer_{i}_W_B_b_B_1", EinsumLayer(
                                            equation="bdnq, edkq -> bek",
                                            weight_shape=[self.e, self.d, D, D],
                                            input_shape=[-1, self.d, n_e, D],
                                            fan_in_mask=[0, 1, 1, 1],
                                            unsqueeze_dims=[-2, -2]))

            # 19th term
            self.add_module(f"layer_{i}_W_B_b_B_2", EinsumLayer(
                                            equation="bdnq, edkq -> benk",
                                            weight_shape=[self.e, self.d, D, D],
                                            input_shape=[-1, self.d, n_e, D],
                                            fan_in_mask=[0, 1, 0, 1],
                                            unsqueeze_dims=[-2]))

            # 20th term
            self.add_module(f"layer_{i}_W_B_bias", EinsumLayer(
                                            equation="ek -> ek",
                                            weight_shape=[self.e, D],
                                            input_shape=[self.e, D],
                                            fan_in_mask=[0, 0],
                                            unsqueeze_dims=[0, -2, -2]))
            # set init einsum for all WB terms
            set_init_einsum_(getattr(self, f"layer_{i}_W_B_b_B_1"),
                                getattr(self, f"layer_{i}_W_B_b_B_2"),
                                getattr(self, f"layer_{i}_W_B_bias"),init_type=init_type)
            set_init_einsum_(getattr(self, f"layer_{i}_W_B_W_QK"),
                                getattr(self, f"layer_{i}_W_B_W_VO"),
                                getattr(self, f"layer_{i}_W_B_W_G"),
                                getattr(self, f"layer_{i}_W_B_W_A_1"),
                                getattr(self, f"layer_{i}_W_B_W_A_2"),
                                getattr(self, f"layer_{i}_W_B_W_A_3"),
                                getattr(self, f"layer_{i}_W_B_W_B_1"),
                                getattr(self, f"layer_{i}_W_B_W_B_2"),
                                getattr(self, f"layer_{i}_W_B_W_B_3"),
                                getattr(self, f"layer_{i}_W_B_b_G"),
                                getattr(self, f"layer_{i}_W_B_b_A_1"),
                                getattr(self, f"layer_{i}_W_B_b_A_2"),
                                getattr(self, f"layer_{i}_W_B_b_A_3"),
                                init_type=init_type, scale_degree=scale_degree)
            
            # -----------------------------------
            #            b_G Terms
            # -----------------------------------
            # Total of 14 terms

            # 1st term
            self.add_module(f"layer_{i}_b_G_W_QK", EinsumLayer(
                                                    equation="bdhpq, edpq -> be",
                                                    weight_shape=[self.e, self.d, D, D],
                                                    input_shape=[-1, self.d, h, D, D],
                                                    fan_in_mask=[0, 1, 1, 1, 1],
                                                    unsqueeze_dims=[-1]))

            # 2nd term
            self.add_module(f"layer_{i}_b_G_W_VO", EinsumLayer(
                                                    equation="bdhpq, edpq -> be",
                                                    weight_shape=[self.e, self.d, D, D],
                                                    input_shape=[-1, self.d, h, D, D],
                                                    fan_in_mask=[0, 1, 1, 1, 1],
                                                    unsqueeze_dims=[-1]))

            # # 3rd term
            self.add_module(f"layer_{i}_b_G_W_G", SharedEinsumLayer(
                                                    equation_1="bdnp, edp -> be",
                                                    equation_2="bdnp, edp -> ben",
                                                    weight_shape=[self.e, self.d, D],
                                                    input_shape=[-1, self.d, n_e, D],
                                                    fan_in_mask=[0, 1, 0, 1],
                                                    unsqueeze_dims=[-1]))

            # 5th term # 1st WA term
            self.add_module(f"layer_{i}_b_G_W_A_1", EinsumLayer(
                                                    equation="bdnpq, edp -> be",
                                                    weight_shape=[self.e, self.d, D],
                                                    input_shape=[-1, self.d, n_e, D, D_e],
                                                    fan_in_mask=[0, 1, 1, 1, 1],
                                                    unsqueeze_dims=[-1]))

            # 6th term # 2nd WA term
            self.add_module(f"layer_{i}_b_G_W_A_2", EinsumLayer(
                                                    equation="bdnpq, edp -> ben",
                                                    weight_shape=[self.e, self.d, D],
                                                    input_shape=[-1, self.d, n_e, D, D_e],
                                                    fan_in_mask=[0, 1, 0, 1, 1]))

            # 7th term # 1st WB term
            self.add_module(f"layer_{i}_b_G_W_B_1", EinsumLayer(
                                                    equation="bdnpq, edq -> be",
                                                    weight_shape=[self.e, self.d, D],
                                                    input_shape=[-1, self.d, n_e, D_e, D],
                                                    fan_in_mask=[0, 1, 1, 1, 1],
                                                    unsqueeze_dims=[-1]))

            # 8th term # 2nd WB term
            self.add_module(f"layer_{i}_b_G_W_B_2", EinsumLayer(
                                                    equation="bdnpq, edq -> ben",
                                                    weight_shape=[self.e, self.d, D],
                                                    input_shape=[-1, self.d, n_e, D_e, D],
                                                    fan_in_mask=[0, 1, 0, 1, 1]))

            # 9th term
            self.add_module(f"layer_{i}_b_G_b_G", SharedEinsumLayer(
                                                    equation_1="bdn, ed -> be",
                                                    equation_2="bdn, ed -> ben",
                                                    weight_shape=[self.e, self.d],
                                                    input_shape=[-1, self.d, n_e],
                                                    fan_in_mask=[0, 1, 0],
                                                    unsqueeze_dims=[-1]))

            # 10th term
            self.add_module(f"layer_{i}_b_G_b_A_1", EinsumLayer(
                                                    equation="bdnq, ed -> be",
                                                    weight_shape=[self.e, self.d],
                                                    input_shape=[-1, self.d, n_e, D_e],
                                                    fan_in_mask=[0, 1, 1, 1],
                                                    unsqueeze_dims=[-1]))

            # 11th term
            self.add_module(f"layer_{i}_b_G_b_A_2", EinsumLayer(
                                                    equation="bdnq, ed -> ben",
                                                    weight_shape=[self.e, self.d],
                                                    input_shape=[-1, self.d, n_e, D_e],
                                                    fan_in_mask=[0, 1, 0, 1]))

            # 12th term
            self.add_module(f"layer_{i}_b_G_b_B_1", EinsumLayer(
                                                    equation="bdnq, edq -> be",
                                                    weight_shape=[self.e, self.d, D],
                                                    input_shape=[-1, self.d, n_e, D],
                                                    fan_in_mask=[0, 1, 1, 1],
                                                    unsqueeze_dims=[-1]))

            # 13th term
            self.add_module(f"layer_{i}_b_G_b_B_2", EinsumLayer(
                                                    equation="bdnq, edq -> ben",
                                                    weight_shape=[self.e, self.d, D],
                                                    input_shape=[-1, self.d, n_e, D],
                                                    fan_in_mask=[0, 1, 1, 1]))

            # 14th Term
            self.add_module(f"layer_{i}_b_G_bias", EinsumLayer(
                                                    equation="e -> e",
                                                    weight_shape=[self.e],
                                                    input_shape=[self.e],
                                                    fan_in_mask=[0],
                                                    unsqueeze_dims=[0, -1]))
            
            # set init einsum for all bG terms
            set_init_einsum_(getattr(self, f"layer_{i}_b_G_W_QK"),
                                getattr(self, f"layer_{i}_b_G_W_VO"),
                                getattr(self, f"layer_{i}_b_G_W_G"),
                                getattr(self, f"layer_{i}_b_G_W_A_1"),
                                getattr(self, f"layer_{i}_b_G_W_A_2"),
                                getattr(self, f"layer_{i}_b_G_W_B_1"),
                                getattr(self, f"layer_{i}_b_G_W_B_2"),
                                getattr(self, f"layer_{i}_b_G_b_G"),
                                getattr(self, f"layer_{i}_b_G_b_A_1"),
                                getattr(self, f"layer_{i}_b_G_b_A_2"),
                                getattr(self, f"layer_{i}_b_G_b_B_1"),
                                getattr(self, f"layer_{i}_b_G_b_B_2"),
                                getattr(self, f"layer_{i}_b_G_bias"),
                                init_type=init_type, scale_degree=scale_degree)
            
            # -----------------------------------
            #            b_A Terms
            # -----------------------------------
            # Total of 20 terms

            # 1st term
            self.add_module(f"layer_{i}_b_A_W_QK", EinsumLayer(
                                                    equation="bdhpq, edpq -> be",
                                                    weight_shape=[self.e, self.d, D, D],
                                                    input_shape=[-1, self.d, h, D, D],
                                                    fan_in_mask=[0, 1, 1, 1, 1],
                                                    unsqueeze_dims=[-1, -1]))

            # 2nd term
            self.add_module(f"layer_{i}_b_A_W_VO", EinsumLayer(
                                                    equation="bdhpq, edpq -> be",
                                                    weight_shape=[self.e, self.d, D, D],
                                                    input_shape=[-1, self.d, h, D, D],
                                                    fan_in_mask=[0, 1, 1, 1, 1],
                                                    unsqueeze_dims=[-1, -1]))

            # # 3rd term
            self.add_module(f"layer_{i}_b_A_W_G", SharedEinsumLayer(
                                                    equation_1="bdnp, edp -> be",
                                                    equation_2="bdnp, edp -> ben",
                                                    weight_shape=[self.e, self.d, D],
                                                    input_shape=[-1, self.d, n_e, D],
                                                    fan_in_mask=[0, 1, 0, 1],
                                                    unsqueeze_dims=[-1, -1]))

            # 5th term
            self.add_module(f"layer_{i}_b_A_W_A_1", EinsumLayer(
                                                    equation="bdnpq, edp -> be",
                                                    weight_shape=[self.e, self.d, D],
                                                    input_shape=[-1, self.d, n_e, D, D_e],
                                                    fan_in_mask=[0, 1, 1, 1, 1],
                                                    unsqueeze_dims=[-1, -1]))

            # 6th term
            self.add_module(f"layer_{i}_b_A_W_A_2", EinsumLayer(
                                                    equation="bdnpq, edp -> ben",
                                                    weight_shape=[self.e, self.d, D],
                                                    input_shape=[-1, self.d, n_e, D, D_e],
                                                    fan_in_mask=[0, 1, 0, 1, 1],
                                                    unsqueeze_dims=[-1]))

            # 8th term
            self.add_module(f"layer_{i}_b_A_W_A_3", EinsumLayer(
                                                    equation="bdnpj, edp -> benj",
                                                    weight_shape=[self.e, self.d, D],
                                                    input_shape=[-1, self.d, n_e, D, D_e],
                                                    fan_in_mask=[0, 1, 0, 1, 0]))

            # 9th term
            self.add_module(f"layer_{i}_b_A_W_B_1", EinsumLayer(
                                                    equation="bdnpq, edq -> be",
                                                    weight_shape=[self.e, self.d, D],
                                                    input_shape=[-1, self.d, n_e, D_e, D],
                                                    fan_in_mask=[0, 1, 1, 1, 1],
                                                    unsqueeze_dims=[-1, -1]))

            # 10th term
            self.add_module(f"layer_{i}_b_A_W_B_2", EinsumLayer(
                                                    equation="bdnpq, edq -> ben",
                                                    weight_shape=[self.e, self.d, D],
                                                    input_shape=[-1, self.d, n_e, D_e, D],
                                                    fan_in_mask=[0, 1, 0, 1, 1],
                                                    unsqueeze_dims=[-1]))

            # 12th term
            self.add_module(f"layer_{i}_b_A_W_B_3", EinsumLayer(
                                                    equation="bdnjq, edq -> benj",
                                                    weight_shape=[self.e, self.d, D],
                                                    input_shape=[-1, self.d, n_e, D_e, D],
                                                    fan_in_mask=[0, 1, 0, 0, 1]))

            # # 13th term

            self.add_module(f"layer_{i}_b_A_b_G", SharedEinsumLayer(
                                                    equation_1="bdn, ed -> be",
                                                    equation_2="bdn, ed -> ben",
                                                    weight_shape=[self.e, self.d],
                                                    input_shape=[-1, self.d, n_e],
                                                    fan_in_mask=[0, 1, 0],
                                                    unsqueeze_dims=[-1, -1]))

            # 14th term
            self.add_module(f"layer_{i}_b_A_b_A_1", EinsumLayer(
                                                    equation="bdnq, ed -> be",
                                                    weight_shape=[self.e, self.d],
                                                    input_shape=[-1, self.d, n_e, D_e],
                                                    fan_in_mask=[0, 1, 1, 1],
                                                    unsqueeze_dims=[-1, -1]))

            # 15th term
            self.add_module(f"layer_{i}_b_A_b_A_2", EinsumLayer(
                                                    equation="bdnq, ed -> ben",
                                                    weight_shape=[self.e, self.d],
                                                    input_shape=[-1, self.d, n_e, D_e],
                                                    fan_in_mask=[0, 1, 0, 1],
                                                    unsqueeze_dims=[-1]))

            # 17th term
            self.add_module(f"layer_{i}_b_A_b_A_3", EinsumLayer(
                                                    equation="bdnj, ed -> benj",
                                                    weight_shape=[self.e, self.d],
                                                    input_shape=[-1, self.d, n_e, D_e],
                                                    fan_in_mask=[0, 1, 0, 0]))

            # 18th term
            self.add_module(f"layer_{i}_b_A_b_B_1", EinsumLayer(
                                                    equation="bdnq, edq -> be",
                                                    weight_shape=[self.e, self.d, D],
                                                    input_shape=[-1, self.d, n_e, D],
                                                    fan_in_mask=[0, 1, 1, 1],
                                                    unsqueeze_dims=[-1, -1]))

            # 19th term
            self.add_module(f"layer_{i}_b_A_b_B_2", EinsumLayer(
                                                    equation="bdnq, edq -> ben",
                                                    weight_shape=[self.e, self.d, D],
                                                    input_shape=[-1, self.d, n_e, D],
                                                    fan_in_mask=[0, 1, 0, 1],
                                                    unsqueeze_dims=[-1]))

            # 20th term
            self.add_module(f"layer_{i}_b_A_bias", EinsumLayer(
                                                    equation="e -> e",
                                                    weight_shape=[self.e],
                                                    input_shape=[self.e],
                                                    fan_in_mask=[0],
                                                    unsqueeze_dims=[0, -1, -1]))
            
            # set init einsum for all bA terms
            set_init_einsum_(getattr(self, f"layer_{i}_b_A_b_A_1"),
                            getattr(self, f"layer_{i}_b_A_b_A_2"),
                            getattr(self, f"layer_{i}_b_A_b_A_3"),
                            getattr(self, f"layer_{i}_b_A_b_B_1"),
                            getattr(self, f"layer_{i}_b_A_b_B_2"),
                            getattr(self, f"layer_{i}_b_A_bias"), init_type=init_type)
            set_init_einsum_(getattr(self, f"layer_{i}_b_A_W_QK"),
                                getattr(self, f"layer_{i}_b_A_W_VO"),
                                getattr(self, f"layer_{i}_b_A_W_G"),
                                getattr(self, f"layer_{i}_b_A_W_A_1"),
                                getattr(self, f"layer_{i}_b_A_W_A_2"),
                                getattr(self, f"layer_{i}_b_A_W_A_3"),
                                getattr(self, f"layer_{i}_b_A_W_B_1"),
                                getattr(self, f"layer_{i}_b_A_W_B_2"),
                                getattr(self, f"layer_{i}_b_A_W_B_3"),
                                getattr(self, f"layer_{i}_b_A_b_G"),                                
                                getattr(self, f"layer_{i}_b_A_bias"),
                                init_type=init_type, scale_degree=scale_degree)

            # -----------------------------------
            #            b_B Terms
            # -----------------------------------
            # Total of 14 terms

            # 1st term
            self.add_module(f"layer_{i}_b_B_W_QK", EinsumLayer(
                                                            equation="bdhpq, edjpq -> bej",
                                                            weight_shape=[self.e, self.d, D, D, D],
                                                            input_shape=[-1, self.d, h, D, D],
                                                            fan_in_mask=[0, 1, 1, 1, 1],
                                                            unsqueeze_dims=[-2]))

            # 2nd term
            self.add_module(f"layer_{i}_b_B_W_VO", EinsumLayer(
                                                            equation="bdhpq, edjpq -> bej",
                                                            weight_shape=[self.e, self.d, D, D, D],
                                                            input_shape=[-1, self.d, h, D, D],
                                                            fan_in_mask=[0, 1, 1, 1, 1],
                                                            unsqueeze_dims=[-2]))

            # 3rd term
            self.add_module(f"layer_{i}_b_B_W_G", SharedEinsumLayer(
                                                    equation_1="bdnp, edjp -> bej",
                                                    equation_2="bdnp, edjp -> benj",
                                                    weight_shape=[self.e, self.d, D, D],
                                                    input_shape=[-1, self.d, n_e, D],
                                                    fan_in_mask=[0, 1, 0, 1],
                                                    unsqueeze_dims=[-2]))

            # 5th term
            self.add_module(f"layer_{i}_b_B_W_A_1", EinsumLayer(
                                                            equation="bdnpq, edjp -> bej",
                                                            weight_shape=[self.e, self.d, D, D],
                                                            input_shape=[-1, self.d, n_e, D, D_e],
                                                            fan_in_mask=[0, 1, 1, 1, 1],
                                                            unsqueeze_dims=[-2]))

            # 6th term
            self.add_module(f"layer_{i}_b_B_W_A_2", EinsumLayer(
                                                            equation="bdnpq, edjp -> benj",
                                                            weight_shape=[self.e, self.d, D, D],
                                                            input_shape=[-1, self.d, n_e, D, D_e],
                                                            fan_in_mask=[0, 1, 0, 1, 1]))

            # 7th term
            self.add_module(f"layer_{i}_b_B_W_B_1", EinsumLayer(
                                                            equation="bdnpq, edjq -> bej",
                                                            weight_shape=[self.e, self.d, D, D],
                                                            input_shape=[-1, self.d, n_e, D],
                                                            fan_in_mask=[0, 1, 0, 1],
                                                            unsqueeze_dims=[-2]))

            # 8th term
            self.add_module(f"layer_{i}_b_B_W_B_2", EinsumLayer(
                                                            equation="bdnpq, edjq -> benj",
                                                            weight_shape=[self.e, self.d, D, D],
                                                            input_shape=[-1, self.d, n_e, D],
                                                            fan_in_mask=[0, 1, 0, 1]))

            # 9th term
            self.add_module(f"layer_{i}_b_B_b_G", SharedEinsumLayer(
                                                    equation_1="bdn, edj -> bej",
                                                    equation_2="bdn, edj -> benj",
                                                    weight_shape=[self.e, self.d, D],
                                                    input_shape=[-1, self.d, n_e],
                                                    fan_in_mask=[0, 1, 0],
                                                    unsqueeze_dims=[-2]))

            # 10th term
            self.add_module(f"layer_{i}_b_B_b_A_1", EinsumLayer(
                                                            equation="bdnq, edj -> bej",
                                                            weight_shape=[self.e, self.d, D],
                                                            input_shape=[-1, self.d, n_e, D_e],
                                                            fan_in_mask=[0, 1, 1, 1],
                                                            unsqueeze_dims=[-2]))

            # 11th term
            self.add_module(f"layer_{i}_b_B_b_A_2", EinsumLayer(
                                                            equation="bdnq, edj -> benj",
                                                            weight_shape=[self.e, self.d, D],
                                                            input_shape=[-1, self.d, n_e, D_e],
                                                            fan_in_mask=[0, 1, 0, 1]))

            # 12th term
            self.add_module(f"layer_{i}_b_B_b_B_1", EinsumLayer(
                                                            equation="bdnq, edjq -> bej",
                                                            weight_shape=[self.e, self.d, D, D],
                                                            input_shape=[-1, self.d, n_e, D],
                                                            fan_in_mask=[0, 1, 1, 1],
                                                            unsqueeze_dims=[-2]))

            # 13th term
            self.add_module(f"layer_{i}_b_B_b_B_2", EinsumLayer(
                                                            equation="bdnq, edjq -> benj",
                                                            weight_shape=[self.e, self.d, D, D],
                                                            input_shape=[-1, self.d, n_e, D],
                                                            fan_in_mask=[0, 1, 0, 1]))

            # 14th term
            self.add_module(f"layer_{i}_b_B_bias", EinsumLayer(
                                                            equation="ej -> ej",
                                                            weight_shape=[self.e, D],
                                                            input_shape=[self.e, D],
                                                            fan_in_mask=[0, 0],
                                                            unsqueeze_dims=[0, -2]))
            #set init einsum for all bB terms
            set_init_einsum_(getattr(self, f"layer_{i}_b_B_b_A_1"),
                            getattr(self, f"layer_{i}_b_B_b_A_2"),
                            getattr(self, f"layer_{i}_b_B_b_B_1"),
                            getattr(self, f"layer_{i}_b_B_b_B_2"),
                            getattr(self, f"layer_{i}_b_B_bias"), init_type=init_type)
            set_init_einsum_(getattr(self, f"layer_{i}_b_B_W_QK"),
                                getattr(self, f"layer_{i}_b_B_W_VO"),
                                getattr(self, f"layer_{i}_b_B_W_G"),
                                getattr(self, f"layer_{i}_b_B_W_A_1"),
                                getattr(self, f"layer_{i}_b_B_W_A_2"),
                                getattr(self, f"layer_{i}_b_B_W_B_1"),
                                getattr(self, f"layer_{i}_b_B_W_B_2"),
                                getattr(self, f"layer_{i}_b_B_b_G"), 
                                init_type=init_type, scale_degree=scale_degree)
    def forward(self, wsfeat: MoEWeightSpaceFeatures):
        out_dict = {
            "W_q": [], "W_k": [], "W_v": [], "W_o": [], "W_G": [],
            "W_A": [], "W_B": [], "b_G": [], "b_A": [], "b_B": []
        }
        
        L = len(wsfeat)  # Number of layers
    
        # Loop over each layer's weights and biases
        for i in range(L):
            W_q, W_k, W_v, W_o, W_G, W_A, W_B, b_G, b_A, b_B = wsfeat[i]
            # Compute intermediate products using einsum equations
            WW_qk = torch.einsum('bdhpk, bdhqk -> bdhpq', W_q, W_k)
            WW_vo = torch.einsum('bdhpk, bdhkq -> bdhpq', W_v, W_o)

            # # Apply your EinsumLayers for W_q, W_k, W_v, and W_o
            layer_W_q = getattr(self, f"layer_{i}_W_Q")(W_q)
            out_dict["W_q"].append(layer_W_q)

            layer_W_k = getattr(self, f"layer_{i}_W_K")(W_k)
            out_dict["W_k"].append(layer_W_k)

            layer_W_v = getattr(self, f"layer_{i}_W_V")(W_v)
            out_dict["W_v"].append(layer_W_v)

            layer_W_o = getattr(self, f"layer_{i}_W_O")(W_o)
            out_dict["W_o"].append(layer_W_o)

            # W_G
            layer_W_G_W_QK = getattr(self, f"layer_{i}_W_G_W_QK")(WW_qk)
            layer_W_G_W_VO = getattr(self, f"layer_{i}_W_G_W_VO")(WW_vo)
            layer_W_G_W_G = getattr(self, f"layer_{i}_W_G_W_G")(W_G)
            layer_W_G_id = torch.einsum('bdnj -> bnj', W_G).unsqueeze(1)
            layer_W_G_W_A_1 = getattr(self, f"layer_{i}_W_G_W_A_1")(W_A)
            layer_W_G_W_A_2 = getattr(self, f"layer_{i}_W_G_W_A_2")(W_A)
            layer_W_G_W_B_1 = getattr(self, f"layer_{i}_W_G_W_B_1")(W_B)
            layer_W_G_W_B_2 = getattr(self, f"layer_{i}_W_G_W_B_2")(W_B)
            layer_W_G_b_G = getattr(self, f"layer_{i}_W_G_b_G")(b_G)
            layer_W_G_b_A_1 = getattr(self, f"layer_{i}_W_G_b_A_1")(b_A)
            layer_W_G_b_A_2 = getattr(self, f"layer_{i}_W_G_b_A_2")(b_A)
            layer_W_G_b_B_1 = getattr(self, f"layer_{i}_W_G_b_B_1")(b_B)
            layer_W_G_b_B_2 = getattr(self, f"layer_{i}_W_G_b_B_2")(b_B)
            layer_W_G_bias = getattr(self, f"layer_{i}_W_G_bias")()

            out_dict["W_G"].append(layer_W_G_W_G + layer_W_G_W_QK + layer_W_G_W_VO 
                + layer_W_G_W_A_1 + layer_W_G_W_A_2 + layer_W_G_W_B_1 + layer_W_G_W_B_2
                + layer_W_G_b_G + layer_W_G_id + layer_W_G_b_A_1 + layer_W_G_b_A_2 
                + layer_W_G_b_B_1 + layer_W_G_b_B_2 + layer_W_G_bias)

            # W_A
            layer_W_A_W_QK = getattr(self, f"layer_{i}_W_A_W_QK")(WW_qk)
            layer_W_A_W_VO = getattr(self, f"layer_{i}_W_A_W_VO")(WW_vo)
            layer_W_A_W_G = getattr(self, f"layer_{i}_W_A_W_G")(W_G)
            layer_W_A_W_A_1 = getattr(self, f"layer_{i}_W_A_W_A_1")(W_A)
            layer_W_A_W_A_2 = getattr(self, f"layer_{i}_W_A_W_A_2")(W_A)
            layer_W_A_W_A_3 = getattr(self, f"layer_{i}_W_A_W_A_3")(W_A)
            layer_W_A_W_B_1 = getattr(self, f"layer_{i}_W_A_W_B_1")(W_B)
            layer_W_A_W_B_2 = getattr(self, f"layer_{i}_W_A_W_B_2")(W_B)
            layer_W_A_W_B_3 = getattr(self, f"layer_{i}_W_A_W_B_3")(W_B)
            layer_W_A_b_G = getattr(self, f"layer_{i}_W_A_b_G")(b_G)
            layer_W_A_b_A_1 = getattr(self, f"layer_{i}_W_A_b_A_1")(b_A)
            layer_W_A_b_A_2 = getattr(self, f"layer_{i}_W_A_b_A_2")(b_A)
            layer_W_A_b_A_3 = getattr(self, f"layer_{i}_W_A_b_A_3")(b_A)
            layer_W_A_b_B_1 = getattr(self, f"layer_{i}_W_A_b_B_1")(b_B)
            layer_W_A_b_B_2 = getattr(self, f"layer_{i}_W_A_b_B_2")(b_B)
            layer_W_A_bias = getattr(self, f"layer_{i}_W_A_bias")()

            out_dict["W_A"].append(layer_W_A_W_QK + layer_W_A_W_VO + layer_W_A_W_G    
                + layer_W_A_W_A_1 + layer_W_A_W_A_2  + layer_W_A_W_A_3 
                + layer_W_A_W_B_1 + layer_W_A_W_B_2  + layer_W_A_W_B_3 
                + layer_W_A_b_A_1 + layer_W_A_b_A_2  + layer_W_A_b_A_3
                + layer_W_A_b_G + layer_W_A_b_B_1 + layer_W_A_b_B_2 + layer_W_A_bias)
            # W_B
            layer_W_B_W_QK = getattr(self, f"layer_{i}_W_B_W_QK")(WW_qk)
            layer_W_B_W_VO = getattr(self, f"layer_{i}_W_B_W_VO")(WW_vo)
            layer_W_B_W_G = getattr(self, f"layer_{i}_W_B_W_G")(W_G)
            layer_W_B_W_A_1 = getattr(self, f"layer_{i}_W_B_W_A_1")(W_A)
            layer_W_B_W_A_2 = getattr(self, f"layer_{i}_W_B_W_A_2")(W_A)
            layer_W_B_W_A_3 = getattr(self, f"layer_{i}_W_B_W_A_3")(W_A)
            layer_W_B_W_B_1 = getattr(self, f"layer_{i}_W_B_W_B_1")(W_B)
            layer_W_B_W_B_2 = getattr(self, f"layer_{i}_W_B_W_B_2")(W_B)
            layer_W_B_W_B_3 = getattr(self, f"layer_{i}_W_B_W_B_3")(W_B)
            layer_W_B_b_G = getattr(self, f"layer_{i}_W_B_b_G")(b_G)
            layer_W_B_b_A_1 = getattr(self, f"layer_{i}_W_B_b_A_1")(b_A)
            layer_W_B_b_A_2 = getattr(self, f"layer_{i}_W_B_b_A_2")(b_A)
            layer_W_B_b_A_3 = getattr(self, f"layer_{i}_W_B_b_A_3")(b_A)
            layer_W_B_b_B_1 = getattr(self, f"layer_{i}_W_B_b_B_1")(b_B)
            layer_W_B_b_B_2 = getattr(self, f"layer_{i}_W_B_b_B_2")(b_B)
            layer_W_B_bias = getattr(self, f"layer_{i}_W_B_bias")()

            out_dict["W_B"].append(layer_W_B_W_QK + layer_W_B_W_VO + layer_W_B_W_G
                + layer_W_B_W_A_1 + layer_W_B_W_A_2  + layer_W_B_W_A_3
                + layer_W_B_W_B_1 + layer_W_B_W_B_2  + layer_W_B_W_B_3
                + layer_W_B_b_A_1 + layer_W_B_b_A_2  + layer_W_B_b_A_3
                + layer_W_B_b_G   + layer_W_B_b_B_1  + layer_W_B_b_B_2 + layer_W_B_bias)
            
            # b_G
            layer_b_G_W_QK = getattr(self, f"layer_{i}_b_G_W_QK")(WW_qk)
            layer_b_G_W_VO = getattr(self, f"layer_{i}_b_G_W_VO")(WW_vo)
            layer_b_G_W_G = getattr(self, f"layer_{i}_b_G_W_G")(W_G)
            layer_b_G_W_A_1 = getattr(self, f"layer_{i}_b_G_W_A_1")(W_A)
            layer_b_G_W_A_2 = getattr(self, f"layer_{i}_b_G_W_A_2")(W_A)
            layer_b_G_W_B_1 = getattr(self, f"layer_{i}_b_G_W_B_1")(W_B)
            layer_b_G_W_B_2 = getattr(self, f"layer_{i}_b_G_W_B_2")(W_B)
            layer_b_G_b_G = getattr(self, f"layer_{i}_b_G_b_G")(b_G)
            layer_b_G_id = torch.einsum('bdn -> bn', b_G).unsqueeze(1)
            layer_b_G_b_A_1 = getattr(self, f"layer_{i}_b_G_b_A_1")(b_A)
            layer_b_G_b_A_2 = getattr(self, f"layer_{i}_b_G_b_A_2")(b_A)
            layer_b_G_b_B_1 = getattr(self, f"layer_{i}_b_G_b_B_1")(b_B)
            layer_b_G_b_B_2 = getattr(self, f"layer_{i}_b_G_b_B_2")(b_B)
            layer_b_G_bias = getattr(self, f"layer_{i}_b_G_bias")()

            out_dict["b_G"].append(layer_b_G_W_QK + layer_b_G_W_VO + layer_b_G_W_G
                +layer_b_G_W_A_1 + layer_b_G_W_A_2 + layer_b_G_W_B_1 + layer_b_G_W_B_2 
                +layer_b_G_b_G + layer_b_G_id + layer_b_G_bias
                + layer_b_G_b_A_1 + layer_b_G_b_A_2 + layer_b_G_b_B_1 + layer_b_G_b_B_2)
            
            # b_A
            layer_b_A_W_QK = getattr(self, f"layer_{i}_b_A_W_QK")(WW_qk)
            layer_b_A_W_VO = getattr(self, f"layer_{i}_b_A_W_VO")(WW_vo)
            layer_b_A_W_G = getattr(self, f"layer_{i}_b_A_W_G")(W_G)
            layer_b_A_W_A_1 = getattr(self, f"layer_{i}_b_A_W_A_1")(W_A)
            layer_b_A_W_A_2 = getattr(self, f"layer_{i}_b_A_W_A_2")(W_A)
            layer_b_A_W_A_3 = getattr(self, f"layer_{i}_b_A_W_A_3")(W_A)
            layer_b_A_W_B_1 = getattr(self, f"layer_{i}_b_A_W_B_1")(W_B)
            layer_b_A_W_B_2 = getattr(self, f"layer_{i}_b_A_W_B_2")(W_B)
            layer_b_A_W_B_3 = getattr(self, f"layer_{i}_b_A_W_B_3")(W_B)
            layer_b_A_b_G = getattr(self, f"layer_{i}_b_A_b_G")(b_G)
            layer_b_A_b_A_1 = getattr(self, f"layer_{i}_b_A_b_A_1")(b_A)
            layer_b_A_b_A_2 = getattr(self, f"layer_{i}_b_A_b_A_2")(b_A)
            layer_b_A_b_A_3 = getattr(self, f"layer_{i}_b_A_b_A_3")(b_A)
            layer_b_A_b_B_1 = getattr(self, f"layer_{i}_b_A_b_B_1")(b_B)
            layer_b_A_b_B_2 = getattr(self, f"layer_{i}_b_A_b_B_2")(b_B)
            layer_b_A_bias = getattr(self, f"layer_{i}_b_A_bias")()

            out_dict["b_A"].append(layer_b_A_W_QK + layer_b_A_W_VO + layer_b_A_W_G
                + layer_b_A_W_A_1 + layer_b_A_W_A_2 + layer_b_A_W_A_3  
                + layer_b_A_W_B_1 + layer_b_A_W_B_2 + layer_b_A_W_B_3  
                + layer_b_A_b_A_1 + layer_b_A_b_A_2 + layer_b_A_b_A_3 
                + layer_b_A_b_B_1 + layer_b_A_b_B_2 + layer_b_A_b_G + layer_b_A_bias )
            # b_B
            layer_b_B_W_QK = getattr(self, f"layer_{i}_b_B_W_QK")(WW_qk)
            layer_b_B_W_VO = getattr(self, f"layer_{i}_b_B_W_VO")(WW_vo)
            layer_b_B_W_G = getattr(self, f"layer_{i}_b_B_W_G")(W_G)
            layer_b_B_W_A_1 = getattr(self, f"layer_{i}_b_B_W_A_1")(W_A)
            layer_b_B_W_A_2 = getattr(self, f"layer_{i}_b_B_W_A_2")(W_A)
            layer_b_B_W_B_1 = getattr(self, f"layer_{i}_b_B_W_B_1")(W_B)
            layer_b_B_W_B_2 = getattr(self, f"layer_{i}_b_B_W_B_2")(W_B)
            layer_b_B_b_G = getattr(self, f"layer_{i}_b_B_b_G")(b_G)
            layer_b_B_b_A_1 = getattr(self, f"layer_{i}_b_B_b_A_1")(b_A)
            layer_b_B_b_A_2 = getattr(self, f"layer_{i}_b_B_b_A_2")(b_A)
            layer_b_B_b_B_1 = getattr(self, f"layer_{i}_b_B_b_B_1")(b_B)
            layer_b_B_b_B_2 = getattr(self, f"layer_{i}_b_B_b_B_2")(b_B)
            layer_b_B_bias = getattr(self, f"layer_{i}_b_B_bias")()

            out_dict["b_B"].append(layer_b_B_W_QK + layer_b_B_W_VO + layer_b_B_W_G
                + layer_b_B_W_A_1 + layer_b_B_W_A_2 + layer_b_B_W_B_1 + layer_b_B_W_B_2 
                + layer_b_B_b_G + layer_b_B_b_A_1 + layer_b_B_b_A_2 
                + layer_b_B_b_B_1 + layer_b_B_b_B_2 + layer_b_B_bias)

        return MoEWeightSpaceFeatures(**out_dict)


class MoELinearInv(nn.Module):
    def __init__(self, encoder_weight_spec: MoENetworkSpec, in_channels, out_channels, out_dim_inv = 5, init_type="pytorch_default", scale_degree = 2, layer_norm= True):
        super().__init__()
        self.d, self.e = in_channels, out_channels
        self.encoder_weight_spec = encoder_weight_spec

        D, D_q, D_k, D_v, n_e, D_e, h = encoder_weight_spec.get_all_dims()
        self.L = len(encoder_weight_spec)
        D_inv = out_dim_inv
        for i in range(self.L):
            # Term 1:
            self.add_module(f"layer_{i}_QK",
                            EinsumLayer(
                                equation="bdhpq, edkpq -> bek",
                                weight_shape=[self.e, self.d, D_inv, D, D],
                                input_shape=[-1, self.d, h, D, D],
                                fan_in_mask=[0, 1, 1, 1, 1]
                            ))
            
            # Term 2:
            self.add_module(f"layer_{i}_VO",
                            EinsumLayer(
                                equation="bdhpq, edkpq -> bek",
                                weight_shape=[self.e, self.d, D_inv, D, D],
                                input_shape=[-1, self.d, h, D, D],
                                fan_in_mask=[0, 1, 1, 1, 1]
                            ))
            
            # Term 3:
            self.add_module(f"layer_{i}_W_G",
                            EinsumLayer(
                                equation="bdnp, edkp -> bek",
                                weight_shape=[self.e, self.d, D_inv, D],
                                input_shape=[-1, self.d, n_e, D],
                                fan_in_mask=[0, 1, 1, 1]
                            ))
            
            # Term 4:
            self.add_module(f"layer_{i}_W_A",
                            EinsumLayer(
                                equation="bdnpq, edkp -> bek",
                                weight_shape=[self.e, self.d, D_inv, D],
                                input_shape=[-1, self.d, n_e, D, D_e],
                                fan_in_mask=[0, 1, 1, 1, 1]
                            ))
            
            # Term 5:
            self.add_module(f"layer_{i}_W_B",
                            EinsumLayer(
                                equation="bdnpq, edkq -> bek",
                                weight_shape=[self.e, self.d, D_inv, D],
                                input_shape=[-1, self.d, n_e, D_e, D],
                                fan_in_mask=[0, 1, 1, 1, 1]
                            ))
            
            # Term 6:
            self.add_module(f"layer_{i}_b_G",
                            EinsumLayer(
                                equation="bdn, edk -> bek",
                                weight_shape=[self.e, self.d, D_inv],
                                input_shape=[-1, self.d, n_e],
                                fan_in_mask=[0, 1, 1]
                            ))

            # Term 7:
            self.add_module(f"layer_{i}_b_A",
                            EinsumLayer(
                                equation="bdnq, edk -> bek",
                                weight_shape=[self.e, self.d, D_inv],
                                input_shape=[-1, self.d, n_e, D_e],
                                fan_in_mask=[0, 1, 1, 1]
                            ))
            
            # Term 8:
            self.add_module(f"layer_{i}_b_B",
                            EinsumLayer(
                                equation="bdnq, edkq -> bek",
                                weight_shape=[self.e, self.d, D_inv, D],
                                input_shape=[-1, self.d, n_e, D],
                                fan_in_mask=[0, 1, 1, 1]
                            ))

            # Term 9:
            self.add_module(f"layer_{i}_bias",
                            EinsumLayer(
                                equation=" ek -> ek",
                                weight_shape=[self.e, D_inv],
                                input_shape=[self.e, D_inv],
                                fan_in_mask=[0, 0],
                                unsqueeze_dims=[0]
                            ))
            #set init einsum for all terms
            set_init_einsum_(getattr(self, f"layer_{i}_b_A"),
                                getattr(self, f"layer_{i}_b_B"),
                                getattr(self, f"layer_{i}_bias"), 
                                init_type=init_type)
            set_init_einsum_(getattr(self, f"layer_{i}_QK"),
                                getattr(self, f"layer_{i}_VO"),
                                getattr(self, f"layer_{i}_W_G"),
                                getattr(self, f"layer_{i}_W_A"),
                                getattr(self, f"layer_{i}_W_B"),
                                getattr(self, f"layer_{i}_b_G"),
                                init_type=init_type, scale_degree=scale_degree)
            self.layer_norm= layer_norm
            if layer_norm:
                self.add_module(f"LayerNorm_{i}", nn.LayerNorm(out_channels*out_dim_inv))

    def forward(self, wsfeat: MoEWeightSpaceFeatures):
        out = []

        for i in range(self.L):
            W_q, W_k, W_v, W_o, W_G, W_A, W_B, b_G, b_A, b_B = wsfeat[i]

            WW_qk = torch.einsum('bdhpk, bdhqk -> bdhpq', W_q, W_k)
            WW_vo = torch.einsum('bdhpk, bdhkq -> bdhpq', W_v, W_o)

            layer_QK = getattr(self, f"layer_{i}_QK")(WW_qk)  # Output: [b, e, k]
            layer_VO = getattr(self, f"layer_{i}_VO")(WW_vo)  # Output: [b, e, k]
            layer_W_G = getattr(self, f"layer_{i}_W_G")(W_G)
            layer_W_A = getattr(self, f"layer_{i}_W_A")(W_A)  # Output: [b, e, k]
            layer_W_B = getattr(self, f"layer_{i}_W_B")(W_B)  # Output: [b, e, k]
            layer_b_G = getattr(self, f"layer_{i}_b_G")(b_G)
            layer_b_A = getattr(self, f"layer_{i}_b_A")(b_A)  # Output: [b, e, k]
            layer_b_B = getattr(self, f"layer_{i}_b_B")(b_B)  # Output: [b, e, k]
            layer_bias = getattr(self, f"layer_{i}_bias")()      # Output: [b, e, k]

            I_U = layer_QK + layer_VO + layer_W_A + layer_W_B  + layer_b_A + layer_b_B + layer_bias 
            + layer_W_G + layer_b_G\
            #+ layer_W_G + layer_b_G
            I_U = I_U.reshape(I_U.shape[0], -1)
            if self.layer_norm:
                I_U = getattr(self, f"LayerNorm_{i}")(I_U)
            out.append(I_U)
              
        return rearrange(out, 'L b ek -> b (L ek)')