import inspect
import numbers
from typing import Optional, Tuple

# import numpy as np
import torch
import torch.nn as nn
import torch.nn.init as init

# import torch_geometric.data
import torch_geometric.graphgym.register as register
from torch.nn.parameter import Parameter
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.models.gnn import GNNPreMP
from torch_geometric.graphgym.models.layer import (
    BatchNorm1dNode,
    new_layer_config,
)
from torch_geometric.graphgym.register import register_network

from graphgps.layer.gps_layer import GPSLayer
from graphgps.slt.custom_attn import CustomMultiheadAttention
from graphgps.slt.custom_generalmultilayer import CustomGNNPreMP  # 追加
from graphgps.slt.monarch_linear import MonarchLinear
from graphgps.slt.sparse_modules import (
    SLT_BondEncoder,
    SparseLinear,
    SparseLinearMulti_mask,
)
from graphgps.tome.utils import parse_r

# import time


def percentile(t, q):
    k = 1 + round(0.01 * float(q) * (t.numel() - 1))
    return t.view(-1).kthvalue(k).values.item()


class FeatureEncoder(torch.nn.Module):
    """
    Encoding node and edge features

    Args:
        dim_in (int): Input feature dimension
    """

    def __init__(self, dim_in):
        super(FeatureEncoder, self).__init__()
        self.dim_in = dim_in
        if cfg.dataset.node_encoder:
            # Encode integer node features via nn.Embeddings
            NodeEncoder = register.node_encoder_dict[
                cfg.dataset.node_encoder_name
            ]
            self.node_encoder = NodeEncoder(cfg.gnn.dim_inner)
            if cfg.dataset.node_encoder_bn:
                self.node_encoder_bn = BatchNorm1dNode(
                    new_layer_config(
                        cfg.gnn.dim_inner,
                        -1,
                        -1,
                        has_act=False,
                        has_bias=False,
                        cfg=cfg,
                    )
                )
            # Update dim_in to reflect the new dimension of the node features
            self.dim_in = cfg.gnn.dim_inner
        if cfg.dataset.edge_encoder:
            # Hard-limit max edge dim for PNA.
            if "PNA" in cfg.gt.layer_type:
                cfg.gnn.dim_edge = min(128, cfg.gnn.dim_inner)
            else:
                cfg.gnn.dim_edge = cfg.gnn.dim_inner
            # Encode integer edge features via nn.Embeddings
            EdgeEncoder = register.edge_encoder_dict[
                cfg.dataset.edge_encoder_name
            ]
            if "ogbg" in cfg.dataset.name:
                if cfg.slt.encoder is True:
                    if (cfg.slt.sm is True or cfg.slt.mm is True) and (
                        cfg.slt.embedding is True
                    ):
                        self.edge_encoder = SLT_BondEncoder(cfg.gnn.dim_edge)
                    else:
                        self.edge_encoder = EdgeEncoder(cfg.gnn.dim_edge)
                else:
                    self.edge_encoder = EdgeEncoder(cfg.gnn.dim_edge)
            else:
                self.edge_encoder = EdgeEncoder(cfg.gnn.dim_edge)

            if cfg.dataset.edge_encoder_bn:
                self.edge_encoder_bn = BatchNorm1dNode(
                    new_layer_config(
                        cfg.gnn.dim_edge,
                        -1,
                        -1,
                        has_act=False,
                        has_bias=False,
                        cfg=cfg,
                    )
                )

    def forward(
        self,
        batch,
        cur_epoch=None,
        mpnn_th=None,
        msa_th=None,
        ffn_th=None,
        encoder_th=None,
        pred_th=None,
        global_th=None,
    ):
        for module in self.children():
            if (
                module.__class__.__name__ == "BondEncoder"
                or module.__class__.__name__ == "AtomEncoder"
            ):
                batch = module(batch)

            # elif module.__class__.__name__ == "LinearEdgeEncoder":
            #     batch = module(batch)
            elif cfg.slt.sm is True or cfg.slt.mm is True:
                batch = module(
                    batch,
                    cur_epoch=cur_epoch,
                    mpnn_th=mpnn_th,
                    msa_th=msa_th,
                    ffn_th=ffn_th,
                    encoder_th=encoder_th,
                    pred_th=pred_th,
                    global_th=global_th,
                )

            else:
                batch = module(batch)

        return batch


