import numbers
import warnings
from typing import Optional, Tuple

import numpy as np
import opt_einsum as oe
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch_geometric as pyg
from torch.nn.parameter import Parameter
from torch_geometric.graphgym import cfg
from torch_geometric.graphgym.register import *
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_scatter import scatter, scatter_add, scatter_max
from yacs.config import CfgNode as CN

from grit.slt.sparse_modules import SparseLinear, SparseLinearMulti_mask

# from grit.utils import negate_edge_index


class RMSNorm(nn.Module):
    __constants__ = ["normalized_shape", "eps", "elementwise_affine"]
    normalized_shape: Tuple[int, ...]
    eps: Optional[float]
    elementwise_affine: bool

    def __init__(
        self,
        normalized_shape,
        eps: Optional[float] = None,
        elementwise_affine: bool = True,
        device=None,
        dtype=None,
    ) -> None:
        super().__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        self.normalized_shape = tuple(normalized_shape)
        self.eps = eps if eps is not None else torch.finfo(torch.float32).eps
        self.elementwise_affine = elementwise_affine

        if self.elementwise_affine:
            self.weight = Parameter(
                torch.empty(self.normalized_shape, device=device, dtype=dtype)
            )
        else:
            self.register_parameter("weight", None)
        self.reset_parameters()

    def reset_parameters(self) -> None:
        if self.elementwise_affine:
            init.ones_(self.weight)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Ensure weight is on the same device as input tensor x
        if self.weight is not None:
            self.weight = self.weight.to()

        rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
        x = x / rms
        if self.elementwise_affine:
            x = x * self.weight
        return x

    def extra_repr(self) -> str:
        return (
            "{normalized_shape}, eps={eps}, "
            "elementwise_affine={elementwise_affine}".format(**self.__dict__)
        )


class PairNorm(torch.nn.Module):
    def __init__(self, device=None, dtype=None):
        super(PairNorm, self).__init__()
        self.device = device
        self.dtype = dtype

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.device != self.device:
            x = x.to(self.device)

        col_mean = x.mean(dim=0)
        x = x - col_mean
        rownorm_mean = (1e-6 + x.pow(2).sum(dim=1).mean()).sqrt()
        x = x / rownorm_mean
        return x

    def extra_repr(self) -> str:
        return f"device={self.device}, dtype={self.dtype}"


def pyg_softmax(src, index, num_nodes=None):
    r"""Computes a sparsely evaluated softmax.
    Given a value tensor :attr:`src`, this function first groups the values
    along the first dimension based on the indices specified in :attr:`index`,
    and then proceeds to compute the softmax individually for each group.

    Args:
        src (Tensor): The source tensor.
        index (LongTensor): The indices of elements for applying the softmax.
        num_nodes (int, optional): The number of nodes, *i.e.*
            :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`)

    :rtype: :class:`Tensor`
    """

    num_nodes = maybe_num_nodes(index, num_nodes)

    out = src - scatter_max(src, index, dim=0, dim_size=num_nodes)[0][index]
    out = out.exp()
    out = out / (
        scatter_add(out, index, dim=0, dim_size=num_nodes)[index] + 1e-16
    )

    return out


