

# from sfm.models.psm.equivariant.scalable.src.utils import *
from sfm.models.psm.equivariant.scalable.src.utils.data_preprocess import *
from sfm.models.psm.equivariant.scalable.src.utils.graph_utils import *
import torch_geometric
from e3nn.o3._spherical_harmonics import _spherical_harmonics
from sfm.models.psm.equivariant.scalable.src.EScAIP import *
from fairchem.core.models.base import HydraModel
import yaml
from torch import nn

class EScAIPBackbone(torch.nn.Module):
    """
    Efficiently Scaled Attention Interactomic Potential (EScAIP) backbone model.
    """

    def __init__(
        self,
        cfg,
    ):
        super().__init__()

        # load configs
        self.global_cfg = cfg.global_cfg
        self.molecular_graph_cfg = cfg.molecular_graph_cfg
        self.gnn_cfg = cfg.gnn_cfg
        self.reg_cfg = cfg.reg_cfg

        # for trainer
        self.regress_forces = cfg.global_cfg.regress_forces
        self.use_pbc = cfg.molecular_graph_cfg.use_pbc

        # graph generation
        # self.use_pbc_single = (
        #     self.molecular_graph_cfg.use_pbc_single
        # )  # TODO: remove this when FairChem fixes the bug
        # generate_graph_fn = partial(
        #     self.generate_graph,
        #     cutoff=self.molecular_graph_cfg.max_radius,
        #     max_neighbors=self.molecular_graph_cfg.max_neighbors,
        #     use_pbc=self.molecular_graph_cfg.use_pbc,
        #     otf_graph=self.molecular_graph_cfg.otf_graph,
        #     enforce_max_neighbors_strictly=self.molecular_graph_cfg.enforce_max_neighbors_strictly,
        #     use_pbc_single=self.molecular_graph_cfg.use_pbc_single,
        # )

        # # data preprocess
        # self.data_preprocess = partial(
        #     data_preprocess,
        #     generate_graph_fn=generate_graph_fn,
        #     global_cfg=self.global_cfg,
        #     gnn_cfg=self.gnn_cfg,
        #     molecular_graph_cfg=self.molecular_graph_cfg,
        # )

        ## Model Components

        # model = HydraModel(**config["model"]["backbone"])

        # Input Block
        self.input_block = InputBlock(
            global_cfg=self.global_cfg,
            molecular_graph_cfg=self.molecular_graph_cfg,
            gnn_cfg=self.gnn_cfg,
            reg_cfg=self.reg_cfg,
        )

        # Transformer Blocks
        self.transformer_blocks = nn.ModuleList(
            [
                EfficientGraphAttentionBlock(
                    global_cfg=self.global_cfg,
                    molecular_graph_cfg=self.molecular_graph_cfg,
                    gnn_cfg=self.gnn_cfg,
                    reg_cfg=self.reg_cfg,
                )
                for _ in range(self.gnn_cfg.num_layers)
            ]
        )

        # Readout Layer
        self.readout_layers = nn.ModuleList(
            [
                ReadoutBlock(
                    global_cfg=self.global_cfg,
                    gnn_cfg=self.gnn_cfg,
                    reg_cfg=self.reg_cfg,
                )
                for _ in range(self.gnn_cfg.num_layers + 1)
            ]
        )

        # Output Projection
        self.output_projection = OutputProjection(
            global_cfg=self.global_cfg,
            gnn_cfg=self.gnn_cfg,
            reg_cfg=self.reg_cfg,
        )

        # init weights
        self.apply(init_linear_weights)

        # enable torch.set_float32_matmul_precision('high') if not using fp16 backbone
        if not self.global_cfg.use_fp16_backbone:
            torch.set_float32_matmul_precision("high")
        torch._logging.set_logs(recompiles=True)

        self.forward_fn = (
            torch.compile(self.compiled_forward)
            if self.global_cfg.use_compile
            else self.compiled_forward
        )

    def compiled_forward(self, data: GraphAttentionData):
        # input block
        node_features, edge_features = self.input_block(data)

        # input readout
        readouts = self.readout_layers[0](node_features, edge_features)
        node_readouts = [readouts[0]]
        edge_readouts = [readouts[1]]

        # transformer blocks
        for idx in range(self.gnn_cfg.num_layers):
            # print("idx and feature statics:",idx,torch.mean(torch.abs(node_features)),torch.mean(torch.abs(edge_features)))
            node_features, edge_features = self.transformer_blocks[idx](
                data, node_features, edge_features
            )
            readouts = self.readout_layers[idx + 1](node_features, edge_features)
            node_readouts.append(readouts[0])
            edge_readouts.append(readouts[1])

        node_features, edge_features = self.output_projection(
            node_readouts=torch.cat(node_readouts, dim=-1),
            edge_readouts=torch.cat(edge_readouts, dim=-1),
        )

        return node_features,edge_features
    
        # return {
        #     "data": data,
        #     "node_features": node_features,
        #     "edge_features": edge_features,
        # }

    # @conditional_grad(torch.enable_grad())
    def forward(self, data: torch_geometric.data.Batch):
        # # gradient force
        # if self.regress_forces and not self.global_cfg.direct_force:
        #     data.pos.requires_grad_(True)

        # preprocess data
        # x = self.data_preprocess(data)

        return self.forward_fn(data)

    @torch.jit.ignore
    def no_weight_decay(self):
        return no_weight_decay(self)



