from typing import Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.graphgym.register as register
import torch_geometric.nn as pyg_nn
from numpy.random import default_rng
from torch import Tensor
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.models.layer import LayerConfig
from torch_geometric.graphgym.register import register_layer
from torch_geometric.nn.conv import GINEConv

# from torch_sparse import SparseTensor, matmul
from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size
from torch_scatter import scatter

from graphgps.slt.monarch_linear import MonarchLinear
from graphgps.slt.sparse_modules import (  # 追加
    BitLinear,
    NMSparseMultiLinear,
    SparseLinear,
    SparseLinearMulti_mask,
)


def percentile(t, q):
    k = 1 + round(0.01 * float(q) * (t.numel() - 1))
    return t.view(-1).kthvalue(k).values.item()


def prune_edge_index(edge_index, prune_ratio):
    """
    Prunes the given edge_index by the specified prune_ratio.
    Args:
    - edge_index (Tensor): The edge_index tensor to be pruned.
    - prune_ratio (float): The ratio of edges to be pruned.
    - seed (int, optional): The random seed for reproducibility.

    Returns:
    - Tensor: The pruned edge_index tensor.
    """

    num_edges = edge_index.size(1)
    num_edges_to_prune = int(num_edges * prune_ratio)

    # Randomly select edges to prune
    rng = default_rng(cfg.seed)

    # Randomly select edges to prune using the local generator
    indices_to_prune = rng.choice(num_edges, num_edges_to_prune, replace=False)

    # Create a mask that keeps the edges not selected for pruning
    mask = torch.ones(num_edges, dtype=torch.bool)
    mask[indices_to_prune] = False

    # Apply mask to edge_index
    pruned_edge_index = edge_index[:, mask]

    return pruned_edge_index