class GraphSequential(nn.Module):
    def __init__(self, *args):
        super(GraphSequential, self).__init__()
        self.mlpmodules = nn.ModuleList(args)
        self.pooling_fun = register.pooling_dict[cfg.model.graph_pooling]

    def _apply_index(self, batch):
        return batch.graph_feature, batch.y

    def forward(
        self,
        batch,
        cur_epoch=None,
        mpnn_th=None,
        msa_th=None,
        ffn_th=None,
        encoder_th=None,
        pred_th=None,
        global_th=None,
    ):
        # flag = True
        graph_emb = self.pooling_fun(batch.x, batch.batch)
        for module in self.mlpmodules:
            if isinstance(module, nn.ReLU):
                graph_emb = module(graph_emb)
            else:
                # if isinstance(batch, torch_geometric.data.Batch) and flag:
                #     batch_x = batch.x
                #     flag = False
                if cfg.slt.pruning == "blockwise":
                    graph_emb = module(graph_emb, pred_th)
                    # graph_emb = batch_x
                elif cfg.slt.pruning == "global":
                    graph_emb = module(graph_emb, global_th)
                    # graph_emb = batch_x

        batch.graph_feature = graph_emb
        pred, label = self._apply_index(batch)
        return pred, label


class RMSNorm(nn.Module):
    __constants__ = ["normalized_shape", "eps", "elementwise_affine"]
    normalized_shape: Tuple[int, ...]
    eps: Optional[float]
    elementwise_affine: bool

    def __init__(
        self,
        normalized_shape,
        eps: Optional[float] = None,
        elementwise_affine: bool = True,
        device=None,
        dtype=None,
    ) -> None:
        super().__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        self.normalized_shape = tuple(normalized_shape)
        self.eps = eps if eps is not None else torch.finfo(torch.float32).eps
        self.elementwise_affine = elementwise_affine

        if self.elementwise_affine:
            self.weight = Parameter(
                torch.empty(self.normalized_shape, device=device, dtype=dtype)
            )
        else:
            self.register_parameter("weight", None)
        self.reset_parameters()

    def reset_parameters(self) -> None:
        if self.elementwise_affine:
            init.ones_(self.weight)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Ensure weight is on the same device as input tensor x
        if self.weight is not None:
            self.weight = self.weight.to()

        rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
        x = x / rms
        if self.elementwise_affine:
            x = x * self.weight
        return x

    def extra_repr(self) -> str:
        return (
            "{normalized_shape}, eps={eps}, "
            "elementwise_affine={elementwise_affine}".format(**self.__dict__)
        )


class CustomSequential(torch.nn.Module):
    def __init__(self, *modules):
        super(CustomSequential, self).__init__()
        # if cfg.slt.rmsnorm:
        #     self.rms_norm = RMSNorm(
        #         cfg.gt.dim_hidden, elementwise_affine=False
        #     ).cuda()
        for idx, module in enumerate(modules):
            self.add_module(str(idx), module)

    def forward(
        self,
        x,
        cur_epoch=None,
        mpnn_th=None,
        msa_th=None,
        ffn_th=None,
        encoder_th=None,
        pred_th=None,
        global_th=None,
    ):

        for module in self.children():
            # if module.__class__.__name__ == "RMSNorm":
            #     x.x = module(x.x)
            # if cfg.slt.tome:
            #     if isinstance(module, GPSLayer):
            #         # module.__class__ = ToMe
            #         module._tome_info = self._tome_info

            if "cur_epoch" in inspect.signature(module.forward).parameters:
                x = module(
                    x,
                    cur_epoch=cur_epoch,
                    mpnn_th=mpnn_th,
                    msa_th=msa_th,
                    ffn_th=ffn_th,
                    encoder_th=encoder_th,
                    pred_th=pred_th,
                    global_th=global_th,
                )
            else:
                x = module(x)
        return x


