import torch
import e3nn
import math
from torch import nn
from e3nn import o3
from e3nn.util.jit import compile_mode
import warnings

from sfm.models.psm.equivariant.tensor_product import Simple_TensorProduct
from sfm.models.psm.equivariant.layer_norm import (  # ,\; EquivariantInstanceNorm,EquivariantGraphNorm; EquivariantRMSNormArraySphericalHarmonicsV2,
    get_norm_layer,
)
from sfm.models.psm.equivariant.equiformer_v2.so3 import SO3_Linear_e2former
from sfm.models.psm.equivariant.equiformer_v2.equiformer_v2_oc20 import SO3_Grid

from sfm.models.psm.equivariant.equiformer.graph_attention_transformer import (
    irreps2gate,
    sort_irreps_even_first,
)

from fairchem.core.models.equiformer_v2.activation import GateActivation
from fairchem.core.models.equiformer_v2.so3 import SO3_LinearV2

from fairchem.core.models.escn.so3 import SO3_Embedding

def drop_path_BL(x, drop_prob: float = 0.0, training: bool = False):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.
    """
    if drop_prob == 0.0 or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],x.shape[1]) + (1,) * (
        x.ndim - 2
    )  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output

def drop_path(x, drop_prob: float = 0.0, training: bool = False):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.
    """
    if drop_prob == 0.0 or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (
        x.ndim - 1
    )  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath_BL(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks)."""

    def __init__(self, drop_prob=None):
        super(DropPath_BL, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x,batch):
        batch_size = batch.max() + 1
        shape = (batch_size,) + (1,) * (
            x.ndim - 1
        )  # work with diff dim tensors, not just 2D ConvNets
        ones = torch.ones(shape, dtype=x.dtype, device=x.device)
        
        if len(x.shape) == 4:
            drop = drop_path_BL(ones, self.drop_prob, self.training)
        elif len(x.shape) == 3:
            drop = drop_path(ones, self.drop_prob, self.training)
        return x * drop[batch]

    def extra_repr(self):
        return "drop_prob={}".format(self.drop_prob)


class RadialProfile(nn.Module):
    def __init__(self, ch_list, use_layer_norm=True, use_offset=True):
        super().__init__()
        modules = []
        input_channels = ch_list[0]
        for i in range(len(ch_list)):
            if i == 0:
                continue
            modules.append(nn.Linear(input_channels, ch_list[i], bias=use_offset))
            input_channels = ch_list[i]

            if i == len(ch_list) - 1:
                break

            if use_layer_norm:
                modules.append(nn.LayerNorm(ch_list[i]))
            # modules.append(nn.ReLU())
            # modules.append(Activation(o3.Irreps('{}x0e'.format(ch_list[i])),
            #    acts=[torch.nn.functional.silu]))
            # modules.append(Activation(o3.Irreps('{}x0e'.format(ch_list[i])),
            #    acts=[ShiftedSoftplus()]))
            modules.append(torch.nn.SiLU())

        self.net = nn.Sequential(*modules)


    def forward(self, f_in):
        f_out = self.net(f_in)
        return f_out

class SmoothLeakyReLU(torch.nn.Module):
    def __init__(self, negative_slope=0.2):
        super().__init__()
        self.alpha = negative_slope

    def forward(self, x):
        ## x could be any dimension.
        return (1-self.alpha) * x * torch.sigmoid(x) + self.alpha * x

    def extra_repr(self):
        return "negative_slope={}".format(self.alpha)

# class SmoothLeakyReLU(torch.nn.Module):
#     def __init__(self, negative_slope=0.2):
#         super().__init__()
#         self.alpha = 0.3 #negative_slope
#         self.func = nn.SiLU()
#     def forward(self, x):
#         ## x could be any dimension.
#         return self.func(x)
#         # return (1-self.alpha) * x * torch.sigmoid(x) + self.alpha * x

#     def extra_repr(self):
#         return "negative_slope={}".format(self.alpha)



class SO3_Linear2Scalar_e2former(torch.nn.Module):
    def __init__(self, in_features, out_features, lmax, bias=True):
        """
        1. Use `torch.einsum` to prevent slicing and concatenation
        2. Need to specify some behaviors in `no_weight_decay` and weight initialization.
        """
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.lmax = lmax

        self.weight = torch.nn.Parameter(
            torch.randn((self.lmax + 1), out_features//2, in_features)
        )
        bound = 1 / math.sqrt(self.in_features)
        torch.nn.init.uniform_(self.weight, -bound, bound)
        self.bias = torch.nn.Parameter(torch.zeros(out_features//2))

        self.weight2 = torch.nn.Parameter(
            torch.randn((self.lmax + 1), out_features//2, in_features)
        )
        bound = 1 / math.sqrt(self.in_features)
        torch.nn.init.uniform_(self.weight2, -bound, bound)
        self.bias = torch.nn.Parameter(torch.zeros(1,1,out_features//2))

        expand_index = torch.zeros([(lmax + 1) ** 2]).long()
        for l in range(lmax + 1):
            start_idx = l**2
            length = 2 * l + 1
            expand_index[start_idx : (start_idx + length)] = l
        self.register_buffer("expand_index", expand_index)


        self.final_linear = nn.Sequential(nn.Linear(out_features//2*(lmax+1),out_features),
                                          nn.LayerNorm(out_features),
                                          nn.SiLU(),
                                          nn.Linear(out_features,out_features))
        
    def forward(self, input_embedding):
        output_shape = input_embedding.shape[:-2]
        l_sum,hidden = input_embedding.shape[-2:]
        input_embedding = input_embedding.reshape([output_shape.numel()]+[l_sum,hidden])
        weight = torch.index_select(
            self.weight, dim=0, index=self.expand_index
        )  # [(L_max + 1) ** 2, C_out, C_in]
        out = torch.einsum(
            "bmi, moi -> bmo", input_embedding, weight
        )  # [N, (L_max + 1) ** 2, C_out]
        out[:, 0:1, :] = out.narrow(1, 0, 1) + self.bias


        weight2 = torch.index_select(
            self.weight2, dim=0, index=self.expand_index
        )  # [(L_max + 1) ** 2, C_out, C_in]
        out2 = torch.einsum(
            "bmi, moi -> bmo", input_embedding, weight2
        )  # [N, (L_max + 1) ** 2, C_out]
        out2[:, 0:1, :] = out2.narrow(1, 0, 1)
        
        tmp_out = []
        for l in range(self.lmax+1):
            tmp_out.append(torch.sum(out[:,l**2:(l+1)**2]*out2[:,l**2:(l+1)**2],dim = 1))

        tmp_out = self.final_linear(torch.cat(tmp_out,dim = -1))

        tmp_out = tmp_out.reshape(output_shape+(self.out_features,))

        return tmp_out

class Irreps2Scalar(torch.nn.Module):
    def __init__(
        self,
        irreps_in,
        out_dim,
        hidden_dim=None,
        bias=True,
        act="smoothleakyrelu",
        rescale=True,
    ):
        """
        1. from irreps to scalar output: [...,irreps] - > [...,out_dim]
        2. bias is used for l=0
        3. act is used for l=0
        4. rescale is default, e.g. irreps is c0*l0+c1*l1+c2*l2+c3*l3, rescale weight is 1/c0**0.5 1/c1**0.5 ...
        """
        super().__init__()
        self.irreps_in = (
            o3.Irreps(irreps_in) if isinstance(irreps_in, str) else irreps_in
        )
        if hidden_dim is not None:
            self.hidden_dim = hidden_dim
        else:
            self.hidden_dim = self.irreps_in[0][0]  # l=0 scalar_dim
        self.out_dim = out_dim
        self.act = act
        self.bias = bias
        self.rescale = rescale

        self.vec_proj_list = nn.ModuleList()
        # self.irreps_in_len = sum([mul*(ir.l*2+1) for mul, ir in self.irreps_in])
        # self.scalar_in_len = sum([mul for mul, ir in self.irreps_in])
        self.lirreps = len(self.irreps_in)
        self.output_mlp = nn.Sequential(
            SmoothLeakyReLU(0.2) if self.act == "smoothleakyrelu" else nn.Identity(),
            nn.Linear(self.hidden_dim, out_dim),  # NOTICE init
        )

        for idx in range(len(self.irreps_in)):
            l = self.irreps_in[idx][1].l
            in_feature = self.irreps_in[idx][0]
            if l == 0:
                vec_proj = nn.Linear(in_feature, self.hidden_dim)
                # bound = 1 / math.sqrt(in_feature)
                # torch.nn.init.uniform_(vec_proj.weight, -bound, bound)
                nn.init.xavier_uniform_(vec_proj.weight)
                vec_proj.bias.data.fill_(0)
            else:
                vec_proj = nn.Linear(in_feature, 2 * (self.hidden_dim), bias=False)
                # bound = 1 / math.sqrt(in_feature*(2*l+1))
                # torch.nn.init.uniform_(vec_proj.weight, -bound, bound)
                nn.init.xavier_uniform_(vec_proj.weight)
            self.vec_proj_list.append(vec_proj)

    def forward(self, input_embedding):
        """
        from e3nn import o3
        irreps_in = o3.Irreps("100x1e+40x2e+10x3e")
        irreps_out = o3.Irreps("20x1e+20x2e+20x3e")
        irrepslinear = IrrepsLinear(irreps_in, irreps_out)
        irreps2scalar = Irreps2Scalar(irreps_in, 128)
        node_embed = irreps_in.randn(200,30,5,-1)
        out_scalar = irreps2scalar(node_embed)
        out_irreps = irrepslinear(node_embed)
        """

        # if input_embedding.shape[-1]!=self.irreps_in_len:
        #     raise ValueError("input_embedding should have same length as irreps_in_len")

        shape = list(input_embedding.shape[:-1])
        num = input_embedding.shape[:-1].numel()
        input_embedding = input_embedding.reshape(num, -1)

        start_idx = 0
        scalars = 0
        for idx, (mul, ir) in enumerate(self.irreps_in):
            if idx == 0 and ir.l == 0:
                scalars = self.vec_proj_list[0](
                    input_embedding[..., : self.irreps_in[0][0]]
                )
                start_idx += mul * (2 * ir.l + 1)
                continue
            vec_proj = self.vec_proj_list[idx]
            vec = (
                input_embedding[:, start_idx : start_idx + mul * (2 * ir.l + 1)]
                .reshape(-1, mul, (2 * ir.l + 1))
                .permute(0, 2, 1)
            )  # [B, 2l+1, D]
            vec1, vec2 = torch.split(
                vec_proj(vec), self.hidden_dim, dim=-1
            )  # [B, 2l+1, D]
            vec_dot = (vec1 * vec2).sum(dim=1)  # [B, 2l+1, D]

            scalars = scalars + vec_dot  # TODO: concat
            start_idx += mul * (2 * ir.l + 1)

        output_embedding = self.output_mlp(scalars)
        output_embedding = output_embedding.reshape(shape + [self.out_dim])
        return output_embedding

    def __repr__(self):
        return f"{self.__class__.__name__}(in_features={self.irreps_in}, out_features={self.out_dim}"


# class IrrepsLinear(torch.nn.Module):
#     def __init__(
#         self,
#         irreps_in,
#         irreps_out,
#         hidden_dim=None,
#         bias=True,
#         act="smoothleakyrelu",
#         rescale=_RESCALE,
#     ):
#         """
#         1. from irreps to scalar output: [...,irreps] - > [...,out_dim]
#         2. bias is used for l=0
#         3. act is used for l=0
#         4. rescale is default, e.g. irreps is c0*l0+c1*l1+c2*l2+c3*l3, rescale weight is 1/c0**0.5 1/c1**0.5 ...
#         """
#         super().__init__()
#         self.irreps_in = o3.Irreps(irreps_in) if isinstance(irreps_in,str) else irreps_in
#         self.irreps_out = o3.Irreps(irreps_out) if isinstance(irreps_out,str) else irreps_out

