import numpy as np
import torch
import torch.nn as nn
import torch_geometric.graphgym.register as register
from torch_geometric.graphgym import cfg
from torch_geometric.graphgym.register import register_head

from graphgps.slt.monarch_linear import MonarchLinear
from graphgps.slt.sparse_modules import (
    NMSparseMultiLinear,
    SparseLinear,
    SparseLinearMulti_mask,
)


def percentile(t, q):
    k = 1 + round(0.01 * float(q) * (t.numel() - 1))
    return t.view(-1).kthvalue(k).values.item()


@register_head("san_graph")
class SANGraphHead(nn.Module):
    """
    SAN prediction head for graph prediction tasks.

    Args:
        dim_in (int): Input dimension.
        dim_out (int): Output dimension. For binary prediction, dim_out=1.
        L (int): Number of hidden layers.
    """

    def calculate_sparsity(
        self, sparsity, epoch, max_epoch_half, schedule_type
    ):
        """Calculate sparsity based on the scheduling type."""
        if schedule_type == "linear":
            return sparsity * (epoch / max_epoch_half)
        elif schedule_type == "curve":
            curve_proportion = (max_epoch_half - epoch) / max_epoch_half
            scale_factor = np.sqrt(1 - curve_proportion**2)
            return sparsity * scale_factor
        elif schedule_type == "step":
            step_size = max_epoch_half // 10
            current_step = (epoch // step_size) + 1
            return (sparsity / 10) * current_step
        elif schedule_type == "reverse":
            return 1 - ((1 - sparsity) * (epoch / max_epoch_half))
        return sparsity

    def calculate_sparsities(
        self, sparsity, epoch, max_epoch_half, schedule_type
    ):
        """Calculate sparsities for multi-mode (mm) based on the scheduling type."""
        if schedule_type == "linear":
            return [value * (epoch / max_epoch_half) for value in sparsity]
        elif schedule_type == "curve":
            curve_proportion = (max_epoch_half - epoch) / max_epoch_half
            scale_factor = np.sqrt(1 - curve_proportion**2)
            increments = [
                (sparsity[i] - sparsity[0]) / (1 - sparsity[0])
                for i in range(1, len(sparsity))
            ]
            sparsities = []
            for i, value in enumerate(sparsity):
                if i == 0:
                    sparsities.append(value * scale_factor)
                else:
                    start_point = increments[i - 1]
                    sparsities.append(
                        start_point + (value - start_point) * scale_factor
                    )
            return sparsities
        elif schedule_type == "step":
            step_size = max_epoch_half // 10
            current_step = (epoch // step_size) + 1
            return [
                value * ((value / 10) * current_step) for value in sparsity
            ]
        elif schedule_type == "reverse":
            return [
                1 - ((1 - value) * (epoch / max_epoch_half))
                for value in sparsity
            ]
        return sparsity

    def get_threshold(self, sparsity, epoch=None, keyword=""):
        max_epoch_half = cfg.optim.max_epoch // 2
        is_training = self.training
        sparsities = sparsity

        if cfg.slt.mm:
            if is_training and epoch < max_epoch_half:
                sparsities = self.calculate_sparsities(
                    sparsity, epoch, max_epoch_half, cfg.slt.sparse_scheduling
                )

            thresholds = []
            for sparsity_value in sparsities:
                local_params = torch.cat(
                    [
                        p.detach().flatten()
                        for name, p in self.named_parameters()
                        if hasattr(p, "is_score")
                        and p.is_score
                        and keyword in name
                    ]
                )
                threshold = percentile(
                    (
                        local_params.abs()
                        if cfg.slt.enable_abs_pruning
                        else local_params
                    ),
                    sparsity_value * 100,
                )
                thresholds.append(threshold)
            return thresholds

        elif cfg.slt.sm:
            if is_training and epoch < max_epoch_half:
                sparsity_value = self.calculate_sparsity(
                    sparsity[0],
                    epoch,
                    max_epoch_half,
                    cfg.slt.sparse_scheduling,
                )
            else:
                sparsity_value = sparsity[0]

            local_params = torch.cat(
                [
                    p.detach().flatten()
                    for name, p in self.named_parameters()
                    if hasattr(p, "is_score")
                    and p.is_score
                    and keyword in name
                ]
            )
            threshold = percentile(
                (
                    local_params.abs()
                    if cfg.slt.enable_abs_pruning
                    else local_params
                ),
                sparsity_value * 100,
            )
            return threshold

    def __init__(self, dim_in, dim_out, L=2):
        super().__init__()
        self.pooling_fun = register.pooling_dict[cfg.model.graph_pooling]

        if cfg.slt.pred is True:
            if cfg.monarch.pred:
                layer_class = MonarchLinear
            elif cfg.slt.srste:
                layer_class = NMSparseMultiLinear
            elif cfg.slt.sm:
                layer_class = SparseLinear
            elif cfg.slt.mm:
                layer_class = SparseLinearMulti_mask

            list_FC_layers = [
                layer_class(
                    dim_in // 2**layer, dim_in // 2 ** (layer + 1), bias=False
                )
                for layer in range(L)
            ]

            list_FC_layers.append(
                layer_class(dim_in // 2**L, dim_out, bias=False)
            )
        else:
            list_FC_layers = [
                nn.Linear(
                    dim_in // 2**layer, dim_in // 2 ** (layer + 1), bias=True
                )
                for layer in range(L)
            ]
            list_FC_layers.append(
                nn.Linear(dim_in // 2**L, dim_out, bias=True)
            )
        self.FC_layers = nn.ModuleList(list_FC_layers)
        self.L = L
        self.activation = register.act_dict[cfg.gnn.act]()

    def _apply_index(self, batch):
        return batch.graph_feature, batch.y

    def forward(
        self,
        batch,
        cur_epoch=None,
        mpnn_th=None,
        msa_th=None,
        ffn_th=None,
        encoder_th=None,
        pred_th=None,
        global_th=None,
    ):
        graph_emb = self.pooling_fun(batch.x, batch.batch)

        if cfg.slt.pred is True:
            if cfg.slt.pruning == "layerwise":
                pred_th = self.get_threshold(
                    cfg.slt.linear_sparsity, cur_epoch
                )
            if (
                cfg.slt.sm is True or cfg.slt.mm is True
            ) and cfg.slt.pred is True:
                if cfg.slt.pruning == "blockwise":
                    for layer in range(self.L):
                        graph_emb = self.FC_layers[layer](graph_emb, pred_th)
                        graph_emb = self.activation(graph_emb)
                    graph_emb = self.FC_layers[self.L](graph_emb, pred_th)
                elif cfg.slt.pruning == "global":
                    for layer in range(self.L):
                        graph_emb = self.FC_layers[layer](graph_emb, global_th)
                        graph_emb = self.activation(graph_emb)
                    graph_emb = self.FC_layers[self.L](graph_emb, global_th)

        else:
            for layer in range(self.L):
                graph_emb = self.FC_layers[layer](graph_emb)
                graph_emb = self.activation(graph_emb)
            graph_emb = self.FC_layers[self.L](graph_emb)
        batch.graph_feature = graph_emb
        pred, label = self._apply_index(batch)
        return pred, label