class MultiHeadAttentionLayerGritSparse(nn.Module):
    """
    Proposed Attention Computation for GRIT
    """

    def __init__(
        self,
        in_dim,
        out_dim,
        num_heads,
        use_bias,
        clamp=5.0,
        dropout=0.0,
        act=None,
        edge_enhance=True,
        sqrt_relu=False,
        signed_sqrt=True,
        cfg=cfg,
        **kwargs,
    ):
        super().__init__()

        self.out_dim = out_dim
        self.num_heads = num_heads
        self.dropout = nn.Dropout(dropout)
        self.clamp = np.abs(clamp) if clamp is not None else None
        self.edge_enhance = edge_enhance

        if cfg.slt.sm is True:
            self.Q = SparseLinear(
                in_dim,
                out_dim * num_heads,
                bias=False,
                attention_scaling=cfg.slt.attention_scaling,
                gain="linear",
                init_mode_weight="signed_xavier_normal_constant_SF",
            )
            self.K = SparseLinear(
                in_dim,
                out_dim * num_heads,
                bias=False,
                attention_scaling=cfg.slt.attention_scaling,
                gain="linear",
                init_mode_weight="signed_xavier_normal_constant_SF",
            )
            self.E = SparseLinear(
                in_dim,
                out_dim * num_heads * 2,
                bias=False,
                attention_scaling=cfg.slt.attention_scaling,
                gain="linear",
                init_mode_weight="signed_xavier_normal_constant_SF",
            )
            self.V = SparseLinear(
                in_dim,
                out_dim * num_heads,
                bias=False,
                attention_scaling=cfg.slt.attention_scaling,
                gain="linear",
                init_mode_weight="signed_xavier_normal_constant_SF",
            )
        elif cfg.slt.mm is True:
            self.Q = SparseLinearMulti_mask(
                in_dim,
                out_dim * num_heads,
                bias=False,
                attention_scaling=cfg.slt.attention_scaling,
                gain="linear",
                init_mode_weight="signed_xavier_normal_constant_SF",
            )
            self.K = SparseLinearMulti_mask(
                in_dim,
                out_dim * num_heads,
                bias=False,
                attention_scaling=cfg.slt.attention_scaling,
                gain="linear",
                init_mode_weight="signed_xavier_normal_constant_SF",
            )
            self.E = SparseLinearMulti_mask(
                in_dim,
                out_dim * num_heads * 2,
                bias=False,
                attention_scaling=cfg.slt.attention_scaling,
                gain="linear",
                init_mode_weight="signed_xavier_normal_constant_SF",
            )
            self.V = SparseLinearMulti_mask(
                in_dim,
                out_dim * num_heads,
                bias=False,
                attention_scaling=cfg.slt.attention_scaling,
                gain="linear",
                init_mode_weight="signed_xavier_normal_constant_SF",
            )
        else:
            self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=True)
            self.K = nn.Linear(in_dim, out_dim * num_heads, bias=use_bias)
            self.E = nn.Linear(in_dim, out_dim * num_heads * 2, bias=True)
            self.V = nn.Linear(in_dim, out_dim * num_heads, bias=use_bias)
            nn.init.xavier_normal_(self.Q.weight)
            nn.init.xavier_normal_(self.K.weight)
            nn.init.xavier_normal_(self.E.weight)
            nn.init.xavier_normal_(self.V.weight)

        self.Aw = nn.Parameter(
            torch.zeros(self.out_dim, self.num_heads, 1), requires_grad=True
        )
        nn.init.xavier_normal_(self.Aw)

        if act is None:
            self.act = nn.Identity()
        else:
            self.act = act_dict[act]()

        if self.edge_enhance:
            self.VeRow = nn.Parameter(
                torch.zeros(self.out_dim, self.num_heads, self.out_dim),
                requires_grad=True,
            )
            nn.init.xavier_normal_(self.VeRow)

    def propagate_attention(self, batch, threshold=None):
        src = batch.K_h[
            batch.edge_index[0]
        ]  # (num relative) x num_heads x out_dim
        dest = batch.Q_h[
            batch.edge_index[1]
        ]  # (num relative) x num_heads x out_dim
        score = src + dest  # element-wise multiplication

        if batch.get("E", None) is not None:
            batch.E = batch.E.view(-1, self.num_heads, self.out_dim * 2)
            E_w, E_b = (
                batch.E[:, :, : self.out_dim],
                batch.E[:, :, self.out_dim :],
            )
            # (num relative) x num_heads x out_dim
            score = score * E_w
            score = torch.sqrt(torch.relu(score)) - torch.sqrt(
                torch.relu(-score)
            )
            score = score + E_b

        score = self.act(score)
        e_t = score

        # output edge
        if batch.get("E", None) is not None:
            batch.wE = score.flatten(1)

        # final attn
        score = oe.contract("ehd, dhc->ehc", score, self.Aw, backend="torch")
        if self.clamp is not None:
            score = torch.clamp(score, min=-self.clamp, max=self.clamp)

        # raw_attn = score
        score = pyg_softmax(
            score, batch.edge_index[1]
        )  # (num relative) x num_heads x 1
        score = self.dropout(score)
        batch.attn = score

        # Aggregate with Attn-Score
        msg = (
            batch.V_h[batch.edge_index[0]] * score
        )  # (num relative) x num_heads x out_dim
        batch.wV = torch.zeros_like(
            batch.V_h
        )  # (num nodes in batch) x num_heads x out_dim
        scatter(msg, batch.edge_index[1], dim=0, out=batch.wV, reduce="add")

        if self.edge_enhance and batch.E is not None:
            rowV = scatter(
                e_t * score, batch.edge_index[1], dim=0, reduce="add"
            )
            rowV = oe.contract(
                "nhd, dhc -> nhc", rowV, self.VeRow, backend="torch"
            )
            batch.wV = batch.wV + rowV

    def forward(self, batch, threshold=None, sparsity=None):

        Q_h = (
            self.Q(batch.x, threshold, sparsity)
            if threshold is not None
            else self.Q(batch.x)
        )
        K_h = (
            self.K(batch.x, threshold, sparsity)
            if threshold is not None
            else self.K(batch.x)
        )
        V_h = (
            self.V(batch.x, threshold, sparsity)
            if threshold is not None
            else self.V(batch.x)
        )

        if batch.get("edge_attr", None) is not None:
            batch.E = (
                self.E(batch.edge_attr, threshold, sparsity)
                if threshold is not None
                else self.E(batch.edge_attr)
            )
        else:
            batch.E = None

        batch.Q_h = Q_h.view(-1, self.num_heads, self.out_dim)
        batch.K_h = K_h.view(-1, self.num_heads, self.out_dim)
        batch.V_h = V_h.view(-1, self.num_heads, self.out_dim)
        if threshold is not None:
            self.propagate_attention(batch, threshold)
        else:
            self.propagate_attention(batch)

        h_out = batch.wV
        e_out = batch.get("wE", None)

        return h_out, e_out