#         self.irreps_in_len = sum([mul*(ir.l*2+1) for mul, ir in self.irreps_in])
#         self.irreps_out_len = sum([mul*(ir.l*2+1) for mul, ir in self.irreps_out])
#         if hidden_dim is not None:
#             self.hidden_dim = hidden_dim
#         else:
#             self.hidden_dim = self.irreps_in[0][0]  # l=0 scalar_dim
#         self.act = act
#         self.bias = bias
#         self.rescale = rescale

#         self.vec_proj_list = nn.ModuleList()
#         # self.irreps_in_len = sum([mul*(ir.l*2+1) for mul, ir in self.irreps_in])
#         # self.scalar_in_len = sum([mul for mul, ir in self.irreps_in])
#         self.output_mlp = nn.Sequential(
#             SmoothLeakyReLU(0.2) if self.act == "smoothleakyrelu" else nn.Identity(),
#             nn.Linear(self.hidden_dim, self.irreps_out[0][0]),
#         )
#         self.weight_list = nn.ParameterList()
#         for idx in range(len(self.irreps_in)):
#             l = self.irreps_in[idx][1].l
#             in_feature = self.irreps_in[idx][0]
#             if l == 0:
#                 vec_proj = nn.Linear(in_feature, self.hidden_dim)
#                 nn.init.xavier_uniform_(vec_proj.weight)
#                 vec_proj.bias.data.fill_(0)
#             else:
#                 vec_proj = nn.Linear(in_feature, 2 * self.hidden_dim, bias=False)
#                 nn.init.xavier_uniform_(vec_proj.weight)

#                 # weight for l>0
#                 out_feature = self.irreps_out[idx][0]
#                 weight = torch.nn.Parameter(
#                                 torch.randn( out_feature,in_feature)
#                             )
#                 bound = 1 / math.sqrt(in_feature) if self.rescale else 1
#                 torch.nn.init.uniform_(weight, -bound, bound)
#                 self.weight_list.append(weight)

#             self.vec_proj_list.append(vec_proj)


#     def forward(self, input_embedding):
#         """
#         from e3nn import o3
#         irreps_in = o3.Irreps("100x1e+40x2e+10x3e")
#         irreps_out = o3.Irreps("20x1e+20x2e+20x3e")
#         irrepslinear = IrrepsLinear(irreps_in, irreps_out)
#         irreps2scalar = Irreps2Scalar(irreps_in, 128)
#         node_embed = irreps_in.randn(200,30,5,-1)
#         out_scalar = irreps2scalar(node_embed)
#         out_irreps = irrepslinear(node_embed)
#         """

#         # if input_embedding.shape[-1]!=self.irreps_in_len:
#         #     raise ValueError("input_embedding should have same length as irreps_in_len")

#         shape = list(input_embedding.shape[:-1])
#         num = input_embedding.shape[:-1].numel()
#         input_embedding = input_embedding.reshape(num, -1)

#         start_idx = 0
#         scalars = self.vec_proj_list[0](input_embedding[..., : self.irreps_in[0][0]])
#         output_embedding = []
#         for idx, (mul, ir) in enumerate(self.irreps_in):
#             if idx == 0:
#                 start_idx += mul * (2 * ir.l + 1)
#                 continue
#             vec_proj = self.vec_proj_list[idx]
#             vec = (
#                 input_embedding[:, start_idx : start_idx + mul * (2 * ir.l + 1)]
#                 .reshape(-1, mul, (2 * ir.l + 1))
#             )  # [B, D, 2l+1]
#             vec1, vec2 = torch.split(
#                 vec_proj(vec.permute(0, 2, 1)), self.hidden_dim, dim=-1
#             )  # [B, 2l+1, D]
#             vec_dot = (vec1 * vec2).sum(dim=1)  # [B, 2l+1, D]

