import numpy as np
import torch
import torch.nn as nn

# import torch_geometric.graphgym.register as register
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_node_encoder

from graphgps.slt.monarch_linear import MonarchLinear
from graphgps.slt.sparse_modules import (
    NMSparseMultiLinear,
    SparseLinear,
    SparseLinearMulti_mask,
)


def percentile(t, q):
    k = 1 + round(0.01 * float(q) * (t.numel() - 1))
    return t.view(-1).kthvalue(k).values.item()


class KernelPENodeEncoder(torch.nn.Module):
    """Configurable kernel-based Positional Encoding node encoder.

    The choice of which kernel-based statistics to use is configurable through
    setting of `kernel_type`. Based on this, the appropriate config is selected,
    and also the appropriate variable with precomputed kernel stats is then
    selected from PyG Data graphs in `forward` function.
    E.g., supported are 'RWSE', 'HKdiagSE', 'ElstaticSE'.

    PE of size `dim_pe` will get appended to each node feature vector.
    If `expand_x` set True, original node features will be first linearly
    projected to (dim_emb - dim_pe) size and the concatenated with PE.

    Args:
        dim_emb: Size of final node embedding
        expand_x: Expand node features `x` from dim_in to (dim_emb - dim_pe)
    """

    kernel_type = None  # Instantiated type of the KernelPE, e.g. RWSE

    def calculate_sparsity(
        self, sparsity, epoch, max_epoch_half, schedule_type
    ):
        """Calculate sparsity based on the scheduling type."""
        if schedule_type == "linear":
            return sparsity * (epoch / max_epoch_half)
        elif schedule_type == "curve":
            curve_proportion = (max_epoch_half - epoch) / max_epoch_half
            scale_factor = np.sqrt(1 - curve_proportion**2)
            return sparsity * scale_factor
        elif schedule_type == "step":
            step_size = max_epoch_half // 10
            current_step = (epoch // step_size) + 1
            return (sparsity / 10) * current_step
        elif schedule_type == "reverse":
            return 1 - ((1 - sparsity) * (epoch / max_epoch_half))
        return sparsity

    def calculate_sparsities(
        self, sparsity, epoch, max_epoch_half, schedule_type
    ):
        """Calculate sparsities for multi-mode (mm) based on the scheduling type."""
        if schedule_type == "linear":
            return [value * (epoch / max_epoch_half) for value in sparsity]
        elif schedule_type == "curve":
            curve_proportion = (max_epoch_half - epoch) / max_epoch_half
            scale_factor = np.sqrt(1 - curve_proportion**2)
            increments = [
                (sparsity[i] - sparsity[0]) / (1 - sparsity[0])
                for i in range(1, len(sparsity))
            ]
            sparsities = []
            for i, value in enumerate(sparsity):
                if i == 0:
                    sparsities.append(value * scale_factor)
                else:
                    start_point = increments[i - 1]
                    sparsities.append(
                        start_point + (value - start_point) * scale_factor
                    )
            return sparsities
        elif schedule_type == "step":
            step_size = max_epoch_half // 10
            current_step = (epoch // step_size) + 1
            return [
                value * ((value / 10) * current_step) for value in sparsity
            ]
        elif schedule_type == "reverse":
            return [
                1 - ((1 - value) * (epoch / max_epoch_half))
                for value in sparsity
            ]
        return sparsity

    def get_threshold(self, sparsity, epoch=None, keyword=""):
        max_epoch_half = cfg.optim.max_epoch // 2
        is_training = self.training
        sparsities = sparsity

        if cfg.slt.mm:
            if is_training and epoch < max_epoch_half:
                sparsities = self.calculate_sparsities(
                    sparsity, epoch, max_epoch_half, cfg.slt.sparse_scheduling
                )

            thresholds = []
            for sparsity_value in sparsities:
                local_params = torch.cat(
                    [
                        p.detach().flatten()
                        for name, p in self.named_parameters()
                        if hasattr(p, "is_score")
                        and p.is_score
                        and keyword in name
                    ]
                )
                threshold = percentile(
                    (
                        local_params.abs()
                        if cfg.slt.enable_abs_pruning
                        else local_params
                    ),
                    sparsity_value * 100,
                )
                thresholds.append(threshold)
            return thresholds

        elif cfg.slt.sm:
            if is_training and epoch < max_epoch_half:
                sparsity_value = self.calculate_sparsity(
                    sparsity[0],
                    epoch,
                    max_epoch_half,
                    cfg.slt.sparse_scheduling,
                )
            else:
                sparsity_value = sparsity[0]

            local_params = torch.cat(
                [
                    p.detach().flatten()
                    for name, p in self.named_parameters()
                    if hasattr(p, "is_score")
                    and p.is_score
                    and keyword in name
                ]
            )
            threshold = percentile(
                (
                    local_params.abs()
                    if cfg.slt.enable_abs_pruning
                    else local_params
                ),
                sparsity_value * 100,
            )
            return threshold

    def __init__(self, dim_emb, expand_x=True):
        super().__init__()
        if self.kernel_type is None:
            raise ValueError(
                f"{self.__class__.__name__} has to be "
                f"preconfigured by setting 'kernel_type' class"
                f"variable before calling the constructor."
            )

        dim_in = cfg.share.dim_in  # Expected original input node features dim

        pecfg = getattr(cfg, f"posenc_{self.kernel_type}")
        dim_pe = pecfg.dim_pe  # Size of the kernel-based PE embedding
        num_rw_steps = len(pecfg.kernel.times)
        model_type = pecfg.model.lower()  # Encoder NN model type for PEs
        n_layers = pecfg.layers  # Num. layers in PE encoder model
        norm_type = (
            pecfg.raw_norm_type.lower()
        )  # Raw PE normalization layer type
        self.pass_as_var = (
            pecfg.pass_as_var
        )  # Pass PE also as a separate variable

        if (
            dim_emb - dim_pe < 0
        ):  # formerly 1, but you could have zero feature size
            raise ValueError(
                f"PE dim size {dim_pe} is too large for "
                f"desired embedding size of {dim_emb}."
            )

        if expand_x and dim_emb - dim_pe > 0:
            self.linear_x = nn.Linear(dim_in, dim_emb - dim_pe)
        self.expand_x = expand_x and dim_emb - dim_pe > 0

        if norm_type == "batchnorm":
            self.raw_norm = nn.BatchNorm1d(
                num_rw_steps, affine=cfg.slt.batch_affine
            )
        else:
            self.raw_norm = None

        activation = nn.ReLU  # register.act_dict[cfg.gnn.act]
        if model_type == "mlp":
            layers = []
            if n_layers == 1:
                layers.append(nn.Linear(num_rw_steps, dim_pe))
                layers.append(activation())
            else:
                layers.append(nn.Linear(num_rw_steps, 2 * dim_pe))
                layers.append(activation())
                for _ in range(n_layers - 2):
                    layers.append(nn.Linear(2 * dim_pe, 2 * dim_pe))
                    layers.append(activation())
                layers.append(nn.Linear(2 * dim_pe, dim_pe))
                layers.append(activation())
            self.pe_encoder = nn.Sequential(*layers)
        elif model_type == "linear":
            if cfg.slt.encoder is True:
                if cfg.monarch.encoder is True:
                    self.pe_encoder = MonarchLinear(
                        num_rw_steps, dim_pe, bias=False
                    )
                elif cfg.slt.srste is True:
                    self.pe_encoder = NMSparseMultiLinear(num_rw_steps, dim_pe)
                elif cfg.slt.sm is True:
                    self.pe_encoder = SparseLinear(
                        num_rw_steps, dim_pe, bias=False
                    )
                elif cfg.slt.mm is True:
                    self.pe_encoder = SparseLinearMulti_mask(
                        num_rw_steps, dim_pe, bias=False
                    )
            else:
                self.pe_encoder = nn.Linear(num_rw_steps, dim_pe)
        else:
            raise ValueError(
                f"{self.__class__.__name__}: Does not support "
                f"'{model_type}' encoder model."
            )

    def forward(
        self,
        batch,
        cur_epoch=None,
        mpnn_th=None,
        msa_th=None,
        ffn_th=None,
        encoder_th=None,
        pred_th=None,
        global_th=None,
    ):
        pestat_var = f"pestat_{self.kernel_type}"
        if not hasattr(batch, pestat_var):
            raise ValueError(
                f"Precomputed '{pestat_var}' variable is "
                f"required for {self.__class__.__name__}; set "
                f"config 'posenc_{self.kernel_type}.enable' to "
                f"True, and also set 'posenc.kernel.times' values"
            )

        pos_enc = getattr(
            batch, pestat_var
        )  # (Num nodes) x (Num kernel times)
        # pos_enc = batch.rw_landing  # (Num nodes) x (Num kernel times)
        if self.raw_norm:
            pos_enc = self.raw_norm(pos_enc)

        if (
            cfg.slt.sm is True or cfg.slt.mm is True
        ) and cfg.slt.encoder is True:
            if cfg.slt.encoder is True:
                if cfg.slt.pruning == "layerwise":
                    layer_encoder_th = self.get_threshold(
                        cfg.slt.linear_sparsity, cur_epoch, "pe_encoder"
                    )
                    pos_enc = self.pe_encoder(pos_enc, layer_encoder_th)
                elif cfg.slt.pruning == "blockwise":
                    pos_enc = self.pe_encoder(pos_enc, encoder_th)
                elif cfg.slt.pruning == "global":
                    pos_enc = self.pe_encoder(pos_enc, global_th)

        else:
            pos_enc = self.pe_encoder(pos_enc)  # (Num nodes) x dim_pe

        # Expand node features if needed
        if self.expand_x:
            h = self.linear_x(batch.x)
        else:
            h = batch.x
        # Concatenate final PEs to input embedding
        batch.x = torch.cat((h, pos_enc), 1)
        # Keep PE also separate in a variable (e.g. for skip connections to input)
        if self.pass_as_var:
            setattr(batch, f"pe_{self.kernel_type}", pos_enc)
        return batch


@register_node_encoder("RWSE")
class RWSENodeEncoder(KernelPENodeEncoder):
    """Random Walk Structural Encoding node encoder."""

    kernel_type = "RWSE"


@register_node_encoder("HKdiagSE")
class HKdiagSENodeEncoder(KernelPENodeEncoder):
    """Heat kernel (diagonal) Structural Encoding node encoder."""

    kernel_type = "HKdiagSE"


@register_node_encoder("ElstaticSE")
class ElstaticSENodeEncoder(KernelPENodeEncoder):
    """Electrostatic interactions Structural Encoding node encoder."""

    kernel_type = "ElstaticSE"