@register_layer("GritTransformer")
class GritTransformerLayer(nn.Module):
    """
    Proposed Transformer Layer for GRIT
    """

    def __init__(
        self,
        in_dim,
        out_dim,
        num_heads,
        dropout=0.0,
        attn_dropout=0.0,
        layer_norm=False,
        batch_norm=True,
        residual=True,
        act="relu",
        norm_e=True,
        O_e=True,
        cfg=cfg,
        **kwargs,
    ):
        super().__init__()

        self.debug = False
        self.in_channels = in_dim
        self.out_channels = out_dim
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.residual = residual
        self.layer_norm = layer_norm
        self.batch_norm = batch_norm

        # -------
        self.update_e = cfg.get("update_e", True)
        self.bn_momentum = cfg.gt.bn_momentum
        self.bn_no_runner = cfg.gt.bn_no_runner
        self.rezero = cfg.get("rezero", False)

        self.act = act_dict[act]() if act is not None else nn.Identity()
        if cfg.get("attn", None) is None:
            cfg.attn = CN()
        self.use_attn = cfg.attn.get("use", True)
        # self.sigmoid_deg = cfg.attn.get("sigmoid_deg", False)
        self.deg_scaler = cfg.attn.get("deg_scaler", True)

        self.attention = MultiHeadAttentionLayerGritSparse(
            in_dim=in_dim,
            out_dim=out_dim // num_heads,
            num_heads=num_heads,
            use_bias=cfg.attn.get("use_bias", False),
            dropout=attn_dropout,
            clamp=cfg.attn.get("clamp", 5.0),
            act=cfg.attn.get("act", "relu"),
            edge_enhance=cfg.attn.get("edge_enhance", True),
            sqrt_relu=cfg.attn.get("sqrt_relu", False),
            signed_sqrt=cfg.attn.get("signed_sqrt", False),
            scaled_attn=cfg.attn.get("scaled_attn", False),
            no_qk=cfg.attn.get("no_qk", False),
        )

        if cfg.attn.get("graphormer_attn", False):
            self.attention = MultiHeadAttentionLayerGraphormerSparse(
                in_dim=in_dim,
                out_dim=out_dim // num_heads,
                num_heads=num_heads,
                use_bias=cfg.attn.get("use_bias", False),
                dropout=attn_dropout,
                clamp=cfg.attn.get("clamp", 5.0),
                act=cfg.attn.get("act", "relu"),
                edge_enhance=True,
                sqrt_relu=cfg.attn.get("sqrt_relu", False),
                signed_sqrt=cfg.attn.get("signed_sqrt", False),
                scaled_attn=cfg.attn.get("scaled_attn", False),
                no_qk=cfg.attn.get("no_qk", False),
            )

        if cfg.slt.sm is True and cfg.slt.msa is True:
            self.O_h = SparseLinear(
                out_dim // num_heads * num_heads,
                out_dim,
                bias=False,
                attention_scaling=cfg.slt.attention_scaling,
                gain="linear",
                init_mode_weight="signed_kaiming_uniform_constant_SF",
            )
            if O_e:
                self.O_e = SparseLinear(
                    out_dim // num_heads * num_heads,
                    out_dim,
                    bias=False,
                    attention_scaling=cfg.slt.attention_scaling,
                    gain="linear",
                    init_mode_weight="signed_kaiming_uniform_constant_SF",
                )
            else:
                self.O_e = nn.Identity()
        elif cfg.slt.mm is True and cfg.slt.msa is True:
            self.O_h = SparseLinearMulti_mask(
                out_dim // num_heads * num_heads,
                out_dim,
                bias=False,
                attention_scaling=cfg.slt.attention_scaling,
                gain="linear",
                init_mode_weight="signed_kaiming_uniform_constant_SF",
            )
            if O_e:
                self.O_e = SparseLinearMulti_mask(
                    out_dim // num_heads * num_heads,
                    out_dim,
                    bias=False,
                    attention_scaling=cfg.slt.attention_scaling,
                    gain="linear",
                    init_mode_weight="signed_kaiming_uniform_constant_SF",
                )
            else:
                self.O_e = nn.Identity()
        else:
            self.O_h = nn.Linear(out_dim // num_heads * num_heads, out_dim)
            if O_e:
                self.O_e = nn.Linear(out_dim // num_heads * num_heads, out_dim)
            else:
                self.O_e = nn.Identity()

        # -------- Deg Scaler Option ------

        if self.deg_scaler:
            self.deg_coef = nn.Parameter(
                torch.zeros(1, out_dim // num_heads * num_heads, 2)
            )
            nn.init.xavier_normal_(self.deg_coef)

        if self.layer_norm:
            self.layer_norm1_h = nn.LayerNorm(out_dim)
            self.layer_norm1_e = (
                nn.LayerNorm(out_dim) if norm_e else nn.Identity()
            )

        if self.batch_norm:
            # when the batch_size is really small, use smaller momentum to avoid bad mini-batch leading to extremely bad val/test loss (NaN)
            self.batch_norm1_h = nn.BatchNorm1d(
                out_dim,
                track_running_stats=not self.bn_no_runner,
                eps=1e-5,
                momentum=cfg.gt.bn_momentum,
            )
            self.batch_norm1_e = (
                nn.BatchNorm1d(
                    out_dim,
                    track_running_stats=not self.bn_no_runner,
                    eps=1e-5,
                    momentum=cfg.gt.bn_momentum,
                )
                if norm_e
                else nn.Identity()
            )

        if cfg.slt.sm is True and cfg.slt.ffn is True:
            self.FFN_h_layer1 = SparseLinear(
                out_dim,
                out_dim * 2,
                bias=False,
                attention_scaling=cfg.slt.attention_scaling,
                init_mode_weight="signed_kaiming_uniform_constant_SF",
            )
            self.FFN_h_layer2 = SparseLinear(
                out_dim * 2,
                out_dim,
                bias=False,
                attention_scaling=cfg.slt.attention_scaling,
                init_mode_weight="signed_kaiming_uniform_constant_SF",
            )
        elif cfg.slt.mm is True and cfg.slt.ffn is True:
            self.FFN_h_layer1 = SparseLinearMulti_mask(
                out_dim,
                out_dim * 2,
                bias=False,
                attention_scaling=cfg.slt.attention_scaling,
                init_mode_weight="signed_kaiming_uniform_constant_SF",
            )
            self.FFN_h_layer2 = SparseLinearMulti_mask(
                out_dim * 2,
                out_dim,
                bias=False,
                attention_scaling=cfg.slt.attention_scaling,
                init_mode_weight="signed_kaiming_uniform_constant_SF",
            )
        else:
            self.FFN_h_layer1 = nn.Linear(out_dim, out_dim * 2)
            self.FFN_h_layer2 = nn.Linear(out_dim * 2, out_dim)

        if self.layer_norm:
            self.layer_norm2_h = nn.LayerNorm(out_dim)

        if self.batch_norm:
            self.batch_norm2_h = nn.BatchNorm1d(
                out_dim,
                track_running_stats=not self.bn_no_runner,
                eps=1e-5,
                momentum=cfg.gt.bn_momentum,
            )

        if self.rezero:
            self.alpha1_h = nn.Parameter(torch.zeros(1, 1))
            self.alpha2_h = nn.Parameter(torch.zeros(1, 1))
            self.alpha1_e = nn.Parameter(torch.zeros(1, 1))

        if cfg.slt.batchnorm_msa:
            self.batchnorm_msa = torch.nn.BatchNorm1d(cfg.gt.dim_hidden).cuda()
        elif cfg.slt.layernorm_msa:
            self.layernorm_msa = torch.nn.LayerNorm(cfg.gt.dim_hidden).cuda()
        elif cfg.slt.pairnorm_msa:
            self.pairnorm_msa = PairNorm().cuda()
        elif cfg.slt.rmsnorm_msa:
            self.rmsnorm_msa = RMSNorm(cfg.gt.dim_hidden).cuda()

        if cfg.slt.batchnorm_ffn:
            self.batchnorm_ffn = torch.nn.BatchNorm1d(cfg.gt.dim_hidden).cuda()
        elif cfg.slt.layernorm_ffn:
            self.layernorm_ffn = torch.nn.LayerNorm(cfg.gt.dim_hidden).cuda()
        elif cfg.slt.pairnorm_ffn:
            self.pairnorm_ffn = PairNorm().cuda()
        elif cfg.slt.rmsnorm_ffn:
            self.rmsnorm_ffn = RMSNorm(cfg.gt.dim_hidden).cuda()

    def forward(self, batch, threshold=None, sparsity=None):
        h = batch.x

        num_nodes = batch.num_nodes
        log_deg = get_log_deg(batch)

        h_in1 = h  # for first residual connection
        e_in1 = batch.get("edge_attr", None)
        e = None
        # multi-head attention out

        if cfg.slt.batchnorm_msa:
            batch.x = self.batchnorm_msa(batch.x)
        elif cfg.slt.layernorm_msa:
            batch.x = self.layernorm_msa(batch.x)
        elif cfg.slt.pairnorm_msa:
            batch.x = self.pairnorm_msa(batch.x)
        elif cfg.slt.rmsnorm_msa:
            batch.x = self.rmsnorm_msa(batch.x)

        h_attn_out, e_attn_out = (
            self.attention(batch, threshold, sparsity)
            if threshold is not None
            else self.attention(batch)
        )

        h = h_attn_out.view(num_nodes, -1)
        h = F.dropout(h, self.dropout, training=self.training)

        # degree scaler
        if self.deg_scaler:
            h = torch.stack([h, h * log_deg], dim=-1)
            h = (h * self.deg_coef).sum(dim=-1)

        h = (
            self.O_h(h, threshold, sparsity)
            if (threshold is not None and cfg.slt.msa is True)
            else self.O_h(h)
        )

        if e_attn_out is not None:
            e = e_attn_out.flatten(1)
            e = F.dropout(e, self.dropout, training=self.training)
            e = (
                self.O_e(e, threshold, sparsity)
                if (threshold is not None and cfg.slt.msa is True)
                else self.O_e(e)
            )

        if self.residual:
            if self.rezero:
                h = h * self.alpha1_h
            h = h_in1 + h  # residual connection
            if e is not None:
                if self.rezero:
                    e = e * self.alpha1_e
                e = e + e_in1

        if self.layer_norm:
            h = self.layer_norm1_h(h)
            if e is not None:
                e = self.layer_norm1_e(e)

        if self.batch_norm:
            h = self.batch_norm1_h(h)
            if e is not None:
                e = self.batch_norm1_e(e)

        # FFN for h
        h_in2 = h  # for second residual connection

        # if cfg.slt.batchnorm_ffn:
        #     h = self.batchnorm_ffn(h)
        # elif cfg.slt.layernorm_ffn:
        #     h = self.layernorm_ffn(h)
        # elif cfg.slt.pairnorm_ffn:
        #     h = self.pairnorm_ffn(h)
        # elif cfg.slt.rmsnorm_ffn:
        #     h = self.rmsnorm_ffn(h)

        h = (
            self.FFN_h_layer1(h, threshold, sparsity)
            if (threshold is not None and cfg.slt.ffn is True)
            else self.FFN_h_layer1(h)
        )
        h = self.act(h)
        h = F.dropout(h, self.dropout, training=self.training)
        h = (
            self.FFN_h_layer2(h, threshold, sparsity)
            if (threshold is not None and cfg.slt.ffn is True)
            else self.FFN_h_layer2(h)
        )

        if self.residual:
            if self.rezero:
                h = h * self.alpha2_h
            h = h_in2 + h  # residual connection

        if self.layer_norm:
            h = self.layer_norm2_h(h)

        if self.batch_norm:
            h = self.batch_norm2_h(h)

        batch.x = h
        if self.update_e:
            batch.edge_attr = e
        else:
            batch.edge_attr = e_in1

        return batch

    def __repr__(self):
        return "{}(in_channels={}, out_channels={}, heads={}, residual={})\n[{}]".format(
            self.__class__.__name__,
            self.in_channels,
            self.out_channels,
            self.num_heads,
            self.residual,
            super().__repr__(),
        )


@torch.no_grad()
def get_log_deg(batch):
    if "log_deg" in batch:
        log_deg = batch.log_deg
    elif "deg" in batch:
        deg = batch.deg
        log_deg = torch.log(deg + 1).unsqueeze(-1)
    else:
        warnings.warn(
            "Compute the degree on the fly; Might be problematric if have applied edge-padding to complete graphs"
        )
        deg = pyg.utils.degree(
            batch.edge_index[1], num_nodes=batch.num_nodes, dtype=torch.float
        )
        log_deg = torch.log(deg + 1)
    log_deg = log_deg.view(batch.num_nodes, 1)
    return log_deg