#             scalars = scalars + vec_dot # TODO: concat

#             # linear for l>0
#             weight = self.weight_list[idx-1]
#             out = torch.matmul(weight,vec).reshape(num,-1) # [B*L, -1]
#             output_embedding.append(out)

#             start_idx += mul * (2 * ir.l + 1)
#         try:
#             scalars = self.output_mlp(scalars)
#         except:
#             raise ValueError(f"scalars shape: {scalars.shape}")
#         output_embedding.insert(0, scalars)
#         output_embedding = torch.cat(output_embedding, dim=1)
#         output_embedding = output_embedding.reshape(shape + [self.irreps_out_len])
#         return output_embedding

#     def __repr__(self):
#         return f"{self.__class__.__name__}(in_features={self.irreps_in}, out_features={self.irreps_out}"


class IrrepsLinear(torch.nn.Module):
    def __init__(
        self, irreps_in, irreps_out, bias=True, act="smoothleakyrelu", rescale=True
    ):
        """
        1. from irreps_in to irreps_out output: [...,irreps_in] - > [...,irreps_out]
        2. bias is used for l=0
        3. act is used for l=0
        4. rescale is default, e.g. irreps is c0*l0+c1*l1+c2*l2+c3*l3, rescale weight is 1/c0**0.5 1/c1**0.5 ...
        """
        super().__init__()
        self.irreps_in = (
            o3.Irreps(irreps_in) if isinstance(irreps_in, str) else irreps_in
        )
        self.irreps_out = (
            o3.Irreps(irreps_out) if isinstance(irreps_out, str) else irreps_out
        )

        self.act = act
        self.bias = bias
        self.rescale = rescale

        for idx2 in range(len(self.irreps_out)):
            if self.irreps_out[idx2][1] not in self.irreps_in:
                raise ValueError(
                    f"Error: each irrep of irreps_out {self.irreps_out} should be in irreps_in {self.irreps_in}. Please check your input and output "
                )

        self.weight_list = nn.ParameterList()
        self.bias_list = nn.ParameterList()
        self.act_list = nn.ModuleList()
        self.irreps_in_len = sum([mul * (ir.l * 2 + 1) for mul, ir in self.irreps_in])
        self.irreps_out_len = sum([mul * (ir.l * 2 + 1) for mul, ir in self.irreps_out])
        self.instructions = []
        start_idx = 0
        for idx1 in range(len(self.irreps_in)):
            l = self.irreps_in[idx1][1].l
            mul = self.irreps_in[idx1][0]
            for idx2 in range(len(self.irreps_out)):
                if self.irreps_in[idx1][1].l == self.irreps_out[idx2][1].l:
                    self.instructions.append(
                        [idx1, mul, l, start_idx, start_idx + (l * 2 + 1) * mul]
                    )
                    out_feature = self.irreps_out[idx2][0]

                    weight = torch.nn.Parameter(torch.randn(out_feature, mul))
                    bound = 1 / math.sqrt(mul) if self.rescale else 1
                    torch.nn.init.uniform_(weight, -bound, bound)
                    self.weight_list.append(weight)

                    bias = torch.nn.Parameter(
                        torch.randn(1, out_feature, 1)
                        if self.bias and l == 0
                        else torch.zeros(1, out_feature, 1)
                    )
                    self.bias_list.append(bias)

                    activation = (
                        nn.Sequential(SmoothLeakyReLU())
                        if self.act == "smoothleakyrelu" and l == 0
                        else nn.Sequential()
                    )
                    self.act_list.append(activation)

            start_idx += (l * 2 + 1) * mul

    def forward(self, input_embedding):
        """
        from e3nn import o3
        irreps_in = o3.Irreps("100x1e+40x2e+10x3e")
        irreps_out = o3.Irreps("20x1e+20x2e+20x3e")
        irrepslinear = IrrepsLinear(irreps_in, irreps_out)
        irreps2scalar = Irreps2Scalar(irreps_in, 128)
        node_embed = irreps_in.randn(200,30,5,-1)
        out_scalar = irreps2scalar(node_embed)
        out_irreps = irrepslinear(node_embed)
        """

        if input_embedding.shape[-1] != self.irreps_in_len:
            raise ValueError("input_embedding should have same length as irreps_in_len")

        shape = list(input_embedding.shape[:-1])
        num = input_embedding.shape[:-1].numel()
        input_embedding = input_embedding.reshape(num, -1)

        output_embedding = []
        for idx, (_, mul, l, start, end) in enumerate(self.instructions):
            weight = self.weight_list[idx]
            bias = self.bias_list[idx]
            activation = self.act_list[idx]

            out = (
                torch.matmul(
                    weight, input_embedding[:, start:end].reshape(-1, mul, (2 * l + 1))
                )
                + bias
            )
            out = activation(out).reshape(num, -1)
            output_embedding.append(out)

        output_embedding = torch.cat(output_embedding, dim=1)
        output_embedding = output_embedding.reshape(shape + [self.irreps_out_len])
        return output_embedding