@register_network("GPSModel")
class GPSModel(torch.nn.Module):
    """General-Powerful-Scalable graph transformer.
    https://arxiv.org/abs/2205.12454
    Rampasek, L., Galkin, M., Dwivedi, V. P., Luu, A. T., Wolf, G., & Beaini, D.
    Recipe for a general, powerful, scalable graph transformer. (NeurIPS 2022)
    """

    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.encoder = FeatureEncoder(dim_in)
        dim_in = self.encoder.dim_in

        if cfg.slt.sm is True or cfg.slt.mm is True:
            if cfg.gnn.layers_pre_mp > 0:
                self.pre_mp = CustomGNNPreMP(
                    dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp
                )
                dim_in = cfg.gnn.dim_inner
        else:
            if cfg.gnn.layers_pre_mp > 0:
                self.pre_mp = GNNPreMP(
                    dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp
                )
                dim_in = cfg.gnn.dim_inner

        if not cfg.gt.dim_hidden == cfg.gnn.dim_inner == dim_in:
            raise ValueError(
                f"The inner and hidden dims must match: "
                f"embed_dim={cfg.gt.dim_hidden} dim_inner={cfg.gnn.dim_inner} "
                f"dim_in={dim_in}"
            )

        try:
            local_gnn_type, global_model_type = cfg.gt.layer_type.split("+")
        except ValueError:
            raise ValueError(f"Unexpected layer type: {cfg.gt.layer_type}")
        layers = []

        if cfg.slt.folding:
            layers.append(
                GPSLayer(
                    dim_h=cfg.gt.dim_hidden,
                    local_gnn_type=local_gnn_type,
                    global_model_type=global_model_type,
                    num_heads=cfg.gt.n_heads,
                    act=cfg.gnn.act,
                    pna_degrees=cfg.gt.pna_degrees,
                    equivstable_pe=cfg.posenc_EquivStableLapPE.enable,
                    dropout=cfg.gt.dropout,
                    attn_dropout=cfg.gt.attn_dropout,
                    layer_norm=cfg.gt.layer_norm,
                    batch_norm=cfg.gt.batch_norm,
                    bigbird_cfg=cfg.gt.bigbird,
                    log_attn_weights=cfg.train.mode == "log-attn-weights",
                )
            )
        else:
            for _ in range(cfg.gt.layers):
                layers.append(
                    GPSLayer(
                        dim_h=cfg.gt.dim_hidden,
                        local_gnn_type=local_gnn_type,
                        global_model_type=global_model_type,
                        num_heads=cfg.gt.n_heads,
                        act=cfg.gnn.act,
                        pna_degrees=cfg.gt.pna_degrees,
                        equivstable_pe=cfg.posenc_EquivStableLapPE.enable,
                        dropout=cfg.gt.dropout,
                        attn_dropout=cfg.gt.attn_dropout,
                        layer_norm=cfg.gt.layer_norm,
                        batch_norm=cfg.gt.batch_norm,
                        bigbird_cfg=cfg.gt.bigbird,
                        log_attn_weights=cfg.train.mode == "log-attn-weights",
                    )
                )

        if cfg.slt.sm is True or cfg.slt.mm is True:
            self.layers = CustomSequential(*layers)
        else:
            self.layers = torch.nn.Sequential(*layers)

        GNNHead = register.head_dict[cfg.gnn.head]
        if (
            (
                cfg.slt.sm is True
                or cfg.slt.mm is True
                or cfg.monarch.pred is True
            )
            and (cfg.gnn.head == "graph")
            and (cfg.slt.pred is True)
        ):
            graphlayers = []

            if cfg.slt.pred is True:
                if cfg.monarch.pred is True:
                    graphlayers.append(
                        MonarchLinear(
                            cfg.gnn.dim_inner, cfg.gnn.dim_inner, bias=False
                        )
                    )
                elif cfg.slt.sm is True:
                    graphlayers.append(
                        SparseLinear(
                            cfg.gnn.dim_inner, cfg.gnn.dim_inner, bias=False
                        )
                    )
                elif cfg.slt.mm is True:
                    graphlayers.append(
                        SparseLinearMulti_mask(
                            cfg.gnn.dim_inner, cfg.gnn.dim_inner, bias=False
                        )
                    )
            graphlayers.append(nn.ReLU())
            if cfg.slt.pred is True:
                if cfg.monarch.pred is True:
                    graphlayers.append(
                        MonarchLinear(cfg.gnn.dim_inner, dim_out, bias=False)
                    )
                elif cfg.slt.sm is True:
                    graphlayers.append(
                        SparseLinear(cfg.gnn.dim_inner, dim_out, bias=False)
                    )
                elif cfg.slt.mm is True:
                    graphlayers.append(
                        SparseLinearMulti_mask(
                            cfg.gnn.dim_inner, dim_out, bias=False
                        )
                    )

            self.post_mp = GraphSequential(*graphlayers)
        else:
            self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out)

    # def forward(self, batch, cur_epoch=None):
    #     for module in self.children():
    #         if module.__class__.__name__ == "GNNGraphHead":
    #             batch = module(batch)
    #         elif cfg.slt.sm is True or cfg.slt.mm is True:
    #             batch = module(batch, cur_epoch)
    #         else:
    #             batch = module(batch)
    #     return batch

    def forward(
        self,
        batch,
        cur_epoch=None,
        mpnn_th=None,
        msa_th=None,
        ffn_th=None,
        encoder_th=None,
        pred_th=None,
        global_th=None,
    ):
        # module_times = {}

        for module in self.children():
            if module.__class__.__name__ == "GNNGraphHead":
                if (
                    cfg.slt.sm is True or cfg.slt.mm is True
                ) and cfg.slt.pred is True:
                    # start = time.time()
                    batch = module(
                        batch,
                        cur_epoch,
                        mpnn_th=mpnn_th,
                        msa_th=msa_th,
                        ffn_th=ffn_th,
                        encoder_th=encoder_th,
                        pred_th=pred_th,
                        global_th=global_th,
                    )
                    # end = time.time()
                    # module_name = module.__class__.__name__
                    # if module_name not in module_times:
                    #     module_times[module_name] = 0.0
                    # module_times[module_name] += end - start

                else:
                    # start = time.time()
                    batch = module(batch)
                    # end = time.time()
                    # module_name = module.__class__.__name__
                    # if module_name not in module_times:
                    #     module_times[module_name] = 0.0
                    # module_times[module_name] += end - start

            elif cfg.slt.sm is True or cfg.slt.mm is True:
                if (
                    module.__class__.__name__ == "CustomSequential"
                    and cfg.slt.folding
                ):
                    for _ in range(cfg.gt.layers):
                        # start = time.time()
                        batch = module(
                            batch,
                            cur_epoch,
                            mpnn_th=mpnn_th,
                            msa_th=msa_th,
                            ffn_th=ffn_th,
                            encoder_th=encoder_th,
                            pred_th=pred_th,
                            global_th=global_th,
                        )
                        # end = time.time()
                        # module_name = module.__class__.__name__
                        # if module_name not in module_times:
                        #     module_times[module_name] = 0.0
                        # module_times[module_name] += end - start
                else:
                    # start = time.time()

                    batch = module(
                        batch,
                        cur_epoch,
                        mpnn_th=mpnn_th,
                        msa_th=msa_th,
                        ffn_th=ffn_th,
                        encoder_th=encoder_th,
                        pred_th=pred_th,
                        global_th=global_th,
                    )

            else:
                batch = module(batch)

        return batch
