# -*- coding: utf-8 -*-
from math import sqrt
from typing import List, Optional, Union

import torch
from e3nn import o3
from e3nn.o3._tensor_product._instruction import Instruction
from e3nn.util import prod
from e3nn.util.jit import compile_mode
from torch import nn
from e3nn.util.codegen import CodeGenMixin
import torch
from torch import fx
from collections import OrderedDict

from e3nn.util import prod
from e3nn import o3
from opt_einsum_fx import optimize_einsums_full

def _sum_tensors(
    xs: List[torch.Tensor], shape: torch.Size, like: torch.Tensor
):
    if len(xs) > 0:
        out = xs[0]
        for x in xs[1:]:
            out = out + x
        return out
    return like.new_zeros(shape)



def slices_basis(irreps):
    r"""List of slices corresponding to indices for each irrep.

    Examples
    --------

    >>> Irreps('2x0e + 1e').slices()
    [slice(0, 2, None), slice(2, 5, None)]
    """
    s = []
    i = 0
    for mul_ir in irreps:
        s.append(slice(i, i + mul_ir[1][0]*2+1))
        i += mul_ir[1][0]*2+1
    return s


def CODEGEN_MAIN_LEFT_RIGHT(self__irreps_in1,
        self__irreps_in2,
        self__irreps_out,
        self__instructions,
                     ) -> fx.GraphModule:
    # 初始化输出形状
    graph = fx.Graph()

    # = Function definitions =
    tracer = fx.proxy.GraphAppendingTracer(graph)
    constants = OrderedDict()

    x1s = fx.Proxy(graph.placeholder("x1", torch.Tensor), tracer=tracer)
    x2s = fx.Proxy(graph.placeholder("x2", torch.Tensor), tracer=tracer)
    weights = fx.Proxy(graph.placeholder("w", torch.Tensor), tracer=tracer)

    empty = fx.Proxy(graph.call_function(torch.empty, ((),), dict(device="cpu")), tracer=tracer)
    output_shape = x1s.shape[:-2]
    # torch.broadcast_tensors(
    #     empty.expand(x1s.shape[:-2]), empty.expand(x2s.shape[:-2])
    # )[0].shape
    del empty


    # # 广播输入张量
    # x1s, x2s = x1s.broadcast_to(output_shape + (-1,-1)), x2s.broadcast_to(
    #     output_shape + (-1,-1)
    # )
    x1s = x1s.reshape(-1, self__irreps_in1.dim//self__irreps_in1[0].mul, self__irreps_in1[0].mul)
    x2s = x2s.reshape(-1, self__irreps_in2.dim//self__irreps_in2[0].mul, self__irreps_in2[0].mul)
    # x1s = x1s.reshape(output_shape.numel(), -1, self__irreps_in1[0].mul)
    # x2s = x2s.reshape(output_shape.numel(), -1, self__irreps_in2[0].mul)
    batch_numel = x1s.shape[0]

    # 提取每个不可约表示的输入
    x1_list = [
        x1s[:, i].reshape(batch_numel, mul_ir.ir.dim, mul_ir.mul)
        for i, mul_ir in zip(slices_basis(self__irreps_in1), self__irreps_in1)
    ]
    x2_list = [
        x2s[:, i].reshape(batch_numel, mul_ir.ir.dim, mul_ir.mul)
        for i, mul_ir in zip(slices_basis(self__irreps_in2), self__irreps_in2)
    ]

    outputs = []
    flat_weight_index = 0

    # 仅处理 "uvu" 的情况
    for idx,ins in enumerate(self__instructions):
        mul_ir_in1 = self__irreps_in1[ins.i_in1]
        mul_ir_in2 = self__irreps_in2[ins.i_in2]
        mul_ir_out = self__irreps_out[ins.i_out]

        # 跳过维度为 0 的情况
        # if mul_ir_in1.dim == 0 or mul_ir_in2.dim == 0 or mul_ir_out.dim == 0:
        #     continue

        x1 = x1_list[ins.i_in1]
        x2 = x2_list[ins.i_in2]
        
        # Create a proxy & request for the relevant wigner w3j
        # If not used (because of specialized code), will get removed later.
        w3j_name = f"_w3j_{mul_ir_in1.ir.l}_{mul_ir_in2.ir.l}_{mul_ir_out.ir.l}"
        w3j = fx.Proxy(graph.get_attr(w3j_name), tracer=tracer)



        l1l2l3 = (mul_ir_in1.ir.l, mul_ir_in2.ir.l, mul_ir_out.ir.l)

        if ins.has_weight:
            # Extract the weight from the flattened weight tensor
            w = weights[ flat_weight_index : flat_weight_index + prod(ins.path_shape)].reshape(
                tuple(ins.path_shape))
    
            flat_weight_index += prod(ins.path_shape)
        # # 取相应的weights
        # w = weights[
        #     :, flat_weight_index : flat_weight_index + prod(ins.path_shape)
        # ].reshape((-1,) + tuple(ins.path_shape))
        # flat_weight_index += prod(ins.path_shape)

        if ins.connection_mode == "uvw":
            xx = torch.einsum("ziu,zjv,ijk->zku", x1, x2, w3j)
            assert ins.has_weight #and x2.shape[-1]==1
            w = w.squeeze()
            result = torch.matmul(xx,w) 
            # result = torch.einsum(f"zuvw,zkuv->zkw", w, xx)
        elif ins.connection_mode == "uvu":
            assert mul_ir_in1.mul == mul_ir_out.mul
            # not so useful operation because v is summed
            if ins.has_weight:
                xx = torch.einsum("ziu,zjv,ijk->zkuv", x1, x2, w3j)
                result = torch.einsum("uv,zkuv->zku", w, xx)
            else:
                result = torch.einsum("ziu,zjv,ijk->zku", x1, x2, w3j)
                # result = torch.sum(xx, dim = -1)
        elif ins.connection_mode == "uuu":
            result = torch.einsum("ziu,zju,ijk->zku", x1, x2, w3j)
            # result = torch.sum(xx, dim = -1)
        
                
        # # 计算结果
        # result = torch.einsum("zuv,ijk,zuvij->zuk", w, w3j, xx)

        result = ins.path_weight * result
        outputs.append(result.reshape(batch_numel, mul_ir_out.ir.l * 2 + 1, mul_ir_out.mul))
        
        
        w3j_name = f"_w3j_{mul_ir_in1.ir.l}_{mul_ir_in2.ir.l}_{mul_ir_out.ir.l}"
        # Remove unused w3js:
        if len(w3j.node.users) == 0:
            # The w3j nodes are reshapes, so we have to remove them from the graph
            # Although they are dead code, they try to reshape to dimensions that don't exist
            # (since the corresponding w3js are not in w3j)
            # so they screw up the shape propagation, even though they would be removed later as dead code by TorchScript.
            graph.erase_node(w3j.node)
        else:
            if w3j_name not in constants:
                constants[w3j_name] = o3.wigner_3j(mul_ir_in1.ir.l, mul_ir_in2.ir.l, mul_ir_out.ir.l)

    # 汇总输出
    outputs = [
        _sum_tensors(
            [
                out
                for ins, out in zip(self__instructions, outputs)
                if ins.i_out == i_out
            ],
            shape=(batch_numel, mul_ir_out.dim),
            like=x1s,
        )
        for i_out, mul_ir_out in enumerate(self__irreps_out)
        if mul_ir_out.mul > 0
    ]
    
    outputs = torch.cat(outputs, dim=1) if len(outputs) > 1 else outputs[0]
    outputs = outputs.reshape(output_shape+(outputs.shape[-2], outputs.shape[-1]))

    graph.output(outputs.node, torch.Tensor)

    # check graphs
    graph.lint()

    # Make GraphModules

    # By putting the constants in a Module rather than a dict,
    # we force FX to copy them as buffers instead of as attributes.
    #
    # FX seems to have resolved this issue for dicts in 1.9, but we support all the way back to 1.8.0.
    constants_root = torch.nn.Module()
    for key, value in constants.items():
        constants_root.register_buffer(key, value)
    graphmod = fx.GraphModule(constants_root, graph, class_name="tp_forward")

    # == Optimize ==
    # TODO: when eliminate_dead_code() is in PyTorch stable, use that
    batchdim = 4
    example_inputs = (
        torch.zeros((batchdim, self__irreps_in1.dim//self__irreps_in1[0].mul, self__irreps_in1[0].mul)),
        torch.zeros((batchdim, self__irreps_in2.dim//self__irreps_in2[0].mul, self__irreps_in2[0].mul)),
        torch.zeros(
            flat_weight_index,
        ),
    )
    graphmod = optimize_einsums_full(graphmod, example_inputs)

    return graphmod




class Simple_TensorProduct(torch.nn.Module):
    def __init__(
        self,
        irreps_in1: o3.Irreps,
        irreps_in2: o3.Irreps,
        irreps_out: o3.Irreps,
        instructions: List[tuple],
        in1_var: Optional[Union[List[float], torch.Tensor]] = None,
        in2_var: Optional[Union[List[float], torch.Tensor]] = None,
        out_var: Optional[Union[List[float], torch.Tensor]] = None,
        irrep_normalization: str = "component",
        path_normalization: str = "element",
        internal_weights=True,
        path_weight_sqrt=True,
        rescale = True,
        use_bias=False
    ):
        super().__init__()
        self.rescale = rescale,

        self.use_bias = use_bias
        if self.use_bias:
            raise ValueError("Not implemented yet, the bias is for order 0 irreps only, it  only works for TensorProductRescale")
        self.irreps_in1 = o3.Irreps(irreps_in1)
        self.irreps_in2 = o3.Irreps(irreps_in2)
        self.irreps_out = o3.Irreps(irreps_out)

        instructions = [x if len(x) == 6 else x + (1.0,) for x in instructions]
        instructions = [
            Instruction(
                i_in1=i_in1,
                i_in2=i_in2,
                i_out=i_out,
                connection_mode=connection_mode,
                has_weight=has_weight,
                path_weight=path_weight,
                path_shape={
                    "uvw": (self.irreps_in1[i_in1].mul, self.irreps_in2[i_in2].mul, self.irreps_out[i_out].mul),
                    "uvu": (self.irreps_in1[i_in1].mul, self.irreps_in2[i_in2].mul),
                    "uvv": (self.irreps_in1[i_in1].mul, self.irreps_in2[i_in2].mul),
                    "uuw": (self.irreps_in1[i_in1].mul, self.irreps_out[i_out].mul),
                    "uuu": (self.irreps_in1[i_in1].mul,),
                    "uvuv": (self.irreps_in1[i_in1].mul, self.irreps_in2[i_in2].mul),
                    "uvu<v": (self.irreps_in1[i_in1].mul * (self.irreps_in2[i_in2].mul - 1) // 2,),
                    "u<vw": (self.irreps_in1[i_in1].mul * (self.irreps_in2[i_in2].mul - 1) // 2, self.irreps_out[i_out].mul),
                }[connection_mode],
            )
            for i_in1, i_in2, i_out, connection_mode, has_weight, path_weight in instructions
        ]

        if in1_var is None:
            in1_var = [1.0 for _ in range(len(self.irreps_in1))]

        if in2_var is None:
            in2_var = [1.0 for _ in range(len(self.irreps_in2))]

        if out_var is None:
            out_var = [1.0 for _ in range(len(self.irreps_out))]

        def num_elements(ins):
            return {
                "uvw": (
                    self.irreps_in1[ins.i_in1].mul * self.irreps_in2[ins.i_in2].mul
                ),
                "uvu": self.irreps_in2[ins.i_in2].mul,
                "uuu": 1,
            }[ins.connection_mode]

        normalization_coefficients = []
        for ins in instructions:
            mul_ir_in1 = self.irreps_in1[ins.i_in1]
            mul_ir_in2 = self.irreps_in2[ins.i_in2]
            mul_ir_out = self.irreps_out[ins.i_out]

            if irrep_normalization == "component":
                alpha = mul_ir_out.ir.dim
            if irrep_normalization == "norm":
                alpha = mul_ir_in1.ir.dim * mul_ir_in2.ir.dim
            if irrep_normalization == "none":
                alpha = 1

            if path_normalization == "element":
                x = sum(in1_var[i.i_in1] * in2_var[i.i_in2] * num_elements(i) for i in instructions if i.i_out == ins.i_out)
            if path_normalization == "path":
                x = in1_var[ins.i_in1] * in2_var[ins.i_in2] * num_elements(ins)
                x *= len([i for i in instructions if i.i_out == ins.i_out])
            if path_normalization == "none":
                x = 1


            alpha /= x
            alpha *= out_var[ins.i_out]
            alpha = sqrt(alpha)
            alpha *= sqrt(ins.path_weight) if path_weight_sqrt else ins.path_weight
            normalization_coefficients += [alpha]

        self.instructions = [
            Instruction(
                ins.i_in1,
                ins.i_in2,
                ins.i_out,
                ins.connection_mode,
                ins.has_weight,
                alpha,
                ins.path_shape,
            )
            for ins, alpha in zip(instructions, normalization_coefficients)
        ]

        self._in1_dim = self.irreps_in1.dim
        self._in2_dim = self.irreps_in2.dim

        self.weight_numel = sum(
            prod(ins.path_shape) for ins in self.instructions if ins.has_weight
        )

        self.internal_weights = internal_weights
        self.shared_weights = internal_weights
        if internal_weights and self.weight_numel > 0:
            assert self.shared_weights, "Having internal weights impose shared weights"
            self.weight = torch.nn.Parameter(torch.randn(self.weight_numel))
        else:
            # For TorchScript, there always has to be some kind of defined .weight
            self.register_buffer("weight", torch.Tensor())

    
        self.init_rescale_bias()

    def calculate_fan_in(self, ins):
        return {
            "uvw": (self.irreps_in1[ins.i_in1].mul * self.irreps_in2[ins.i_in2].mul),
            "uvu": self.irreps_in2[ins.i_in2].mul,
            "uvv": self.irreps_in1[ins.i_in1].mul,
            "uuw": self.irreps_in1[ins.i_in1].mul,
            "uuu": 1,
            "uvuv": 1,
            "uvu<v": 1,
            "u<vw": self.irreps_in1[ins.i_in1].mul
            * (self.irreps_in2[ins.i_in2].mul - 1)
            // 2,
        }[ins.connection_mode]

    def init_rescale_bias(self) -> None:
        irreps_out = self.irreps_out
        # For each zeroth order output irrep we need a bias
        # Determine the order for each output tensor and their dims
        self.irreps_out_orders = [
            int(irrep_str[-2]) for irrep_str in str(irreps_out).split("+")
        ]
        self.irreps_out_dims = [
            int(irrep_str.split("x")[0]) for irrep_str in str(irreps_out).split("+")
        ]
        self.irreps_out_slices = irreps_out.slices()

        # Store tuples of slices and corresponding biases in a list
        self.bias = None
        self.bias_slices = []
        self.bias_slice_idx = []
        self.irreps_bias = self.irreps_out.simplify()
        self.irreps_bias_orders = [
            int(irrep_str[-2]) for irrep_str in str(self.irreps_bias).split("+")
        ]
        self.irreps_bias_parity = [
            irrep_str[-1] for irrep_str in str(self.irreps_bias).split("+")
        ]
        self.irreps_bias_dims = [
            int(irrep_str.split("x")[0])
            for irrep_str in str(self.irreps_bias).split("+")
        ]
        if self.use_bias:
            self.bias = []
            for slice_idx in range(len(self.irreps_bias_orders)):
                if (
                    self.irreps_bias_orders[slice_idx] == 0
                    and self.irreps_bias_parity[slice_idx] == "e"
                ):
                    out_slice = self.irreps_bias.slices()[slice_idx]
                    out_bias = torch.nn.Parameter(
                        torch.zeros(
                            self.irreps_bias_dims[slice_idx], dtype=self.weight.dtype
                        )
                    )
                    self.bias += [out_bias]
                    self.bias_slices += [out_slice]
                    self.bias_slice_idx += [slice_idx]
        self.bias = torch.nn.ParameterList(self.bias)

        self.slices_sqrt_k = {}
        with torch.no_grad():
            # Determine fan_in for each slice, it could be that each output slice is updated via several instructions
            slices_fan_in = {}  # fan_in per slice
            for instr in self.instructions:
                slice_idx = instr[2]
                fan_in = self.calculate_fan_in(instr)
                slices_fan_in[slice_idx] = (
                    slices_fan_in[slice_idx] + fan_in
                    if slice_idx in slices_fan_in.keys()
                    else fan_in
                )
            for instr in self.instructions:
                slice_idx = instr[2]
                if self.rescale:
                    sqrt_k = 1 / slices_fan_in[slice_idx] ** 0.5
                else:
                    sqrt_k = 1.0
                self.slices_sqrt_k[slice_idx] = (
                    self.irreps_out_slices[slice_idx],
                    sqrt_k,
                )

            # Re-initialize weights in each instruction
            if self.internal_weights:
                for weight, instr in zip(self.weight_views(), self.instructions):
                    # The tensor product in e3nn already normalizes proportional to 1 / sqrt(fan_in), and the weights are by
                    # default initialized with unif(-1,1). However, we want to be consistent with torch.nn.Linear and
                    # initialize the weights with unif(-sqrt(k),sqrt(k)), with k = 1 / fan_in
                    slice_idx = instr[2]
                    if self.rescale:
                        sqrt_k = 1 / slices_fan_in[slice_idx] ** 0.5
                        weight.data.mul_(sqrt_k)
                    # else:
                    #    sqrt_k = 1.
                    #
                    # if self.rescale:
                    # weight.data.uniform_(-sqrt_k, sqrt_k)
                    #    weight.data.mul_(sqrt_k)
                    # self.slices_sqrt_k[slice_idx] = (self.irreps_out_slices[slice_idx], sqrt_k)

            # Initialize the biases
            # for (out_slice_idx, out_slice, out_bias) in zip(self.bias_slice_idx, self.bias_slices, self.bias):
            #    sqrt_k = 1 / slices_fan_in[out_slice_idx] ** 0.5
            #    out_bias.uniform_(-sqrt_k, sqrt_k)



    @torch.jit.unused
    def _prep_weights_python(
        self, weight: Optional[Union[torch.Tensor, List[torch.Tensor]]]
    ) -> Optional[torch.Tensor]:
        if isinstance(weight, list):
            weight_shapes = [
                ins.path_shape for ins in self.instructions if ins.has_weight
            ]
            if not self.shared_weights:
                weight = [
                    w.reshape(-1, prod(shape))
                    for w, shape in zip(weight, weight_shapes)
                ]
            else:
                weight = [
                    w.reshape(prod(shape)) for w, shape in zip(weight, weight_shapes)
                ]
            return torch.cat(weight, dim=-1)
        else:
            return weight

    def _get_weights(self, weight: Optional[torch.Tensor]) -> torch.Tensor:
        if not torch.jit.is_scripting():
            # If we're not scripting, then we're in Python and `weight` could be a List[Tensor]
            # deal with that:
            weight = self._prep_weights_python(weight)
        if weight is None:
            if self.weight_numel > 0 and not self.internal_weights:
                raise RuntimeError(
                    "Weights must be provided when the TensorProduct does not have `internal_weights`"
                )
            return self.weight
        else:
            if self.shared_weights:
                assert weight.shape == (self.weight_numel,), "Invalid weight shape"
            else:
                assert weight.shape[-1] == self.weight_numel, "Invalid weight shape"
                assert (
                    weight.ndim > 1
                ), "When shared weights is false, weights must have batch dimension"
            return weight

    def weight_views(
        self, weight: Optional[torch.Tensor] = None, yield_instruction: bool = False
    ):
        r"""Iterator over weight views for each weighted instruction.

        Parameters
        ----------
        weight : `torch.Tensor`, optional
            like ``weight`` argument to ``forward()``

        yield_instruction : `bool`, default False
            Whether to also yield the corresponding instruction.

        Yields
        ------
        If ``yield_instruction`` is ``True``, yields ``(instruction_index, instruction, weight_view)``.
        Otherwise, yields ``weight_view``.
        """
        weight = self._get_weights(weight)
        batchshape = weight.shape[:-1]
        offset = 0
        for ins_i, ins in enumerate(self.instructions):
            if ins.has_weight:
                flatsize = prod(ins.path_shape)
                this_weight = weight.narrow(-1, offset, flatsize).view(
                    batchshape + ins.path_shape
                )
                offset += flatsize
                if yield_instruction:
                    yield ins_i, ins, this_weight
                else:
                    yield this_weight

    def _sum_tensors(
        self, xs: List[torch.Tensor], shape: torch.Size, like: torch.Tensor
    ):
        if len(xs) > 0:
            out = xs[0]
            for x in xs[1:]:
                out = out + x
            return out
        return like.new_zeros(shape)

    def _main_left_right(self, x1s, x2s, weights):
        # 初始化输出形状
        empty = torch.empty((), device="cpu")
        output_shape = torch.broadcast_tensors(
            empty.expand(x1s.shape[:-1]), empty.expand(x2s.shape[:-1])
        )[0].shape
        del empty

        # 广播输入张量
        x1s, x2s = x1s.broadcast_to(output_shape + (-1,)), x2s.broadcast_to(
            output_shape + (-1,)
        )
        output_shape = output_shape + (self.irreps_out.dim,)
        x1s = x1s.reshape(-1, self.irreps_in1.dim)
        x2s = x2s.reshape(-1, self.irreps_in2.dim)
        batch_numel = x1s.shape[0]

        # 提取和重塑权重
        if self.weight_numel > 0:
            weights = weights.reshape(-1, self.weight_numel)

        # 提取每个不可约表示的输入
        x1_list = [
            x1s[:, i].reshape(batch_numel, mul_ir.mul, mul_ir.ir.dim)
            for i, mul_ir in zip(self.irreps_in1.slices(), self.irreps_in1)
        ]
        x2_list = [
            x2s[:, i].reshape(batch_numel, mul_ir.mul, mul_ir.ir.dim)
            for i, mul_ir in zip(self.irreps_in2.slices(), self.irreps_in2)
        ]

        outputs = []
        flat_weight_index = 0

        # 仅处理 "uvu" 的情况
        for ins in self.instructions:
            mul_ir_in1 = self.irreps_in1[ins.i_in1]
            mul_ir_in2 = self.irreps_in2[ins.i_in2]
            mul_ir_out = self.irreps_out[ins.i_out]

            # 跳过维度为 0 的情况
            # if mul_ir_in1.dim == 0 or mul_ir_in2.dim == 0 or mul_ir_out.dim == 0:
            #     continue

            x1 = x1_list[ins.i_in1]
            x2 = x2_list[ins.i_in2]

            # 计算 x1 和 x2 的外积
            xx = torch.einsum("zui,zvj->zuvij", x1, x2)
            # 获取 Wigner 3j 符号
            w3j = o3.wigner_3j(mul_ir_in1.ir.l, mul_ir_in2.ir.l, mul_ir_out.ir.l).to(
                x1s.device
            )
            
            l1l2l3 = (mul_ir_in1.ir.l, mul_ir_in2.ir.l, mul_ir_out.ir.l)

            if ins.has_weight:
                # Extract the weight from the flattened weight tensor
                w = weights[:, flat_weight_index : flat_weight_index + prod(ins.path_shape)].reshape(
                    (-1,) + tuple(ins.path_shape))
                
                flat_weight_index += prod(ins.path_shape)
            
            # # 取相应的weights
            # w = weights[
            #     :, flat_weight_index : flat_weight_index + prod(ins.path_shape)
            # ].reshape((-1,) + tuple(ins.path_shape))
            # flat_weight_index += prod(ins.path_shape)
 
            if ins.connection_mode == "uvw":
                assert ins.has_weight
                if l1l2l3 == (0, 0, 0):
                    result = torch.einsum(
                        f"zuvw,zu,zv->zw", w, x1.reshape(batch_numel, mul_ir_in1.dim), x2.reshape(batch_numel, mul_ir_in2.dim)
                    )
                elif mul_ir_in1.ir.l == 0:
                    result = torch.einsum(f"zuvw,zu,zvj->zwj", w, x1.reshape(batch_numel, mul_ir_in1.dim), x2) / sqrt(
                        mul_ir_out.ir.dim
                    )
                elif mul_ir_in2.ir.l == 0:
                    result = torch.einsum(f"zuvw,zui,zv->zwi", w, x1, x2.reshape(batch_numel, mul_ir_in2.dim)) / sqrt(
                        mul_ir_out.ir.dim
                    )
                elif mul_ir_out.ir.l == 0:
                    result = torch.einsum(f"zuvw,zui,zvi->zw", w, x1, x2) / sqrt(mul_ir_in1.ir.dim)
                else:
                    result = torch.einsum(f"zuvw,ijk,zuvij->zwk", w, w3j, xx)
            if ins.connection_mode == "uvu":
                assert mul_ir_in1.mul == mul_ir_out.mul
                if ins.has_weight:
                    if l1l2l3 == (0, 0, 0):
                        result = torch.einsum(
                            f"zuv,zu,zv->zu", w, x1.reshape(batch_numel, mul_ir_in1.dim), x2.reshape(batch_numel, mul_ir_in2.dim)
                        )
                    elif mul_ir_in1.ir.l == 0:
                        result = torch.einsum(f"zuv,zu,zvj->zuj", w, x1.reshape(batch_numel, mul_ir_in1.dim), x2) / sqrt(
                            mul_ir_out.ir.dim
                        )
                    elif mul_ir_in2.ir.l == 0:
                        result = torch.einsum(f"zuv,zui,zv->zui", w, x1, x2.reshape(batch_numel, mul_ir_in2.dim)) / sqrt(
                            mul_ir_out.ir.dim
                        )
                    elif mul_ir_out.ir.l == 0:
                        result = torch.einsum(f"zuv,zui,zvi->zu", w, x1, x2) / sqrt(mul_ir_in1.ir.dim)
                    else:
                        result = torch.einsum(f"zuv,ijk,zuvij->zuk", w, w3j, xx)
                else:
                    # not so useful operation because v is summed
                    result = torch.einsum("ijk,zuvij->zuk", w3j, xx)
            # # 计算结果
            # result = torch.einsum("zuv,ijk,zuvij->zuk", w, w3j, xx)

            result = ins.path_weight * result
            outputs.append(result.reshape(batch_numel, mul_ir_out.dim))

        # 汇总输出
        outputs = [
            self._sum_tensors(
                [
                    out
                    for ins, out in zip(self.instructions, outputs)
                    if ins.i_out == i_out
                ],
                shape=(batch_numel, mul_ir_out.dim),
                like=x1s,
            )
            for i_out, mul_ir_out in enumerate(self.irreps_out)
            if mul_ir_out.mul > 0
        ]
        outputs = torch.cat(outputs, dim=1) if len(outputs) > 1 else outputs[0]
        return outputs.reshape(output_shape)

    def forward(self, x, y, weight: Optional[torch.Tensor] = None):
        assert x.shape[-1] == self._in1_dim, "Incorrect last dimension for x"
        assert y.shape[-1] == self._in2_dim, "Incorrect last dimension for y"
        weight = self.weight
        return self._main_left_right(x, y, weight)




class Simple_TensorProduct_oTchannel(torch.nn.Module,CodeGenMixin):
    def __init__(
        self,
        irreps_in1: o3.Irreps,
        irreps_in2: o3.Irreps,
        irreps_out: o3.Irreps,
        instructions: List[tuple]=None,
        learnable_weight = None,
        connection_mode = 'uvu',
        reduce_same_order = False,
        in1_var: Optional[Union[List[float], torch.Tensor]] = None,
        in2_var: Optional[Union[List[float], torch.Tensor]] = None,
        out_var: Optional[Union[List[float], torch.Tensor]] = None,
        irrep_normalization: str = "component",
        path_normalization: str = "element",
        internal_weights=True,
        path_weight_sqrt=True,
        rescale = True,
        use_bias=False
    ):
        super().__init__()
        self.rescale = rescale

        self.use_bias = use_bias
        if self.use_bias:
            raise ValueError("Not implemented yet, the bias is for order 0 irreps only, it  only works for TensorProductRescale")
        self.irreps_in1 = o3.Irreps(irreps_in1) if isinstance(irreps_in1,str) else irreps_in1
        self.irreps_in2 = o3.Irreps(irreps_in2) if isinstance(irreps_in2,str) else irreps_in2
        self.irreps_out = o3.Irreps(irreps_out) if isinstance(irreps_out,str) else irreps_out

        for i in range(1,len(self.irreps_in1)):
            if self.irreps_in1[i][0]!=self.irreps_in1[i-1][0]:
                raise ValueError("The input channel must have the same channel")
        
        if instructions is None and learnable_weight is None:
            raise ValueError("please set instructions or learable weight")
        if instructions is not None and learnable_weight is not None:
            raise ValueError("please set instructions or learable weight")
        if instructions is None:
            instructions,irreps_output = self._get_instruction(irreps_in1,
                                                 irreps_in2,
                                                 irreps_out,
                                                 learnable_weight=learnable_weight,
                                                 connection_mode=connection_mode)
            self.irreps_out = irreps_output
        instructions = [x if len(x) == 6 else x + (1.0,) for x in instructions]
        instructions = [
            Instruction(
                i_in1=i_in1,
                i_in2=i_in2,
                i_out=i_out,
                connection_mode=connection_mode,
                has_weight=has_weight,
                path_weight=path_weight,
                path_shape={
                    "uvw": (self.irreps_in1[i_in1].mul, self.irreps_in2[i_in2].mul, self.irreps_out[i_out].mul),
                    "uvu": (self.irreps_in1[i_in1].mul, self.irreps_in2[i_in2].mul),
                    "uvv": (self.irreps_in1[i_in1].mul, self.irreps_in2[i_in2].mul),
                    "uuw": (self.irreps_in1[i_in1].mul, self.irreps_out[i_out].mul),
                    "uuu": (self.irreps_in1[i_in1].mul,),
                    "uvuv": (self.irreps_in1[i_in1].mul, self.irreps_in2[i_in2].mul),
                    "uvu<v": (self.irreps_in1[i_in1].mul * (self.irreps_in2[i_in2].mul - 1) // 2,),
                    "u<vw": (self.irreps_in1[i_in1].mul * (self.irreps_in2[i_in2].mul - 1) // 2, self.irreps_out[i_out].mul),
                }[connection_mode],
            )
            for i_in1, i_in2, i_out, connection_mode, has_weight, path_weight in instructions
        ]

        if in1_var is None:
            in1_var = [1.0 for _ in range(len(self.irreps_in1))]

        if in2_var is None:
            in2_var = [1.0 for _ in range(len(self.irreps_in2))]

        if out_var is None:
            out_var = [1.0 for _ in range(len(self.irreps_out))]

        def num_elements(ins):
            return {
                "uvw": (
                    self.irreps_in1[ins.i_in1].mul * self.irreps_in2[ins.i_in2].mul
                ),
                "uvu": self.irreps_in2[ins.i_in2].mul,
                "uuu": self.irreps_in2[ins.i_in2].mul,
            }[ins.connection_mode]

        normalization_coefficients = []
        for ins in instructions:
            mul_ir_in1 = self.irreps_in1[ins.i_in1]
            mul_ir_in2 = self.irreps_in2[ins.i_in2]
            mul_ir_out = self.irreps_out[ins.i_out]

            if irrep_normalization == "component":
                alpha = mul_ir_out.ir.dim
            if irrep_normalization == "norm":
                alpha = mul_ir_in1.ir.dim * mul_ir_in2.ir.dim
            if irrep_normalization == "none":
                alpha = 1

            if path_normalization == "element":
                x = sum(in1_var[i.i_in1] * in2_var[i.i_in2] * num_elements(i) for i in instructions if i.i_out == ins.i_out)
            if path_normalization == "path":
                x = in1_var[ins.i_in1] * in2_var[ins.i_in2] * num_elements(ins)
                x *= len([i for i in instructions if i.i_out == ins.i_out])
            if path_normalization == "none":
                x = 1


            alpha /= x
            alpha *= out_var[ins.i_out]
            alpha = sqrt(alpha)
            alpha *= sqrt(ins.path_weight) if path_weight_sqrt else ins.path_weight
            normalization_coefficients += [alpha]

        self.instructions = [
            Instruction(
                ins.i_in1,
                ins.i_in2,
                ins.i_out,
                ins.connection_mode,
                ins.has_weight,
                alpha,
                ins.path_shape,
            )
            for ins, alpha in zip(instructions, normalization_coefficients)
        ]

        self._in1_dim = self.irreps_in1.dim
        self._in2_dim = self.irreps_in2.dim

        self.weight_numel = sum(
            prod(ins.path_shape) for ins in self.instructions if ins.has_weight
        )

        self.internal_weights = internal_weights
        self.shared_weights = internal_weights
        if internal_weights and self.weight_numel > 0:
            assert self.shared_weights, "Having internal weights impose shared weights"
            self.weight = torch.nn.Parameter(torch.randn(self.weight_numel))
        else:
            # For TorchScript, there always has to be some kind of defined .weight
            self.register_buffer("weight", torch.Tensor())

    
        self.init_rescale_bias()

        graphmod_left_right = CODEGEN_MAIN_LEFT_RIGHT(self.irreps_in1,
            self.irreps_in2,
            self.irreps_out,
            self.instructions)
            
                
        assert graphmod_left_right is not None    
        self._codegen_register({"_compiled_main_left_right": graphmod_left_right})

    # this is specific tensor product without reduce operation on same order
    def _get_instruction(self,input1,input2,output,
                         learnable_weight = True,
                        connection_mode = 'uvu',
                        reduce_sameorder = True):
        if learnable_weight == False:
            connection_mode = 'uvu'
        if reduce_sameorder:
            irreps_output = []
            instructions = []

            for i, (mul, ir_in) in enumerate(input1):
                for j, (_, ir_edge) in enumerate(input2):
                    for ir_out in ir_in * ir_edge:
                        if ir_out in output: # or ir_out == o3.Irrep(0, 1):
                            k = len(irreps_output)
                            irreps_output.append((mul, ir_out))
                            instructions.append((i, j, k, connection_mode, learnable_weight))
            irreps_output = o3.Irreps(irreps_output)

            return instructions,irreps_output

        # elif woreduce_sort:
        #     irreps_output = []
        #     instructions = []

        #     for i, (mul, ir_in) in enumerate(self.irreps_node_input):
        #         for j, (_, ir_edge) in enumerate(self.irreps_edge_attr):
        #             for ir_out in ir_in * ir_edge:
        #                 if ir_out in self.irreps_node_output: # or ir_out == o3.Irrep(0, 1):
        #                     k = len(irreps_output)
        #                     irreps_output.append((mul, ir_out))
        #                     instructions.append((i, j, k, connection_mode, learnable_weight))

        #     irreps_output = o3.Irreps(irreps_output)
        #     instructions = [
        #         (i_1, i_2, p[i_out], mode, train)
        #         for i_1, i_2, i_out, mode, train in instructions
        #     ]
        #     return instructions,irreps_output

        # elif reduce_sort:
        #     irreps_output = []
        #     instructions = []

        #     for i, (mul, ir_in) in enumerate(self.irreps_node_input):
        #         for j, (_, ir_edge) in enumerate(self.irreps_edge_attr):
        #             for ir_out in ir_in * ir_edge:
        #                 if ir_out in self.irreps_node_output: # or ir_out == o3.Irrep(0, 1):
        #                     k = len(irreps_output)
        #                     irreps_output.append((mul, ir_out))
        #                     instructions.append((i, j, k, connection_mode, learnable_weight))

        #     irreps_output = o3.Irreps(irreps_output)
        #     irreps_output, p, _ = sort_irreps_even_first(irreps_output)  # irreps_output.sort()
        #     instructions = [
        #         (i_1, i_2, p[i_out], mode, train)
        #         for i_1, i_2, i_out, mode, train in instructions
        #     ]
        #     return instructions,irreps_output
    def calculate_fan_in(self, ins):
        return {
            "uvw": (self.irreps_in1[ins.i_in1].mul * self.irreps_in2[ins.i_in2].mul),
            "uvu": self.irreps_in2[ins.i_in2].mul,
            "uvv": self.irreps_in1[ins.i_in1].mul,
            "uuw": self.irreps_in1[ins.i_in1].mul,
            "uuu": 1,
            "uvuv": 1,
            "uvu<v": 1,
            "u<vw": self.irreps_in1[ins.i_in1].mul
            * (self.irreps_in2[ins.i_in2].mul - 1)
            // 2,
        }[ins.connection_mode]

    def init_rescale_bias(self) -> None:
        irreps_out = self.irreps_out
        # For each zeroth order output irrep we need a bias
        # Determine the order for each output tensor and their dims
        self.irreps_out_orders = [
            int(irrep_str[-2]) for irrep_str in str(irreps_out).split("+")
        ]
        self.irreps_out_dims = [
            int(irrep_str.split("x")[0]) for irrep_str in str(irreps_out).split("+")
        ]
        self.irreps_out_slices = irreps_out.slices()

        # Store tuples of slices and corresponding biases in a list
        self.bias = None
        self.bias_slices = []
        self.bias_slice_idx = []
        self.irreps_bias = self.irreps_out.simplify()
        self.irreps_bias_orders = [
            int(irrep_str[-2]) for irrep_str in str(self.irreps_bias).split("+")
        ]
        self.irreps_bias_parity = [
            irrep_str[-1] for irrep_str in str(self.irreps_bias).split("+")
        ]
        self.irreps_bias_dims = [
            int(irrep_str.split("x")[0])
            for irrep_str in str(self.irreps_bias).split("+")
        ]
        if self.use_bias:
            self.bias = []
            for slice_idx in range(len(self.irreps_bias_orders)):
                if (
                    self.irreps_bias_orders[slice_idx] == 0
                    and self.irreps_bias_parity[slice_idx] == "e"
                ):
                    out_slice = self.irreps_bias.slices()[slice_idx]
                    out_bias = torch.nn.Parameter(
                        torch.zeros(
                            self.irreps_bias_dims[slice_idx], dtype=self.weight.dtype
                        )
                    )
                    self.bias += [out_bias]
                    self.bias_slices += [out_slice]
                    self.bias_slice_idx += [slice_idx]
        self.bias = torch.nn.ParameterList(self.bias)

        self.slices_sqrt_k = {}
        with torch.no_grad():
            # Determine fan_in for each slice, it could be that each output slice is updated via several instructions
            slices_fan_in = {}  # fan_in per slice
            for instr in self.instructions:
                slice_idx = instr[2]
                fan_in = self.calculate_fan_in(instr)
                slices_fan_in[slice_idx] = (
                    slices_fan_in[slice_idx] + fan_in
                    if slice_idx in slices_fan_in.keys()
                    else fan_in
                )
            for instr in self.instructions:
                slice_idx = instr[2]
                if self.rescale:
                    sqrt_k = 1 / slices_fan_in[slice_idx] ** 0.5
                else:
                    sqrt_k = 1.0
                self.slices_sqrt_k[slice_idx] = (
                    self.irreps_out_slices[slice_idx],
                    sqrt_k,
                )

            # Re-initialize weights in each instruction
            if self.internal_weights:
                for weight, instr in zip(self.weight_views(), self.instructions):
                    # The tensor product in e3nn already normalizes proportional to 1 / sqrt(fan_in), and the weights are by
                    # default initialized with unif(-1,1). However, we want to be consistent with torch.nn.Linear and
                    # initialize the weights with unif(-sqrt(k),sqrt(k)), with k = 1 / fan_in
                    slice_idx = instr[2]
                    if self.rescale:
                        sqrt_k = 1 / slices_fan_in[slice_idx] ** 0.5
                        weight.data.mul_(sqrt_k)
                    # else:
                    #    sqrt_k = 1.
                    #
                    # if self.rescale:
                    # weight.data.uniform_(-sqrt_k, sqrt_k)
                    #    weight.data.mul_(sqrt_k)
                    # self.slices_sqrt_k[slice_idx] = (self.irreps_out_slices[slice_idx], sqrt_k)

            # Initialize the biases
            # for (out_slice_idx, out_slice, out_bias) in zip(self.bias_slice_idx, self.bias_slices, self.bias):
            #    sqrt_k = 1 / slices_fan_in[out_slice_idx] ** 0.5
            #    out_bias.uniform_(-sqrt_k, sqrt_k)



    @torch.jit.unused
    def _prep_weights_python(
        self, weight: Optional[Union[torch.Tensor, List[torch.Tensor]]]
    ) -> Optional[torch.Tensor]:
        if isinstance(weight, list):
            weight_shapes = [
                ins.path_shape for ins in self.instructions if ins.has_weight
            ]
            if not self.shared_weights:
                weight = [
                    w.reshape(-1, prod(shape))
                    for w, shape in zip(weight, weight_shapes)
                ]
            else:
                weight = [
                    w.reshape(prod(shape)) for w, shape in zip(weight, weight_shapes)
                ]
            return torch.cat(weight, dim=-1)
        else:
            return weight

    def _get_weights(self, weight: Optional[torch.Tensor]) -> torch.Tensor:
        if not torch.jit.is_scripting():
            # If we're not scripting, then we're in Python and `weight` could be a List[Tensor]
            # deal with that:
            weight = self._prep_weights_python(weight)
        if weight is None:
            if self.weight_numel > 0 and not self.internal_weights:
                raise RuntimeError(
                    "Weights must be provided when the TensorProduct does not have `internal_weights`"
                )
            return self.weight
        else:
            if self.shared_weights:
                assert weight.shape == (self.weight_numel,), "Invalid weight shape"
            else:
                assert weight.shape[-1] == self.weight_numel, "Invalid weight shape"
                assert (
                    weight.ndim > 1
                ), "When shared weights is false, weights must have batch dimension"
            return weight

    def weight_views(
        self, weight: Optional[torch.Tensor] = None, yield_instruction: bool = False
    ):
        r"""Iterator over weight views for each weighted instruction.

        Parameters
        ----------
        weight : `torch.Tensor`, optional
            like ``weight`` argument to ``forward()``

        yield_instruction : `bool`, default False
            Whether to also yield the corresponding instruction.

        Yields
        ------
        If ``yield_instruction`` is ``True``, yields ``(instruction_index, instruction, weight_view)``.
        Otherwise, yields ``weight_view``.
        """
        weight = self._get_weights(weight)
        batchshape = weight.shape[:-1]
        offset = 0
        for ins_i, ins in enumerate(self.instructions):
            if ins.has_weight:
                flatsize = prod(ins.path_shape)
                this_weight = weight.narrow(-1, offset, flatsize).view(
                    batchshape + ins.path_shape
                )
                offset += flatsize
                if yield_instruction:
                    yield ins_i, ins, this_weight
                else:
                    yield this_weight

    def forward(self, x, y, weight: Optional[torch.Tensor] = None):
        assert x.shape[-2:].numel() == self._in1_dim, "Incorrect last dimension for x"
        assert y.shape[-2:].numel() == self._in2_dim, "Incorrect last dimension for y"

        if weight is not None and self.weight_numel>0:
            weight = self.weight.reshape(1,self.weight_numel)*weight.reshape(-1,self.weight_numel)
        else:
            weight = self.weight
        # return self._main_left_right(x, y, weight)
        return self._compiled_main_left_right(x, y, weight)




        # def forward(self, x1 : torch.Tensor, x2 : torch.Tensor, weight : torch.Tensor) -> torch.Tensor:

    #     w = weight
    #     empty = torch.empty((), device = 'cpu')
    #     getattr_1 = x1.shape
    #     getitem = getattr_1[slice(None, -1, None)];  getattr_1 = None
    #     expand = empty.expand(getitem);  getitem = None
    #     getattr_2 = x2.shape
    #     getitem_1 = getattr_2[slice(None, -1, None)];  getattr_2 = None
    #     expand_1 = empty.expand(getitem_1);  empty = getitem_1 = None
    #     broadcast_tensors = torch.functional.broadcast_tensors(expand, expand_1);  expand = expand_1 = None
    #     getitem_2 = broadcast_tensors[0];  broadcast_tensors = None
    #     getattr_3 = getitem_2.shape;  getitem_2 = None
    #     add = getattr_3 + (-1,)
    #     broadcast_to = x1.broadcast_to(add);  x1 = add = None
    #     add_1 = getattr_3 + (-1,)
    #     broadcast_to_1 = x2.broadcast_to(add_1);  x2 = add_1 = None
    #     add_2 = getattr_3 + (7168,);  getattr_3 = None
    #     reshape = broadcast_to.reshape(-1, 4096);  broadcast_to = None
    #     reshape_1 = broadcast_to_1.reshape(-1, 3);  broadcast_to_1 = None
    #     getattr_4 = reshape.shape
    #     getitem_3 = getattr_4[0];  getattr_4 = None
    #     reshape_2 = w.reshape(-1, 3072);  w = None
    #     getitem_4 = reshape[(slice(None, None, None), slice(0, 1024, None))]
    #     reshape_3 = getitem_4.reshape(getitem_3, 1024, 1);  getitem_4 = None
    #     getitem_5 = reshape[(slice(None, None, None), slice(1024, 4096, None))];  reshape = None
    #     reshape_4 = getitem_5.reshape(getitem_3, 1024, 3);  getitem_5 = None
    #     reshape_5 = reshape_1.reshape(getitem_3, 1, 3);  reshape_1 = None
    #     getitem_6 = reshape_2[(slice(None, None, None), slice(0, 1024, None))]
    #     reshape_6 = getitem_6.reshape((1024, 1));  getitem_6 = None
    #     einsum = torch.functional.einsum('edb,eca->ecdab', reshape_5, reshape_3)
    #     reshape_7 = reshape_3.reshape(getitem_3, 1024);  reshape_3 = None
    #     einsum_1 = torch.functional.einsum('ca,ab->cab', reshape_7, reshape_6);  reshape_7 = reshape_6 = None
    #     einsum_2 = torch.functional.einsum('dbc,dca->dba', einsum_1, reshape_5);  einsum_1 = None
    #     reshape_8 = einsum_2.reshape(getitem_3, 3072);  einsum_2 = None
    #     getitem_7 = reshape_2[(slice(None, None, None), slice(1024, 2048, None))]
    #     reshape_9 = getitem_7.reshape((1024, 1));  getitem_7 = None
    #     mul = reshape_5 * 0.5773502691896258
    #     einsum_3 = torch.functional.einsum('dca,dba->dcb', mul, reshape_4);  mul = None
    #     einsum_4 = torch.functional.einsum('cba,ab->ca', einsum_3, reshape_9);  einsum_3 = reshape_9 = None
    #     reshape_10 = einsum_4.reshape(getitem_3, 1024);  einsum_4 = None
    #     getitem_8 = reshape_2[(slice(None, None, None), slice(2048, 3072, None))];  reshape_2 = None
    #     reshape_11 = getitem_8.reshape((1024, 1));  getitem_8 = None
    #     _w3j_1_1_1 = o3.wigner_3j(1,1,1).to(reshape_11.device)
    #     einsum_5 = torch.functional.einsum('dba,bc->dbac', reshape_4, reshape_11);  reshape_4 = reshape_11 = None
    #     mul_5 = reshape_5 * 1.7320508075688772;  reshape_5 = None
    #     tensordot = torch.functional.tensordot(mul_5, _w3j_1_1_1, dims = ((2,), (1,)), out = None);  mul_5 = _w3j_1_1_1 = None
    #     einsum_6 = torch.functional.einsum('edab,ecad->ecb', tensordot, einsum_5);  tensordot = einsum_5 = None
    #     reshape_12 = einsum_6.reshape(getitem_3, 3072);  einsum_6 = getitem_3 = None
    #     cat = torch.cat([reshape_10, reshape_8, reshape_12], dim = 1);  reshape_10 = reshape_8 = reshape_12 = None
    #     reshape_13 = cat.reshape(add_2);  cat = add_2 = None
    #     return reshape_13


# @compile_mode("script")
# class FeedForwardVec2Scalar(torch.nn.Module):
#     """
#     Use two (FCTP + Gate)
#     """

#     def __init__(
#         self,
#         irreps_node_input,
#         irreps_node_output,
#         proj_drop=0.1,
#     ):
#         super().__init__()
#         self.irreps_node_input = (
#             o3.Irreps(irreps_node_input)
#             if isinstance(irreps_node_input, str)
#             else irreps_node_input
#         )

#         self.irreps_node_output = (
#             o3.Irreps(irreps_node_output)
#             if isinstance(irreps_node_output, str)
#             else irreps_node_output
#         )

#         self.scalar_dim = self.irreps_node_input[0][0]  # l=0 scalar_dim

#         self.ir2scalar = Irreps2Scalar(
#             self.irreps_node_input[1:],
#             self.scalar_dim,
#             bias=True,
#             act="smoothleakyrelu",
#         )

#         self.slinear_1 = IrrepsLinear(
#             self.irreps_node_input, self.irreps_node_input, bias=True, act=None
#         )

#         self.slinear_2 = IrrepsLinear(
#             self.irreps_node_input, self.irreps_node_output, bias=True, act=None
#         )

#         self.scalar_linear = nn.Linear(self.scalar_dim, self.scalar_dim * 2)
#         nn.Sigmoid()

#         if proj_drop != 0.0:
#             self.proj_drop = EquivariantDropout(
#                 self.irreps_node_output, drop_prob=proj_drop
#             )

#     def forward(self, node_input, **kwargs):
#         """
#         irreps_in = o3.Irreps("128x0e+32x1e")
#         func =  FeedForwardNetwork(
#                 irreps_in,
#                 irreps_in,
#                 proj_drop=0.1,
#             )
#         out = func(irreps_in.randn(10,20,-1))
#         """
#         node_output = self.slinear_1(node_input)
#         scalar = node_output[..., : self.scalar_dim]
#         vec = node_output[..., self.scalar_dim :]
#         scalar1, scalar2 = torch.split(
#             self.scalar_linear(scalar), self.scalar_dim, dim=-1
#         )
#         scalar = scalar1 + self.ir2scalar(vec) * scalar2  # vec 2 scalar
#         node_output = torch.cat([scalar, vec], dim=-1)

#         node_output = self.slinear_2(node_output)

#         return node_output