@compile_mode("script")
class Vec2AttnHeads(torch.nn.Module):
    """
    Reshape vectors of shape [..., irreps_head] to vectors of shape
    [..., num_heads, irreps_head].
    """

    def __init__(self, irreps_head, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.irreps_head = irreps_head
        self.irreps_mid_in = []
        for mul, ir in irreps_head:
            self.irreps_mid_in.append((mul * num_heads, ir))
        self.irreps_mid_in = o3.Irreps(self.irreps_mid_in)
        self.mid_in_indices = []
        start_idx = 0
        for mul, ir in self.irreps_mid_in:
            self.mid_in_indices.append((start_idx, start_idx + mul * ir.dim))
            start_idx = start_idx + mul * ir.dim

    def forward(self, x):
        shape = list(x.shape[:-1])
        num = x.shape[:-1].numel()
        x = x.reshape(num, -1)

        N, _ = x.shape
        out = []
        for ir_idx, (start_idx, end_idx) in enumerate(self.mid_in_indices):
            temp = x.narrow(1, start_idx, end_idx - start_idx)
            temp = temp.reshape(N, self.num_heads, -1)
            out.append(temp)
        out = torch.cat(out, dim=2)
        out = out.reshape(shape + [self.num_heads, -1])
        return out

    def __repr__(self):
        return "{}(irreps_head={}, num_heads={})".format(
            self.__class__.__name__, self.irreps_head, self.num_heads
        )


@compile_mode("script")
class AttnHeads2Vec(torch.nn.Module):
    """
    Convert vectors of shape [..., num_heads, irreps_head] into
    vectors of shape [..., irreps_head * num_heads].
    """

    def __init__(self, irreps_head, num_heads=-1):
        super().__init__()
        self.irreps_head = irreps_head
        self.num_heads = num_heads
        self.head_indices = []
        start_idx = 0
        for mul, ir in self.irreps_head:
            self.head_indices.append((start_idx, start_idx + mul * ir.dim))
            start_idx = start_idx + mul * ir.dim

    def forward(self, x):
        head_cnt = x.shape[-2]
        shape = list(x.shape[:-2])
        num = x.shape[:-2].numel()
        x = x.reshape(num, head_cnt, -1)
        N, _, _ = x.shape
        out = []
        for ir_idx, (start_idx, end_idx) in enumerate(self.head_indices):
            temp = x.narrow(2, start_idx, end_idx - start_idx)
            temp = temp.reshape(N, -1)
            out.append(temp)
        out = torch.cat(out, dim=1)
        out = out.reshape(shape + [-1])
        return out

    def __repr__(self):
        return "{}(irreps_head={})".format(self.__class__.__name__, self.irreps_head)


# class EquivariantDropout(nn.Module):
#     def __init__(self, irreps, drop_prob):
#         """
#         equivariant for irreps: [..., irreps]
#         """

#         super(EquivariantDropout, self).__init__()
#         self.irreps = irreps
#         self.num_irreps = irreps.num_irreps
#         self.drop_prob = drop_prob
#         self.drop = torch.nn.Dropout(drop_prob, True)
#         self.mul = o3.ElementwiseTensorProduct(
#             irreps, o3.Irreps("{}x0e".format(self.num_irreps))
#         )

#     def forward(self, x):
#         """
#         x: [..., irreps]

#         t1 = o3.Irreps("5x0e+4x1e+3x2e")
#         func = EquivariantDropout(t1, 0.5)
#         out = func(t1.randn(2,3,-1))
#         """
#         if not self.training or self.drop_prob == 0.0:
#             return x

#         shape = x.shape
#         N = x.shape[:-1].numel()
#         x = x.reshape(N, -1)
#         mask = torch.ones((N, self.num_irreps), dtype=x.dtype, device=x.device)
#         mask = self.drop(mask)

#         out = self.mul(x, mask)

#         return out.reshape(shape)

class EquivariantDropout(nn.Module):
    def __init__(self, dim, lmax,drop_prob):
        """
        equivariant for irreps: [..., irreps]
        """

        super(EquivariantDropout, self).__init__()
        self.lmax = lmax
        self.scalar_dim = dim
        self.drop_prob = drop_prob
        self.drop = torch.nn.Dropout(drop_prob, True)
    def forward(self, x):
        """
        x: [..., irreps]

        t1 = o3.Irreps("5x0e+4x1e+3x2e")
        func = EquivariantDropout(t1, 0.5)
        out = func(t1.randn(2,3,-1))
        """
        if not self.training or self.drop_prob == 0.0:
            return x
        shape = x.shape
        N = x.shape[:-2].numel()
        x = x.reshape(N, (self.lmax+1)**2,-1)
        
        mask = torch.ones((N, self.lmax+1,self.scalar_dim), dtype=x.dtype, device=x.device)
        mask = self.drop(mask)
        out = []
        for l in range(self.lmax+1):
            out.append(x[:,l**2:(l+1)**2]*mask[:,l:l+1])
        out = torch.cat(out,dim = 1)
        return out.reshape(shape)

class TensorProductRescale(torch.nn.Module):
    def __init__(
        self,
        irreps_in1,
        irreps_in2,
        irreps_out,
        instructions,
        bias=True,
        rescale=True,
        internal_weights=None,
        shared_weights=None,
        normalization=None,
        mode="default",
    ):
        super().__init__()

        self.irreps_in1 = irreps_in1
        self.irreps_in2 = irreps_in2
        self.irreps_out = irreps_out
        self.rescale = rescale
        self.use_bias = bias

        # e3nn.__version__ == 0.4.4
        # Use `path_normalization` == 'none' to remove normalization factor
        if mode == "simple":
            self.tp = Simple_TensorProduct(
                irreps_in1=self.irreps_in1,
                irreps_in2=self.irreps_in2,
                irreps_out=self.irreps_out,
                instructions=instructions,
                rescale = rescale,
                # normalization=normalization,
                # internal_weights=internal_weights,
                # shared_weights=shared_weights,
                # path_normalization="none",
            )
        else:
            self.tp = o3.TensorProduct(
                irreps_in1=self.irreps_in1,
                irreps_in2=self.irreps_in2,
                irreps_out=self.irreps_out,
                instructions=instructions,
                normalization=normalization,
                internal_weights=internal_weights,
                shared_weights=shared_weights,
                path_normalization="none",
            )

        self.init_rescale_bias()

    def calculate_fan_in(self, ins):
        return {
            "uvw": (self.irreps_in1[ins.i_in1].mul * self.irreps_in2[ins.i_in2].mul),
            "uvu": self.irreps_in2[ins.i_in2].mul,
            "uvv": self.irreps_in1[ins.i_in1].mul,
            "uuw": self.irreps_in1[ins.i_in1].mul,
            "uuu": 1,
            "uvuv": 1,
            "uvu<v": 1,
            "u<vw": self.irreps_in1[ins.i_in1].mul
            * (self.irreps_in2[ins.i_in2].mul - 1)
            // 2,
        }[ins.connection_mode]

    def init_rescale_bias(self) -> None:
        irreps_out = self.irreps_out
        # For each zeroth order output irrep we need a bias
        # Determine the order for each output tensor and their dims
        self.irreps_out_orders = [
            int(irrep_str[-2]) for irrep_str in str(irreps_out).split("+")
        ]
        self.irreps_out_dims = [
            int(irrep_str.split("x")[0]) for irrep_str in str(irreps_out).split("+")
        ]
        self.irreps_out_slices = irreps_out.slices()

        # Store tuples of slices and corresponding biases in a list
        self.bias = None
        self.bias_slices = []
        self.bias_slice_idx = []
        self.irreps_bias = self.irreps_out.simplify()
        self.irreps_bias_orders = [
            int(irrep_str[-2]) for irrep_str in str(self.irreps_bias).split("+")
        ]
        self.irreps_bias_parity = [
            irrep_str[-1] for irrep_str in str(self.irreps_bias).split("+")
        ]
        self.irreps_bias_dims = [
            int(irrep_str.split("x")[0])
            for irrep_str in str(self.irreps_bias).split("+")
        ]
        if self.use_bias:
            self.bias = []
            for slice_idx in range(len(self.irreps_bias_orders)):
                if (
                    self.irreps_bias_orders[slice_idx] == 0
                    and self.irreps_bias_parity[slice_idx] == "e"
                ):
                    out_slice = self.irreps_bias.slices()[slice_idx]
                    out_bias = torch.nn.Parameter(
                        torch.zeros(
                            self.irreps_bias_dims[slice_idx], dtype=self.tp.weight.dtype
                        )
                    )
                    self.bias += [out_bias]
                    self.bias_slices += [out_slice]
                    self.bias_slice_idx += [slice_idx]
        self.bias = torch.nn.ParameterList(self.bias)

        self.slices_sqrt_k = {}
        with torch.no_grad():
            # Determine fan_in for each slice, it could be that each output slice is updated via several instructions
            slices_fan_in = {}  # fan_in per slice
            for instr in self.tp.instructions:
                slice_idx = instr[2]
                fan_in = self.calculate_fan_in(instr)
                slices_fan_in[slice_idx] = (
                    slices_fan_in[slice_idx] + fan_in
                    if slice_idx in slices_fan_in.keys()
                    else fan_in
                )
            for instr in self.tp.instructions:
                slice_idx = instr[2]
                if self.rescale:
                    sqrt_k = 1 / slices_fan_in[slice_idx] ** 0.5
                else:
                    sqrt_k = 1.0
                self.slices_sqrt_k[slice_idx] = (
                    self.irreps_out_slices[slice_idx],
                    sqrt_k,
                )

            # Re-initialize weights in each instruction
            if self.tp.internal_weights:
                for weight, instr in zip(self.tp.weight_views(), self.tp.instructions):
                    # The tensor product in e3nn already normalizes proportional to 1 / sqrt(fan_in), and the weights are by
                    # default initialized with unif(-1,1). However, we want to be consistent with torch.nn.Linear and
                    # initialize the weights with unif(-sqrt(k),sqrt(k)), with k = 1 / fan_in
                    slice_idx = instr[2]
                    if self.rescale:
                        sqrt_k = 1 / slices_fan_in[slice_idx] ** 0.5
                        weight.data.mul_(sqrt_k)
                    # else:
                    #    sqrt_k = 1.
                    #
                    # if self.rescale:
                    # weight.data.uniform_(-sqrt_k, sqrt_k)
                    #    weight.data.mul_(sqrt_k)
                    # self.slices_sqrt_k[slice_idx] = (self.irreps_out_slices[slice_idx], sqrt_k)

            # Initialize the biases
            # for (out_slice_idx, out_slice, out_bias) in zip(self.bias_slice_idx, self.bias_slices, self.bias):
            #    sqrt_k = 1 / slices_fan_in[out_slice_idx] ** 0.5
            #    out_bias.uniform_(-sqrt_k, sqrt_k)

    def forward_tp_rescale_bias(self, x, y, weight=None):
        out = self.tp(x, y, weight)
        # if self.rescale and self.tp.internal_weights:
        #    for (slice, slice_sqrt_k) in self.slices_sqrt_k.values():
        #        out[:, slice] /= slice_sqrt_k
        if self.use_bias:
            for _, slice, bias in zip(self.bias_slice_idx, self.bias_slices, self.bias):
                # out[:, slice] += bias
                out.narrow(-1, slice.start, slice.stop - slice.start).add_(bias)
        return out

    def forward(self, x, y, weight=None):
        out = self.forward_tp_rescale_bias(x, y, weight)
        return out


class SeparableFCTP(torch.nn.Module):
    def __init__(
        self,
        irreps_x,
        irreps_y,
        irreps_out,
        fc_neurons,
        use_activation=False,
        norm_layer="graph",
        internal_weights=False,
        mode="default",
        connection_mode='uvu',
        rescale=True,
        eqv2=False
    ):
        """
        Use separable FCTP for spatial convolution.
        [...,irreps_x] tp [...,irreps_y] - > [..., irreps_out]

        fc_neurons is not needed in e2former
        """

        super().__init__()
        self.irreps_node_input = o3.Irreps(irreps_x)
        self.irreps_edge_attr = o3.Irreps(irreps_y)
        self.irreps_node_output = o3.Irreps(irreps_out)
        norm = get_norm_layer(norm_layer)


        irreps_output = []
        instructions = []

        for i, (mul, ir_in) in enumerate(self.irreps_node_input):
            for j, (_, ir_edge) in enumerate(self.irreps_edge_attr):
                for ir_out in ir_in * ir_edge:
                    if ir_out in self.irreps_node_output: # or ir_out == o3.Irrep(0, 1):
                        k = len(irreps_output)
                        irreps_output.append((mul, ir_out))
                        instructions.append((i, j, k, connection_mode, True))

        irreps_output = o3.Irreps(irreps_output)
        irreps_output, p, _ = sort_irreps_even_first(irreps_output)  # irreps_output.sort()
        instructions = [
            (i_1, i_2, p[i_out], mode, train)
            for i_1, i_2, i_out, mode, train in instructions
        ]
        if mode != "default":
            if internal_weights is False:
                raise ValueError("tp not support some parameter, please check your code.")
            
        if eqv2==True:
            self.dtp = TensorProductRescale(
                self.irreps_node_input,
                self.irreps_edge_attr,
                irreps_output,
                instructions,
                internal_weights=internal_weights,
                shared_weights=True,
                bias=False,
                rescale=rescale,
                mode=mode,
            )


            self.dtp_rad = None
            self.fc_neurons = fc_neurons
            if fc_neurons is not None:
                warnings.warn("NOTICEL: fc_neurons is not needed in e2former")
                self.dtp_rad = RadialProfile(fc_neurons + [self.dtp.tp.irreps_out.num_irreps])
                # for slice, slice_sqrt_k in self.dtp.slices_sqrt_k.values():
                #     self.dtp_rad.net[-1].weight.data[slice, :] *= slice_sqrt_k
                #     self.dtp_rad.offset.data[slice] *= slice_sqrt_k

            self.norm = None

            if use_activation:
                irreps_lin_output = self.irreps_node_output
                irreps_scalars, irreps_gates, irreps_gated = irreps2gate(
                    self.irreps_node_output
                )
                irreps_lin_output = irreps_scalars + irreps_gates + irreps_gated
                irreps_lin_output = irreps_lin_output.simplify()
                self.lin = IrrepsLinear(
                    self.dtp.irreps_out.simplify(), irreps_lin_output, bias=False, act=None
                )
                if norm_layer is not None:
                    self.norm = norm(irreps_lin_output)

            else:
                self.lin = IrrepsLinear(
                    self.dtp.irreps_out.simplify(), self.irreps_node_output, bias=False, act=None
                )
                if norm_layer is not None:
                    self.norm = norm(self.irreps_node_output)

            self.gate = None
            if use_activation:
                if irreps_gated.num_irreps == 0:
                    gate = Activation(self.irreps_node_output, acts=[torch.nn.SiLU()])
                else:
                    gate = Gate(
                        irreps_scalars,
                        [torch.nn.SiLU() for _, ir in irreps_scalars],  # scalar
                        irreps_gates,
                        [torch.sigmoid for _, ir in irreps_gates],  # gates (scalars)
                        irreps_gated,  # gated tensors
                    )
                self.gate = gate
        else:
            self.dtp = TensorProductRescale(
                self.irreps_node_input,
                self.irreps_edge_attr,
                irreps_output,
                instructions,
                internal_weights=internal_weights,
                shared_weights=internal_weights,
                bias=False,
                rescale=rescale,
                mode=mode,
            )


            self.dtp_rad = None
            self.fc_neurons = fc_neurons
            if fc_neurons is not None:
                warnings.warn("NOTICEL: fc_neurons is not needed in e2former")
                self.dtp_rad = RadialProfile(fc_neurons + [self.dtp.tp.weight_numel])
                for slice, slice_sqrt_k in self.dtp.slices_sqrt_k.values():
                    self.dtp_rad.net[-1].weight.data[slice, :] *= slice_sqrt_k
                    self.dtp_rad.offset.data[slice] *= slice_sqrt_k

            irreps_lin_output = self.irreps_node_output
            irreps_scalars, irreps_gates, irreps_gated = irreps2gate(
                self.irreps_node_output
            )
            if use_activation:
                irreps_lin_output = irreps_scalars + irreps_gates + irreps_gated
                irreps_lin_output = irreps_lin_output.simplify()
            self.lin = IrrepsLinear(
                self.dtp.irreps_out.simplify(), irreps_lin_output, bias=False, act=None
            )

            self.norm = None
            if norm_layer is not None:
                self.norm = norm(self.irreps_node_output)

            self.gate = None
            if use_activation:
                if irreps_gated.num_irreps == 0:
                    gate = Activation(self.irreps_node_output, acts=[torch.nn.SiLU()])
                else:
                    gate = Gate(
                        irreps_scalars,
                        [torch.nn.SiLU() for _, ir in irreps_scalars],  # scalar
                        irreps_gates,
                        [torch.sigmoid for _, ir in irreps_gates],  # gates (scalars)
                        irreps_gated,  # gated tensors
                    )
                self.gate = gate           

    def forward(self, irreps_x, irreps_y, xy_scalar_fea, batch=None,eqv2=False, **kwargs):
        """
        x: [..., irreps]

        irreps_in = o3.Irreps("256x0e+64x1e+32x2e")
        sep_tp = SeparableFCTP(irreps_in,"1x1e",irreps_in,fc_neurons=None,
                            use_activation=False,norm_layer=None,
                            internal_weights=True)
        out = sep_tp(irreps_in.randn(100,10,-1),torch.randn(100,10,3),None)
        print(out.shape)
        """
        if eqv2==True:
            shape = irreps_x.shape[:-2]
            N = irreps_x.shape[:-2].numel()
            irreps_x = self.from_eqv2toe3nn(irreps_x)
            irreps_y = irreps_y.reshape(N, -1)

            out = self.dtp(irreps_x, irreps_y, None)
            if self.dtp_rad is not None and xy_scalar_fea is not None:
                xy_scalar_fea = xy_scalar_fea.reshape(N, -1)
                weight = self.dtp_rad(xy_scalar_fea)
                temp = []
                start = 0
                start_scalar = 0
                for mul,(ir,_) in self.dtp.tp.irreps_out.simplify():
                    temp.append((out[:,start:start+(2*ir+1)*mul].reshape(-1,mul,2*ir+1)*\
                                                weight[:,start_scalar:start_scalar+mul].unsqueeze(-1)).reshape(-1,(2*ir+1)*mul))
                    start_scalar += mul
                    start += (2*ir+1)*mul
                out = torch.cat(temp,dim = -1)
            out = self.lin(out)
            if self.norm is not None:
                out = self.norm(out, batch=batch)
            if self.gate is not None:
                out = self.gate(out)
            return self.from_e3nntoeqv2(out)
        else:
            shape = irreps_x.shape[:-1]
            N = irreps_x.shape[:-1].numel()
            irreps_x = irreps_x.reshape(N, -1)
            irreps_y = irreps_y.reshape(N, -1)

            weight = None
            if self.dtp_rad is not None and xy_scalar_fea is not None:
                xy_scalar_fea = xy_scalar_fea.reshape(N, -1)
                weight = self.dtp_rad(xy_scalar_fea)
            out = self.dtp(irreps_x, irreps_y, weight)
            out = self.lin(out)
            if self.norm is not None:
                out = self.norm(out, batch=batch)
            if self.gate is not None:
                out = self.gate(out)
            return out.reshape(list(shape) + [-1])


    def from_eqv2toe3nn(self,embedding):
        BL = embedding.shape[0]
        lmax = self.irreps_node_input[-1][1][0]
        start = 0
        out = []
        for l in range(1+lmax):
            out.append(embedding[:,start:start+2*l+1,:].permute(0,2,1).reshape(BL,-1))
            start += 2*l+1
        return torch.cat(out,dim = -1)


    def from_e3nntoeqv2(self,embedding):
        lmax = self.irreps_node_output[-1][1][0]
        mul = self.irreps_node_output[-1][0]

        start = 0
        out = []
        for l in range(1+lmax):
            out.append(embedding[:,start:start+mul*(2*l+1)].reshape(-1,mul,2*l+1).permute(0,2,1))
            start += mul*(2*l+1)
        return torch.cat(out,dim = 1)

class CosineCutoff(torch.nn.Module):
    r"""Appies a cosine cutoff to the input distances.

    .. math::
        \text{cutoffs} =
        \begin{cases}
        0.5 * (\cos(\frac{\text{distances} * \pi}{\text{cutoff}}) + 1.0),
        & \text{if } \text{distances} < \text{cutoff} \\
        0, & \text{otherwise}
        \end{cases}

    Args:
        cutoff (float): A scalar that determines the point at which the cutoff
            is applied.
    """

    def __init__(self, cutoff: float) -> None:
        super().__init__()
        self.cutoff = cutoff

    def forward(self, distances):
        r"""Applies a cosine cutoff to the input distances.

        Args:
            distances (torch.Tensor): A tensor of distances.

        Returns:
            cutoffs (torch.Tensor): A tensor where the cosine function
                has been applied to the distances,
                but any values that exceed the cutoff are set to 0.
        """
        cutoffs = 0.5 * ((distances * math.pi / self.cutoff).cos() + 1.0)
        cutoffs = cutoffs * (distances < self.cutoff).float()
        return cutoffs





def get_mul_0(irreps):
    mul_0 = 0
    for mul, ir in irreps:
        if ir.l == 0 and ir.p == 1:
            mul_0 += mul
    return mul_0




@compile_mode("trace")
class Activation(torch.nn.Module):
    """
    Directly apply activation when irreps is type-0.
    """

    def __init__(self, irreps_in, acts):
        super().__init__()
        if isinstance(irreps_in, str):
            irreps_in = o3.Irreps(irreps_in)
        assert len(irreps_in) == len(acts), (irreps_in, acts)

        # normalize the second moment
        acts = [
            e3nn.math.normalize2mom(act) if act is not None else None for act in acts
        ]

        from e3nn.util._argtools import _get_device

        irreps_out = []
        for (mul, (l_in, p_in)), act in zip(irreps_in, acts):
            if act is not None:
                if l_in != 0:
                    raise ValueError(
                        "Activation: cannot apply an activation function to a non-scalar input."
                    )

                x = torch.linspace(0, 10, 256, device=_get_device(act))

                a1, a2 = act(x), act(-x)
                if (a1 - a2).abs().max() < 1e-5:
                    p_act = 1
                elif (a1 + a2).abs().max() < 1e-5:
                    p_act = -1
                else:
                    p_act = 0

                p_out = p_act if p_in == -1 else p_in
                irreps_out.append((mul, (0, p_out)))

                if p_out == 0:
                    raise ValueError(
                        "Activation: the parity is violated! The input scalar is odd but the activation is neither even nor odd."
                    )
            else:
                irreps_out.append((mul, (l_in, p_in)))

        self.irreps_in = irreps_in
        self.irreps_out = o3.Irreps(irreps_out)
        self.acts = torch.nn.ModuleList(acts)
        assert len(self.irreps_in) == len(self.acts)

    # def __repr__(self):
    #    acts = "".join(["x" if a is not None else " " for a in self.acts])
    #    return f"{self.__class__.__name__} [{self.acts}] ({self.irreps_in} -> {self.irreps_out})"
    def extra_repr(self):
        output_str = super(Activation, self).extra_repr()
        output_str = output_str + "{} -> {}, ".format(self.irreps_in, self.irreps_out)
        return output_str

    def forward(self, features, dim=-1):
        # directly apply activation without narrow
        if len(self.acts) == 1:
            return self.acts[0](features)

        output = []
        index = 0
        for (mul, ir), act in zip(self.irreps_in, self.acts):
            if act is not None:
                output.append(act(features.narrow(dim, index, mul)))
            else:
                output.append(features.narrow(dim, index, mul * ir.dim))
            index += mul * ir.dim

        if len(output) > 1:
            return torch.cat(output, dim=dim)
        elif len(output) == 1:
            return output[0]
        else:
            return torch.zeros_like(features)


@compile_mode("script")
class Gate(torch.nn.Module):
    """
    TODO: to be optimized.  Toooooo ugly
    1. Use `narrow` to split tensor.
    2. Use `Activation` in this file.
    """

    def __init__(
        self, irreps_scalars, act_scalars, irreps_gates, act_gates, irreps_gated
    ):
        super().__init__()
        irreps_scalars = o3.Irreps(irreps_scalars)
        irreps_gates = o3.Irreps(irreps_gates)
        irreps_gated = o3.Irreps(irreps_gated)

        if len(irreps_gates) > 0 and irreps_gates.lmax > 0:
            raise ValueError(
                f"Gate scalars must be scalars, instead got irreps_gates = {irreps_gates}"
            )
        if len(irreps_scalars) > 0 and irreps_scalars.lmax > 0:
            raise ValueError(
                f"Scalars must be scalars, instead got irreps_scalars = {irreps_scalars}"
            )
        if irreps_gates.num_irreps != irreps_gated.num_irreps:
            raise ValueError(
                f"There are {irreps_gated.num_irreps} irreps in irreps_gated, but a different number ({irreps_gates.num_irreps}) of gate scalars in irreps_gates"
            )
        # assert len(irreps_scalars) == 1
        # assert len(irreps_gates) == 1

        self.irreps_scalars = irreps_scalars
        self.irreps_gates = irreps_gates
        self.irreps_gated = irreps_gated
        self._irreps_in = (irreps_scalars + irreps_gates + irreps_gated).simplify()

        self.act_scalars = Activation(irreps_scalars, act_scalars)
        irreps_scalars = self.act_scalars.irreps_out

        self.act_gates = Activation(irreps_gates, act_gates)
        irreps_gates = self.act_gates.irreps_out

        self.mul = o3.ElementwiseTensorProduct(irreps_gated, irreps_gates)
        irreps_gated = self.mul.irreps_out

        self._irreps_out = irreps_scalars + irreps_gated

    def __repr__(self):
        return f"{self.__class__.__name__} ({self.irreps_in} -> {self.irreps_out})"

    def forward(self, features):
        scalars_dim = self.irreps_scalars.dim
        gates_dim = self.irreps_gates.dim
        input_dim = self.irreps_in.dim

        scalars = features.narrow(-1, 0, scalars_dim)
        gates = features.narrow(-1, scalars_dim, gates_dim)
        gated = features.narrow(
            -1, (scalars_dim + gates_dim), (input_dim - scalars_dim - gates_dim)
        )

        scalars = self.act_scalars(scalars)
        if gates.shape[-1]:
            gates = self.act_gates(gates)
            gated = self.mul(gated, gates)
            features = torch.cat([scalars, gated], dim=-1)
        else:
            features = scalars
        return features

    @property
    def irreps_in(self):
        """Input representations."""
        return self._irreps_in

    @property
    def irreps_out(self):
        """Output representations."""
        return self._irreps_out




@compile_mode("script")
class Gate_s3(torch.nn.Module):
    """
    TODO: to be optimized.  Toooooo ugly
    1. Use `narrow` to split tensor.
    2. Use `Activation` in this file.
    """

    def __init__(self, sphere_channels,lmax, act_scalars="silu", act_vector="sigmoid"):
        super().__init__()


        self.sphere_channels = sphere_channels
        self.lmax = lmax
        self.gates = torch.nn.Linear(sphere_channels, sphere_channels*(lmax+1))
        bound = 1 / math.sqrt(sphere_channels)
        torch.nn.init.uniform_(self.gates.weight, -bound, bound)

        if act_scalars == "silu":
            self.act_scalars = e3nn.math.normalize2mom(torch.nn.SiLU())
        else:
            raise ValueError("in Gate, only support silu")

        if act_vector == "sigmoid":
            self.act_vector = e3nn.math.normalize2mom(torch.nn.Sigmoid())
        else:
            raise ValueError("in Gate, only support sigmoid for vector")

    def __repr__(self):
        return f"{self.__class__.__name__} sph ({self.sphere_channels} lmax {self.lmax}"

    def forward(self, features):
        input_shape = features.shape
        features = features.reshape(input_shape[:-2].numel(),-1, input_shape[-1])
        
        scalars = self.gates(features[:,0:1])
        out = [self.act_scalars(scalars[:,:, : self.sphere_channels])]

        start = 1
        for l in range(1, self.lmax+1):
            out.append(
                self.act_vector(scalars[:,:         ,l*self.sphere_channels:l*self.sphere_channels + self.sphere_channels])  # __ * 1 * hidden_dim
                    *   features[:,start:start+2*l+1,:]  # __ * (2l+1) * hidden_dim
            )
            start += 2 * l + 1

        out = torch.cat(out, dim=1)
        return out.reshape(input_shape)

    @property
    def irreps_in(self):
        """Input representations."""
        return self.out



@compile_mode("script")
class FeedForwardNetwork_s3(torch.nn.Module):
    """
    Use two (FCTP + Gate)
    """

    def __init__(
        self,
        sphere_channels,
        hidden_channels,
        output_channels,
        lmax,
    ):
        super().__init__()
        self.sphere_channels = sphere_channels
        self.hidden_channels = hidden_channels
        self.output_channels = output_channels

        self.slinear_1 = SO3_Linear_e2former(
            self.sphere_channels, self.hidden_channels,lmax=lmax, bias=True
        )

        self.gate = Gate_s3(self.hidden_channels,lmax=lmax, act_scalars="silu", act_vector="sigmoid")

        self.slinear_2 = SO3_Linear_e2former(
            self.hidden_channels, self.output_channels, lmax=lmax,bias=True
        )

    def forward(self, node_input, **kwargs):
        """
        irreps_in = o3.Irreps("128x0e+32x1e")
        func =  FeedForwardNetwork(
                irreps_in,
                irreps_in,
                proj_drop=0.1,
            )
        out = func(irreps_in.randn(10,20,-1))
        """
        node_output = self.slinear_1(node_input)
        node_output = self.gate(node_output)
        node_output = self.slinear_2(node_output)
        return node_output


class S2Activation(torch.nn.Module):
    """
    Assume we only have one resolution
    """

    def __init__(self, lmax, mmax):
        super().__init__()
        self.lmax = lmax
        self.mmax = mmax
        self.act = torch.nn.SiLU()

    def forward(self, inputs, SO3_grid):
        to_grid_mat = SO3_grid[self.lmax][self.mmax].get_to_grid_mat(
            device=None
        )  # `device` is not used
        from_grid_mat = SO3_grid[self.lmax][self.mmax].get_from_grid_mat(device=None)
        x_grid = torch.einsum("bai, zic -> zbac", to_grid_mat, inputs)
        x_grid = self.act(x_grid)
        outputs = torch.einsum("bai, zbac -> zic", from_grid_mat, x_grid)
        return outputs
    
class SeparableS2Activation(torch.nn.Module):
    def __init__(self, lmax, mmax):
        super().__init__()

        self.lmax = lmax
        self.mmax = mmax

        self.scalar_act = torch.nn.SiLU()
        self.s2_act = S2Activation(self.lmax, self.mmax)

    def forward(self, input_scalars, input_tensors, SO3_grid):
        output_scalars = self.scalar_act(input_scalars)
        output_scalars = output_scalars.reshape(
            output_scalars.shape[0], 1, output_scalars.shape[-1]
        )
        output_tensors = self.s2_act(input_tensors, SO3_grid)
        outputs = torch.cat(
            (output_scalars, output_tensors.narrow(1, 1, output_tensors.shape[1] - 1)),
            dim=1,
        )
        return outputs
    
    
# follow eSCN
class FeedForwardNetwork_escn(torch.nn.Module):
    """
    FeedForwardNetwork: Perform feedforward network with S2 activation or gate activation

    Args:
        sphere_channels (int):      Number of spherical channels
        hidden_channels (int):      Number of hidden channels used during feedforward network
        output_channels (int):      Number of output channels

        lmax_list (list:int):       List of degrees (l) for each resolution
        mmax_list (list:int):       List of orders (m) for each resolution

        SO3_grid (SO3_grid):        Class used to convert from grid the spherical harmonic representations

        activation (str):           Type of activation function
        use_gate_act (bool):        If `True`, use gate activation. Otherwise, use S2 activation
        use_grid_mlp (bool):        If `True`, use projecting to grids and performing MLPs.
        use_sep_s2_act (bool):      If `True`, use separable grid MLP when `use_grid_mlp` is True.
    """

    def __init__(
        self,
        sphere_channels,
        hidden_channels,
        output_channels,
        lmax,
        grid_resolution = 18,
    ):
        super(FeedForwardNetwork_escn, self).__init__()
        self.sphere_channels = sphere_channels
        # self.hidden_channels = hidden_channels
        self.output_channels = output_channels

        self.so3_grid = torch.nn.ModuleList()
        self.lmax = lmax
        for l in range(lmax + 1):
            SO3_m_grid = nn.ModuleList()
            for m in range(lmax + 1):
                SO3_m_grid.append(
                    SO3_Grid(
                        l, m, resolution=grid_resolution, normalization="component"
                    )
                )
            self.so3_grid.append(SO3_m_grid)

        self.act = nn.SiLU()
       # Non-linear point-wise comvolution for the aggregated messages
        self.fc1_sphere = nn.Linear(
            2 * self.sphere_channels, self.sphere_channels, bias=False
        )

        self.fc2_sphere = nn.Linear(
            self.sphere_channels, self.sphere_channels, bias=False
        )

        self.fc3_sphere = nn.Linear(
            self.sphere_channels, self.sphere_channels, bias=False
        )

    def forward(self, node_irreps,nore_irreps_his,**kwargs):
        """_summary_
            model = FeedForwardNetwork_grid_nonlinear(
                    sphere_channels = 128,
                    hidden_channels = 128,
                    output_channels = 128,
                    lmax = 4,
                    grid_resolution = 18,
                )
            node_irreps = torch.randn(100,3,25,128)
            node_irreps_his = torch.randn(100,3,25,128)
            model(node_irreps,node_irreps_his).shape
        Args:
            node_irreps (_type_): _description_
            nore_irreps_his (_type_): _description_

        Returns:
            _type_: _description_
        """

        out_shape = node_irreps.shape[:-2]

        node_irreps = node_irreps.reshape(out_shape.numel(),(self.lmax+1)**2,self.sphere_channels)
        nore_irreps_his = nore_irreps_his.reshape(out_shape.numel(),(self.lmax+1)**2,self.sphere_channels)
        
        
        to_grid_mat = self.so3_grid[self.lmax][self.lmax].get_to_grid_mat(
            device=None
        )  # `device` is not used
        from_grid_mat = self.so3_grid[self.lmax][self.lmax].get_from_grid_mat(device=None)
        
        # Compute point-wise spherical non-linearity on aggregated messages
        # Project to grid
        x_grid = torch.einsum("bai, zic -> zbac", to_grid_mat, node_irreps) #input_embedding.to_grid(self.SO3_grid, lmax=max_lmax)
        x_grid_his = torch.einsum("bai, zic -> zbac", to_grid_mat, nore_irreps_his)
        x_grid = torch.cat([x_grid, x_grid_his], dim=3)

        # Perform point-wise convolution
        x_grid = self.act(self.fc1_sphere(x_grid))
        x_grid = self.act(self.fc2_sphere(x_grid))
        x_grid = self.fc3_sphere(x_grid)

        node_irreps = torch.einsum("bai, zbac -> zic", from_grid_mat, x_grid)
        return node_irreps.reshape(out_shape+(-1,self.output_channels))

class FeedForwardNetwork_s2(torch.nn.Module):
    """
    FeedForwardNetwork: Perform feedforward network with S2 activation or gate activation

    Args:
        sphere_channels (int):      Number of spherical channels
        hidden_channels (int):      Number of hidden channels used during feedforward network
        output_channels (int):      Number of output channels

        lmax_list (list:int):       List of degrees (l) for each resolution
        mmax_list (list:int):       List of orders (m) for each resolution

        SO3_grid (SO3_grid):        Class used to convert from grid the spherical harmonic representations

        activation (str):           Type of activation function
        use_gate_act (bool):        If `True`, use gate activation. Otherwise, use S2 activation
        use_grid_mlp (bool):        If `True`, use projecting to grids and performing MLPs.
        use_sep_s2_act (bool):      If `True`, use separable grid MLP when `use_grid_mlp` is True.
    """

    def __init__(
        self,
        sphere_channels,
        hidden_channels,
        output_channels,
        lmax,
        mmax = 2,
        grid_resolution = 18,  
        use_gate_act =   False,     # [True, False] Switch between gate activation and S2 activation
        use_grid_mlp =   True,      # [False, True] If `True`, use projecting to grids and performing MLPs for FFNs.
        use_sep_s2_act = True ,     # Separable S2 activation. Used for ablation study.

        # activation="scaled_silu",
        # use_sep_s2_act=True,
    ):
        super(FeedForwardNetwork_s2, self).__init__()
        self.sphere_channels = sphere_channels
        self.hidden_channels = hidden_channels
        self.output_channels = output_channels
        self.sphere_channels_all = self.sphere_channels
        self.so3_grid = torch.nn.ModuleList()
        self.lmax = lmax
        self.max_lmax = self.lmax
        self.lmax_list = [lmax]
        for l in range(lmax + 1):
            SO3_m_grid = nn.ModuleList()
            for m in range(lmax + 1):
                SO3_m_grid.append(
                    SO3_Grid(
                        l, m, resolution=grid_resolution, normalization="component"
                    )
                )
            self.so3_grid.append(SO3_m_grid)

        self.use_gate_act =   use_gate_act     # [True, False] Switch between gate activation and S2 activation
        self.use_grid_mlp =   use_grid_mlp      # [False, True] If `True`, use projecting to grids and performing MLPs for FFNs.
        self.use_sep_s2_act = use_sep_s2_act      # Separable S2 activation. Used for ablation study.


        self.so3_linear_1 = SO3_LinearV2(
            self.sphere_channels_all, self.hidden_channels, lmax=self.lmax
        )
        if self.use_grid_mlp:
            if self.use_sep_s2_act:
                self.scalar_mlp = nn.Sequential(
                    nn.Linear(
                        self.sphere_channels_all,
                        self.hidden_channels,
                        bias=True,
                    ),
                    nn.SiLU(),
                )
            else:
                self.scalar_mlp = None
            self.grid_mlp = nn.Sequential(
                nn.Linear(self.hidden_channels, self.hidden_channels, bias=False),
                nn.SiLU(),
                nn.Linear(self.hidden_channels, self.hidden_channels, bias=False),
                nn.SiLU(),
                nn.Linear(self.hidden_channels, self.hidden_channels, bias=False),
            )
        else:
            if self.use_gate_act:
                self.gating_linear = torch.nn.Linear(
                    self.sphere_channels_all,
                    self.lmax * self.hidden_channels,
                )
                self.gate_act = GateActivation(
                    self.lmax, self.lmax, self.hidden_channels
                )
            else:
                if self.use_sep_s2_act:
                    self.gating_linear = torch.nn.Linear(
                        self.sphere_channels_all, self.hidden_channels
                    )
                    self.s2_act = SeparableS2Activation(self.lmax, self.lmax)
                else:
                    self.gating_linear = None
                    self.s2_act = S2Activation(self.lmax, self.lmax)
        self.so3_linear_2 = SO3_LinearV2(
            self.hidden_channels, self.output_channels, lmax=self.lmax
        )

    def forward(self, input_embedding):
        out_shape = input_embedding.shape[:-2]

        input_embedding = input_embedding.reshape(out_shape.numel(),(self.lmax+1)**2,self.sphere_channels)
        #######################for memory saving
        x = SO3_Embedding(
            input_embedding.shape[0],
            self.lmax_list,
            self.sphere_channels,
            input_embedding.device,
            input_embedding.dtype,
        )
        x.embedding = input_embedding
        x = self._forward(x)

        return x.embedding.reshape(out_shape+(-1,self.output_channels))

    def _forward(self, input_embedding):
        gating_scalars = None
        if self.use_grid_mlp:
            if self.use_sep_s2_act:
                gating_scalars = self.scalar_mlp(
                    input_embedding.embedding.narrow(1, 0, 1)
                )
        else:
            if self.gating_linear is not None:
                gating_scalars = self.gating_linear(
                    input_embedding.embedding.narrow(1, 0, 1)
                )

        input_embedding = self.so3_linear_1(input_embedding)

        if self.use_grid_mlp:
            # Project to grid
            input_embedding_grid = input_embedding.to_grid(
                self.so3_grid, lmax=self.max_lmax
            )
            # Perform point-wise operations
            input_embedding_grid = self.grid_mlp(input_embedding_grid)
            # Project back to spherical harmonic coefficients
            input_embedding._from_grid(
                input_embedding_grid, self.so3_grid, lmax=self.max_lmax
            )

            if self.use_sep_s2_act:
                input_embedding.embedding = torch.cat(
                    (
                        gating_scalars,
                        input_embedding.embedding.narrow(
                            1, 1, input_embedding.embedding.shape[1] - 1
                        ),
                    ),
                    dim=1,
                )
        else:
            if self.use_gate_act:
                input_embedding.embedding = self.gate_act(
                    gating_scalars, input_embedding.embedding
                )
            else:
                if self.use_sep_s2_act:
                    input_embedding.embedding = self.s2_act(
                        gating_scalars,
                        input_embedding.embedding,
                        self.so3_grid,
                    )
                else:
                    input_embedding.embedding = self.s2_act(
                        input_embedding.embedding, self.so3_grid
                    )

        return self.so3_linear_2(input_embedding)