@registry.register_model("EScAIP_direct_force_head")
class EScAIPDirectForceHead(torch.nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.global_cfg = cfg.global_cfg
        self.gnn_cfg = cfg.gnn_cfg
        self.reg_cfg = cfg.reg_cfg
        self.force_direction_layer = OutputLayer(
            global_cfg=self.global_cfg,
            gnn_cfg=self.gnn_cfg,
            reg_cfg=self.reg_cfg,
            output_type="Vector",
        )
        self.force_magnitude_layer = OutputLayer(
            global_cfg=self.global_cfg,
            gnn_cfg=self.gnn_cfg,
            reg_cfg=self.reg_cfg,
            output_type="Scalar",
        )

        self.post_init()

    def compiled_forward(self, edge_features, node_features, data: GraphAttentionData):
        # get force direction from edge features
        force_direction = self.force_direction_layer(
            edge_features
        )  # (num_nodes, max_neighbor, 3)
        force_direction = (
            force_direction * data.edge_direction
        )  # (num_nodes, max_neighbor, 3)
        force_direction = (force_direction * data.neighbor_mask.unsqueeze(-1)).sum(
            dim=1
        )  # (num_nodes, 3)
        # get force magnitude from node readouts
        force_magnitude = self.force_magnitude_layer(node_features)  # (num_nodes, 1)
        # get output force
        return force_direction * force_magnitude  # (num_nodes, 3)
    
    def post_init(self, gain=1.0):
        # init weights
        self.apply(partial(init_linear_weights, gain=gain))

        self.forward_fn = (
            torch.compile(self.compiled_forward)
            if self.global_cfg.use_compile
            else self.compiled_forward
        )
    
    def forward(self, data, node_features,edge_features) -> dict[str, torch.Tensor]:
        force_output = self.forward_fn(
            edge_features=edge_features,
            node_features=node_features,
            data=data,
        )

        return force_output


@registry.register_model("EScAIP_energy_head")
class EScAIPEnergyHead(torch.nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.global_cfg = cfg.global_cfg
        self.gnn_cfg = cfg.gnn_cfg
        self.reg_cfg = cfg.reg_cfg

        self.energy_layer = OutputLayer(
            global_cfg=self.global_cfg,
            gnn_cfg=self.gnn_cfg,
            reg_cfg=self.reg_cfg,
            output_type="Scalar",
        )

        self.post_init(gain=0.01)

    def post_init(self, gain=1.0):
        # init weights
        self.apply(partial(init_linear_weights, gain=gain))

        self.forward_fn = (
            torch.compile(self.compiled_forward)
            if self.global_cfg.use_compile
            else self.compiled_forward
        )
    def compiled_forward(self, node_features, data: GraphAttentionData):
        energy_output = self.energy_layer(node_features)

        return energy_output

    def forward(self, data, node_features,edge_features) -> dict[str, torch.Tensor]:
        energy_output = self.forward_fn(
            node_features=node_features,
            data=data,
        )
        return energy_output

@torch.jit.script
def get_node_direction_expansion_topK(
    distance_vec: torch.Tensor, lmax: int
):
    """
    Calculate Bond-Orientational Order (BOO) for each node in the graph.
    Ref: Steinhardt, et al. "Bond-orientational order in liquids and glasses." Physical Review B 28.2 (1983): 784.
    Return: (N, )
    """
    num_nodes,topK = distance_vec.shape[:2]
    distance_vec = torch.nn.functional.normalize(distance_vec, dim=-1)
    edge_sh = _spherical_harmonics(
        lmax=lmax,
        x=distance_vec[:,:, 0],
        y=distance_vec[:,:, 1],
        z=distance_vec[:,:, 2],
    )
    node_boo = torch.sum(edge_sh,dim = 1)
    sh_index = torch.arange(lmax + 1, device=node_boo.device)
    sh_index = torch.repeat_interleave(sh_index, 2 * sh_index + 1)
    node_boo = scatter(node_boo**2, sh_index, dim=1, reduce="sum").sqrt()
    return node_boo


class SCA_model(nn.Module):
    def __init__(self,cfg):
        super().__init__()

        # edge distance expansion
        expansion_func = {
            "gaussian": GaussianSmearing,
            "sigmoid": SigmoidSmearing,
            "linear_sigmoid": LinearSigmoidSmearing,
            "silu": SiLUSmearing,
        }[cfg["distance_function"]]

        self.max_neighbors = cfg["max_neighbors"]
        self.max_radius = cfg["max_radius"]
        self.atten_num_heads = cfg["atten_num_heads"]
        self.number_of_basis = cfg["edge_distance_expansion_size"]
        self.edge_distance_expansion_func = expansion_func(
            0.0,
            self.max_radius,
            self.number_of_basis, # number of basis
            basis_width_scalar=2.0,
        )

        sca_cfg = init_configs(EScAIPConfigs, cfg)

        self.backbone = EScAIPBackbone(sca_cfg)
        self.force_head = EScAIPDirectForceHead(sca_cfg)
        self.energy_head = EScAIPEnergyHead(sca_cfg)
    



    def forward(
        self,
        batched_data,
        token_embedding: torch.Tensor,
        mixed_attn_bias=None,
        padding_mask: torch.Tensor = None,
        pbc_expand_batched = None,
        time_embed = None,
        sepFN=False,
        **kwargs,
    ) -> torch.Tensor:
        
        device = padding_mask.device
        B, L = padding_mask.shape[:2]
        
        node_pos = batched_data["pos"]
        node_pos.requires_grad = True


        # padding mask and non_atom_mask diff lie in cell 8 point
        non_atom_mask = batched_data["non_atom_mask"]
        node_pos = torch.where(non_atom_mask.unsqueeze(dim = -1).repeat(1,1,3),999.,node_pos)
        node_mask = torch.logical_not(non_atom_mask)
        atomic_numbers = batched_data["masked_token_type"].reshape(B, L)[node_mask]
        ptr = torch.cat([torch.Tensor([0,]).int().to(device),
                        torch.cumsum(torch.sum(node_mask,dim = -1),dim = -1)]
                        ,dim = 0)
        f_node_pos = node_pos[node_mask]


        # expand_node_mask = node_mask
        expand_node_pos = node_pos
        expand_ptr = ptr
        outcell_index = torch.arange(L).unsqueeze(dim = 0).repeat(B,1).to(device)
        f_exp_node_pos = f_node_pos
        f_outcell_index = torch.arange(len(f_node_pos)).to(device)
        mol_type = 0  # torch.any(batched_data["is_molecule"]):
        L2 = L
        if torch.any(batched_data["is_periodic"]):
            mol_type = 1
            #  batched_data["outcell_index"] # B*L2
            # batched_data["outcell_index_0"] # B*L2
            # batched_data.update(pbc_expand_batched)
            L2 = pbc_expand_batched["outcell_index"].shape[1]
            outcell_index = pbc_expand_batched['outcell_index']
            # outcell_index_0 = (torch.arange(B).reshape(B, 1).repeat(1,batched_data["outcell_index"].shape[1] ))
            expand_node_pos = pbc_expand_batched["expand_pos"].float()
            expand_node_pos[pbc_expand_batched["expand_mask"]] = 999 # set expand node pos padding to 9999
            expand_node_mask = torch.logical_not(pbc_expand_batched["expand_mask"])
            expand_ptr = torch.cat([torch.Tensor([0,]).int().to(device),
                            torch.cumsum(torch.sum(expand_node_mask,dim = -1),dim = -1)]
                            ,dim = 0)
            f_exp_node_pos = expand_node_pos[expand_node_mask]
            f_outcell_index = (outcell_index+ptr[:B,None])[expand_node_mask] # e.g. n1*hidden [flatten_outcell_index]  -> n2*hidden
        if torch.any(batched_data["is_protein"]):
            mol_type = 2
        batched_data["mol_type"] = mol_type


        edge_vec = node_pos.unsqueeze(2) - expand_node_pos.unsqueeze(1)
        dist = torch.norm(edge_vec, dim=-1)  # B*L*L Attention: ego-connection is 0 here
        dist = torch.where(dist < 1e-4, 1000, dist)
        # dist_embedding = self.rbf(dist.reshape(-1)).reshape(B, L, L2, self.number_of_basis)  # [B, L, L, number_of_basis]
        _, neighbor_indices = dist.sort(dim=-1)
        topK = min(L2, self.max_neighbors)
        neighbor_indices = neighbor_indices[:, :, :topK]  # Shape: B*L*K
        # neighbor_indices = torch.arange(topK).reshape(1,1,topK).repeat(B,L,1).to(device)
        # neighbor_indices = torch.arange(K).to(device).reshape(1,1,K).repeat(B,L,1)
        dist = torch.gather(dist, dim=-1, index=neighbor_indices)  # Shape: B*L*topK
        attn_mask = (dist >self.max_radius) | (dist<1e-4)
        attn_mask = attn_mask[node_mask] #.unsqueeze(dim = -1)
        f_dist = dist[node_mask] #flattn_N* topK*


        f_sparse_idx_node = (torch.gather(outcell_index.unsqueeze(1).repeat(1,L,1),
                              2,
                              neighbor_indices)+ptr[:B,None,None])[node_mask]
        f_sparse_idx_node = torch.clamp(f_sparse_idx_node,max = ptr[B]-1)
        f_sparse_idx_expnode = (neighbor_indices+expand_ptr[:B,None,None])[node_mask]
        f_sparse_idx_expnode = torch.clamp(f_sparse_idx_expnode,max = expand_ptr[B]-1)
        f_edge_vec = f_node_pos.unsqueeze(dim = 1)-f_exp_node_pos[f_sparse_idx_expnode]
        ##############↑↑↑↑↑↑↑ upper is for data process↑↑↑↑↑↑↑#################


        ############ this is for SCA
        neighbor_mask = ~attn_mask #torch.randn(f_N,topK)>0 # is neighbor 1
        # edge_scalars = torch.randn(f_N,topK,num_basis)
        # f_sparse_idx_node = torch.randint(0,f_N,(f_N,topK))
        # f_sparse_idx_expnode = torch.randint(0,f_N2,(f_N2,topK))
        # node direction expansion
        # get_node_direction_expansion

        # GraphAttentionData:
        # atomic numbers
        atomic_numbers = atomic_numbers.long()
        # print(torch.max(atomic_numbers))
        # if torch.max(atomic_numbers)>90:raise ValueError()

        f_N = f_node_pos.shape[0]
        f_N2 = f_exp_node_pos.shape[0]

        node_direction_expansion = get_node_direction_expansion_topK(f_edge_vec,lmax = 9) #torch.randn(f_N,boo_lmax+1)
        # node_batch = torch.randint(0,B,(f_N,))
        neighbor_list = f_sparse_idx_node


        edge_distance_expansion = self.edge_distance_expansion_func(torch.norm(f_edge_vec,dim = -1))
        edge_distance_expansion = edge_distance_expansion.reshape(f_N,topK,-1)
        edge_direction = -f_edge_vec / torch.clamp(torch.norm(f_edge_vec,dim = -1,keepdim=True),min = 1e-8)


        # get attention mask
        attn_mask_inf, angle_embedding = get_attn_mask(
            edge_direction=edge_direction,
            neighbor_mask=neighbor_mask,
            num_heads=self.atten_num_heads,
            use_angle_embedding=None,
            filled_value = -1e6)


        atten_name= "memory_efficient"
        torch.backends.cuda.enable_flash_sdp(atten_name == "flash")
        torch.backends.cuda.enable_mem_efficient_sdp(atten_name == "memory_efficient")
        torch.backends.cuda.enable_math_sdp(atten_name == "math")

        # construct input data
        data = GraphAttentionData(
            atomic_numbers=atomic_numbers,
            node_direction_expansion=node_direction_expansion,
            edge_distance_expansion=edge_distance_expansion,
            edge_direction=edge_direction,
            attn_mask=attn_mask_inf,
            angle_embedding=angle_embedding,
            neighbor_list=neighbor_list,
            neighbor_mask=neighbor_mask,
            node_batch=None,
            node_padding_mask=torch.ones_like(atomic_numbers, dtype=torch.bool),
            graph_padding_mask= torch.ones(B, dtype=torch.bool),
        )


        node_fea,edge_fea = self.backbone(data)
        pred_energy = self.energy_head(data,node_fea,edge_fea)
        pred_force = self.force_head(data,node_fea,edge_fea)

        
        node_attr = torch.zeros((B , L, 1), device=device)
        node_vec = torch.zeros((B , L, 3 ), device=device)
        node_attr[node_mask] = pred_energy
        node_vec[node_mask] = pred_force  # the part of order 0

        return node_attr, node_vec

        # return pred_energy,pred_force