class Custom_GatedGCNLayer(pyg_nn.conv.MessagePassing):
    """
    GatedGCN layer
    Residual Gated Graph ConvNets
    https://arxiv.org/pdf/1711.07553.pdf
    """

    def calculate_sparsity(
        self, sparsity, epoch, max_epoch_half, schedule_type
    ):
        """Calculate sparsity based on the scheduling type."""
        if schedule_type == "linear":
            return sparsity * (epoch / max_epoch_half)
        elif schedule_type == "curve":
            curve_proportion = (max_epoch_half - epoch) / max_epoch_half
            scale_factor = np.sqrt(1 - curve_proportion**2)
            return sparsity * scale_factor
        elif schedule_type == "step":
            step_size = max_epoch_half // 10
            current_step = (epoch // step_size) + 1
            return (sparsity / 10) * current_step
        elif schedule_type == "reverse":
            return 1 - ((1 - sparsity) * (epoch / max_epoch_half))
        return sparsity

    def calculate_sparsities(
        self, sparsity, epoch, max_epoch_half, schedule_type
    ):
        """Calculate sparsities for multi-mode (mm) based on the scheduling type."""
        if schedule_type == "linear":
            return [value * (epoch / max_epoch_half) for value in sparsity]
        elif schedule_type == "curve":
            curve_proportion = (max_epoch_half - epoch) / max_epoch_half
            scale_factor = np.sqrt(1 - curve_proportion**2)
            increments = [
                (sparsity[i] - sparsity[0]) / (1 - sparsity[0])
                for i in range(1, len(sparsity))
            ]
            sparsities = []
            for i, value in enumerate(sparsity):
                if i == 0:
                    sparsities.append(value * scale_factor)
                else:
                    start_point = increments[i - 1]
                    sparsities.append(
                        start_point + (value - start_point) * scale_factor
                    )
            return sparsities
        elif schedule_type == "step":
            step_size = max_epoch_half // 10
            current_step = (epoch // step_size) + 1
            return [
                value * ((value / 10) * current_step) for value in sparsity
            ]
        elif schedule_type == "reverse":
            return [
                1 - ((1 - value) * (epoch / max_epoch_half))
                for value in sparsity
            ]
        return sparsity

    def get_threshold(self, sparsity, epoch=None, keyword=""):
        max_epoch_half = cfg.optim.max_epoch // 2
        is_training = self.training
        sparsities = sparsity

        if cfg.slt.mm:
            if is_training and epoch < max_epoch_half:
                sparsities = self.calculate_sparsities(
                    sparsity, epoch, max_epoch_half, cfg.slt.sparse_scheduling
                )

            thresholds = []
            for sparsity_value in sparsities:
                local_params = torch.cat(
                    [
                        p.detach().flatten()
                        for name, p in self.named_parameters()
                        if hasattr(p, "is_score")
                        and p.is_score
                        and keyword in name
                    ]
                )
                threshold = percentile(
                    (
                        local_params.abs()
                        if cfg.slt.enable_abs_pruning
                        else local_params
                    ),
                    sparsity_value * 100,
                )
                thresholds.append(threshold)
            return thresholds

        elif cfg.slt.sm:
            if is_training and epoch < max_epoch_half:
                sparsity_value = self.calculate_sparsity(
                    sparsity[0],
                    epoch,
                    max_epoch_half,
                    cfg.slt.sparse_scheduling,
                )
            else:
                sparsity_value = sparsity[0]

            local_params = torch.cat(
                [
                    p.detach().flatten()
                    for name, p in self.named_parameters()
                    if hasattr(p, "is_score")
                    and p.is_score
                    and keyword in name
                ]
            )
            threshold = percentile(
                (
                    local_params.abs()
                    if cfg.slt.enable_abs_pruning
                    else local_params
                ),
                sparsity_value * 100,
            )
            return threshold

    def __init__(
        self,
        in_dim,
        out_dim,
        dropout,
        residual,
        act="relu",
        equivstable_pe=False,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.activation = register.act_dict[act]
        if cfg.slt.bitlinear:
            layer_class = BitLinear
        elif cfg.monarch.mpnn:
            layer_class = MonarchLinear
        elif cfg.slt.srste:
            layer_class = NMSparseMultiLinear
        elif cfg.slt.sm:
            layer_class = SparseLinear
        elif cfg.slt.mm:
            layer_class = SparseLinearMulti_mask

        common_args = (in_dim, out_dim)
        extra_kwargs = {}
        if (
            layer_class == SparseLinearMulti_mask
            or layer_class == SparseLinear
        ):
            extra_kwargs = {
                "gain": "linear",
                "init_mode_weight": "signed_kaiming_uniform_constant",
            }

        self.A = layer_class(*common_args, bias=False, **extra_kwargs)
        self.B = layer_class(*common_args, bias=False, **extra_kwargs)
        self.C = layer_class(*common_args, bias=False, **extra_kwargs)
        self.D = layer_class(*common_args, bias=False, **extra_kwargs)
        self.E = layer_class(*common_args, bias=False, **extra_kwargs)

        # Handling for Equivariant and Stable PE using LapPE
        # ICLR 2022 https://openreview.net/pdf?id=e95i1IHcWj
        self.EquivStablePE = equivstable_pe
        if self.EquivStablePE:
            self.mlp_r_ij = nn.Sequential(
                nn.Linear(1, out_dim),
                self.activation(),
                nn.Linear(out_dim, 1),
                nn.Sigmoid(),
            )

        self.bn_node_x = nn.BatchNorm1d(out_dim, affine=cfg.slt.batch_affine)
        self.bn_edge_e = nn.BatchNorm1d(out_dim, affine=cfg.slt.batch_affine)
        self.act_fn_x = self.activation()
        self.act_fn_e = self.activation()
        self.dropout = dropout
        self.residual = residual
        self.e = None

    def forward(self, batch, cur_epoch=None, mpnn_threshold=None):
        x, e, edge_index = batch.x, batch.edge_attr, batch.edge_index

        """
        x               : [n_nodes, in_dim]
        e               : [n_edges, in_dim]
        edge_index      : [2, n_edges]
        """
        if self.residual:
            x_in = x
            e_in = e

        if cfg.slt.pruning == "layerwise":
            threshold = self.get_threshold(
                cfg.slt.linear_sparsity, cur_epoch, ""
            )
        else:
            threshold = mpnn_threshold

        Ax = self.A(x, threshold)
        Bx = self.B(x, threshold)
        Ce = self.C(e, threshold)
        Dx = self.D(x, threshold)
        Ex = self.E(x, threshold)

        # Handling for Equivariant and Stable PE using LapPE
        # ICLR 2022 https://openreview.net/pdf?id=e95i1IHcWj
        pe_LapPE = batch.pe_EquivStableLapPE if self.EquivStablePE else None

        if cfg.slt.adj_rand_pruning is True:
            edge_index = prune_edge_index(edge_index, cfg.slt.adj_pruning_rate)

        x, e = self.propagate(
            edge_index, Bx=Bx, Dx=Dx, Ex=Ex, Ce=Ce, e=e, Ax=Ax, PE=pe_LapPE
        )

        x = self.bn_node_x(x)
        e = self.bn_edge_e(e)

        x = self.act_fn_x(x)
        e = self.act_fn_e(e)

        x = F.dropout(x, self.dropout, training=self.training)
        e = F.dropout(e, self.dropout, training=self.training)

        if self.residual:
            x = x_in + x
            e = e_in + e

        batch.x = x
        batch.edge_attr = e

        return batch

    def message(self, Dx_i, Ex_j, PE_i, PE_j, Ce):
        """
        {}x_i           : [n_edges, out_dim]
        {}x_j           : [n_edges, out_dim]
        {}e             : [n_edges, out_dim]
        """
        e_ij = Dx_i + Ex_j + Ce
        sigma_ij = torch.sigmoid(e_ij)

        # Handling for Equivariant and Stable PE using LapPE
        # ICLR 2022 https://openreview.net/pdf?id=e95i1IHcWj
        if self.EquivStablePE:
            r_ij = ((PE_i - PE_j) ** 2).sum(dim=-1, keepdim=True)
            r_ij = self.mlp_r_ij(
                r_ij
            )  # the MLP is 1 dim --> hidden_dim --> 1 dim
            sigma_ij = sigma_ij * r_ij

        self.e = e_ij
        return sigma_ij

    def aggregate(self, sigma_ij, index, Bx_j, Bx):
        """
        sigma_ij        : [n_edges, out_dim]  ; is the output from message() function
        index           : [n_edges]
        {}x_j           : [n_edges, out_dim]
        """
        dim_size = Bx.shape[0]  # or None ??   <--- Double check this

        sum_sigma_x = sigma_ij * Bx_j
        numerator_eta_xj = scatter(
            sum_sigma_x, index, 0, None, dim_size, reduce="sum"
        )

        sum_sigma = sigma_ij
        denominator_eta_xj = scatter(
            sum_sigma, index, 0, None, dim_size, reduce="sum"
        )

        out = numerator_eta_xj / (denominator_eta_xj + 1e-6)
        return out

    def update(self, aggr_out, Ax):
        """
        aggr_out        : [n_nodes, out_dim] ; is the output from aggregate() f
                                               unction after the aggregation
        {}x             : [n_nodes, out_dim]
        """
        x = Ax + aggr_out
        e_out = self.e
        del self.e
        return x, e_out


@register_layer("slt_gatedgcnconv")
class SLT_GatedGCNGraphGymLayer(nn.Module):
    """GatedGCN layer.
    Residual Gated Graph ConvNets
    https://arxiv.org/pdf/1711.07553.pdf
    """

    def __init__(self, layer_config: LayerConfig, **kwargs):
        super().__init__()
        self.model = Custom_GatedGCNLayer(
            in_dim=layer_config.dim_in,
            out_dim=layer_config.dim_out,
            dropout=0.0,  # Dropout is handled by GraphGym's `GeneralLayer` wrapper
            # Residual connections are handled by GraphGym's `GNNStackStage` wrapper
            residual=False,
            act=layer_config.act,
            **kwargs,
        )

    def forward(self, batch, cur_epoch=None):
        return self.model(batch)


# for ZINC dataset
class SLT_GINEConv(GINEConv):
    def percentile(t, q):
        k = 1 + round(0.01 * float(q) * (t.numel() - 1))
        return t.view(-1).kthvalue(k).values.item()

    def calculate_sparsity(
        self, sparsity, epoch, max_epoch_half, schedule_type
    ):
        """Calculate sparsity based on the scheduling type."""
        if schedule_type == "linear":
            return sparsity * (epoch / max_epoch_half)
        elif schedule_type == "curve":
            curve_proportion = (max_epoch_half - epoch) / max_epoch_half
            scale_factor = np.sqrt(1 - curve_proportion**2)
            return sparsity * scale_factor
        elif schedule_type == "step":
            step_size = max_epoch_half // 10
            current_step = (epoch // step_size) + 1
            return (sparsity / 10) * current_step
        elif schedule_type == "reverse":
            return 1 - ((1 - sparsity) * (epoch / max_epoch_half))
        return sparsity

    def calculate_sparsities(
        self, sparsity, epoch, max_epoch_half, schedule_type
    ):
        """Calculate sparsities for multi-mode (mm) based on the scheduling type."""
        if schedule_type == "linear":
            return [value * (epoch / max_epoch_half) for value in sparsity]
        elif schedule_type == "curve":
            curve_proportion = (max_epoch_half - epoch) / max_epoch_half
            scale_factor = np.sqrt(1 - curve_proportion**2)
            increments = [
                (sparsity[i] - sparsity[0]) / (1 - sparsity[0])
                for i in range(1, len(sparsity))
            ]
            sparsities = []
            for i, value in enumerate(sparsity):
                if i == 0:
                    sparsities.append(value * scale_factor)
                else:
                    start_point = increments[i - 1]
                    sparsities.append(
                        start_point + (value - start_point) * scale_factor
                    )
            return sparsities
        elif schedule_type == "step":
            step_size = max_epoch_half // 10
            current_step = (epoch // step_size) + 1
            return [
                value * ((value / 10) * current_step) for value in sparsity
            ]
        elif schedule_type == "reverse":
            return [
                1 - ((1 - value) * (epoch / max_epoch_half))
                for value in sparsity
            ]
        return sparsity

    def get_threshold(self, sparsity, epoch=None, keyword=""):
        max_epoch_half = cfg.optim.max_epoch // 2
        is_training = self.training
        sparsities = sparsity

        if cfg.slt.mm:
            if is_training and epoch < max_epoch_half:
                sparsities = self.calculate_sparsities(
                    sparsity, epoch, max_epoch_half, cfg.slt.sparse_scheduling
                )

            thresholds = []
            for sparsity_value in sparsities:
                local_params = torch.cat(
                    [
                        p.detach().flatten()
                        for name, p in self.named_parameters()
                        if hasattr(p, "is_score")
                        and p.is_score
                        and keyword in name
                    ]
                )
                threshold = percentile(
                    (
                        local_params.abs()
                        if cfg.slt.enable_abs_pruning
                        else local_params
                    ),
                    sparsity_value * 100,
                )
                thresholds.append(threshold)
            return thresholds

        elif cfg.slt.sm:
            if is_training and epoch < max_epoch_half:
                sparsity_value = self.calculate_sparsity(
                    sparsity[0],
                    epoch,
                    max_epoch_half,
                    cfg.slt.sparse_scheduling,
                )
            else:
                sparsity_value = sparsity[0]

            local_params = torch.cat(
                [
                    p.detach().flatten()
                    for name, p in self.named_parameters()
                    if hasattr(p, "is_score")
                    and p.is_score
                    and keyword in name
                ]
            )
            threshold = percentile(
                (
                    local_params.abs()
                    if cfg.slt.enable_abs_pruning
                    else local_params
                ),
                sparsity_value * 100,
            )
            return threshold

    def __init__(self, nn, eps=0.0, train_eps=False, edge_dim=None, **kwargs):
        super(SLT_GINEConv, self).__init__(
            nn, eps, train_eps, edge_dim, **kwargs
        )

    def forward(
        self,
        x: Union[Tensor, OptPairTensor],
        edge_index: Adj,
        edge_attr: OptTensor = None,
        cur_epoch=None,
        size: Size = None,
    ) -> Tensor:
        if isinstance(x, Tensor):
            x: OptPairTensor = (x, x)

        out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)

        x_r = x[1]
        if x_r is not None:
            out = out + (1 + self.eps) * x_r

        if cfg.slt.pruning == "layerwise":
            threshold = self.get_threshold(
                cfg.slt.linear_sparsity, cur_epoch, "nn"
            )

        out = self.nn[0](out, threshold)
        out = self.nn[1](out)
        out = self.nn[2](out, threshold)

        return out
