import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric as pyg
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_scatter import scatter, scatter_max, scatter_add

from torch_geometric.graphgym.register import *
import opt_einsum as oe

from yacs.config import CfgNode as CN

import warnings

def pyg_softmax(src, index, num_nodes=None):
    r"""Computes a sparsely evaluated softmax.
    Given a value tensor :attr:`src`, this function first groups the values
    along the first dimension based on the indices specified in :attr:`index`,
    and then proceeds to compute the softmax individually for each group.

    Args:
        src (Tensor): The source tensor.
        index (LongTensor): The indices of elements for applying the softmax.
        num_nodes (int, optional): The number of nodes, *i.e.*
            :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`)

    :rtype: :class:`Tensor`
    """

    num_nodes = maybe_num_nodes(index, num_nodes)

    out = src - scatter_max(src, index, dim=0, dim_size=num_nodes)[0][index]
    out = out.exp()
    out = out / (
            scatter_add(out, index, dim=0, dim_size=num_nodes)[index] + 1e-16)

    return out



class MultiHeadAttentionLayerGritSparse(nn.Module):
    """
        Proposed Attention Computation for GRIT
    """

    def __init__(self, in_dim, out_dim, num_heads, use_bias, features=False,
                 clamp=5., dropout=0., act=None,
                 edge_enhance=True,
                 sqrt_relu=False,
                 signed_sqrt=True,
                 cfg=CN(),
                 **kwargs):
        super().__init__()

        self.out_dim = out_dim
        self.num_heads = num_heads
        self.dropout = nn.Dropout(dropout)
        self.clamp = np.abs(clamp) if clamp is not None else None
        self.edge_enhance = edge_enhance
        self.features = features

        self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=True)
        self.K = nn.Linear(in_dim, out_dim * num_heads, bias=use_bias)
        self.E = nn.Linear(in_dim, out_dim * num_heads * 2, bias=True)
        self.V = nn.Linear(in_dim, out_dim * num_heads, bias=use_bias)
        nn.init.xavier_normal_(self.Q.weight)
        nn.init.xavier_normal_(self.K.weight)
        nn.init.xavier_normal_(self.E.weight)
        nn.init.xavier_normal_(self.V.weight)

        if self.features:
            self.Q_f = nn.Linear(in_dim, out_dim * num_heads, bias=True)
            self.K_f = nn.Linear(in_dim, out_dim * num_heads, bias=use_bias)
            self.V_f = nn.Linear(in_dim, out_dim * num_heads, bias=use_bias)
            nn.init.xavier_normal_(self.Q_f.weight)
            nn.init.xavier_normal_(self.K_f.weight)
            nn.init.xavier_normal_(self.V_f.weight)

        self.Aw = nn.Parameter(torch.zeros(self.out_dim, self.num_heads, 1), requires_grad=True)
        nn.init.xavier_normal_(self.Aw)

        if act is None:
            self.act = nn.Identity()
        else:
            self.act = act_dict[act]()

        if self.edge_enhance:
            self.VeRow = nn.Parameter(torch.zeros(self.out_dim, self.num_heads, self.out_dim), requires_grad=True)
            nn.init.xavier_normal_(self.VeRow)

    def propagate_attention(self, batch):
        src = batch.K_h[batch.edge_index[0]]      # (num relative) x num_heads x out_dim
        if self.features:
            src = src + batch.K_f[batch.edge_index[0]] 
        dest = batch.Q_h[batch.edge_index[1]]     # (num relative) x num_heads x out_dim
        if self.features:
            dest = dest + batch.Q_f[batch.edge_index[1]]
        score = src + dest                        # element-wise multiplication

        if batch.get("E", None) is not None:
            batch.E = batch.E.view(-1, self.num_heads, self.out_dim * 2)
            E_w, E_b = batch.E[:, :, :self.out_dim], batch.E[:, :, self.out_dim:]
            # (num relative) x num_heads x out_dim
            score = score * E_w
            score = torch.sqrt(torch.relu(score)) - torch.sqrt(torch.relu(-score))
            score = score + E_b

        score = self.act(score)
        e_t = score

        # output edge
        if batch.get("E", None) is not None:
            batch.wE = score.flatten(1)

        # final attn
        score = oe.contract("ehd, dhc->ehc", score, self.Aw, backend="torch")
        if self.clamp is not None:
            score = torch.clamp(score, min=-self.clamp, max=self.clamp)

        raw_attn = score
        score = pyg_softmax(score, batch.edge_index[1])  # (num relative) x num_heads x 1
        score = self.dropout(score)
        batch.attn = score

        # Aggregate with Attn-Score
        msg = batch.V_h[batch.edge_index[0]] * score  # (num relative) x num_heads x out_dim
        batch.wV = torch.zeros_like(batch.V_h)  # (num nodes in batch) x num_heads x out_dim
        scatter(msg, batch.edge_index[1], dim=0, out=batch.wV, reduce='add')

        if self.edge_enhance and batch.get('E', None) is not None:
            rowV = scatter(e_t * score, batch.edge_index[1], dim=0, reduce="add")
            rowV = oe.contract("nhd, dhc -> nhc", rowV, self.VeRow, backend="torch")
            batch.wV = batch.wV + rowV

        if self.features:
            msg = batch.V_f[batch.edge_index[0]] * score  # (num relative) x num_heads x out_dim
            batch.wF = torch.zeros_like(batch.V_f)  # (num nodes in batch) x num_heads x out_dim
            scatter(msg, batch.edge_index[1], dim=0, out=batch.wF, reduce='add')

            if self.edge_enhance and batch.get('E', None) is not None:
                batch.wF = batch.wF + rowV


    def forward(self, batch):
        Q_h = self.Q(batch.x)
        K_h = self.K(batch.x)

        if self.features:
            batch.Q_f = self.Q_f(batch.x_features).view(-1, self.num_heads, self.out_dim)
            batch.K_f = self.K_f(batch.x_features).view(-1, self.num_heads, self.out_dim)
            batch.V_f = self.V_f(batch.x_features).view(-1, self.num_heads, self.out_dim) 

        V_h = self.V(batch.x)
        if batch.get("edge_attr", None) is not None:
            batch.E = self.E(batch.edge_attr)
        else:
            batch.E = None

        batch.Q_h = Q_h.view(-1, self.num_heads, self.out_dim)
        batch.K_h = K_h.view(-1, self.num_heads, self.out_dim)
        batch.V_h = V_h.view(-1, self.num_heads, self.out_dim)
        self.propagate_attention(batch)
        h_out = batch.wV
        e_out = batch.get('wE', None)
        f_out = batch.get('wF', None) if self.features else None

        return h_out, e_out, f_out


@register_layer("GritTransformer")
class GritTransformerLayer(nn.Module):
    """
        Proposed Transformer Layer for GRIT
    """
    def __init__(self, in_dim, out_dim, num_heads,
                 dropout=0.0,
                 attn_dropout=0.0,
                 layer_norm=False, batch_norm=True,
                 residual=True,
                 act='relu',
                 norm_e=True,
                 O_e=True,
                 features=False,
                 coupling=False,
                 conditioning=True,
                 spec_conditioning=False,
                 ignore_edges=False,
                 cfg=dict(),
                 **kwargs):
        super().__init__()

        self.debug = False
        self.in_channels = in_dim
        self.out_channels = out_dim
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.residual = residual
        self.layer_norm = layer_norm
        self.batch_norm = batch_norm
        # self.sep_t = sep_t
        self.features = features
        self.coupling = coupling
        self.ignore_edges = ignore_edges

        # Conditioning
        self.conditional = conditioning # True # cfg.conditional_gen or cfg.time_conditioning
        self.spec_conditional = spec_conditioning

        def init_weights(m):
            if isinstance(m, nn.Linear):
                torch.nn.init.zeros_(m.weight)
                m.bias.data.fill_(0.0)


        if self.conditional:
            # if sep_t:

            self.c_mapping_ln_x = torch.nn.Sequential(
                nn.Linear(out_dim, out_dim, bias=True),
                nn.SiLU(),
                nn.Linear(out_dim, 4 * out_dim, bias=True),
            )
            self.c_mapping_ln_e = torch.nn.Sequential(
                nn.Linear(out_dim, out_dim, bias=True),
                nn.SiLU(),
                nn.Linear(out_dim, 2 * out_dim, bias=True),
            )
            self.c_mapping_res_x = torch.nn.Sequential(
                nn.Linear(out_dim, out_dim, bias=True),
                nn.SiLU(),
                nn.Linear(out_dim, 2 * out_dim, bias=True),
            )
            self.c_mapping_res_e = torch.nn.Sequential(
                nn.Linear(out_dim, out_dim, bias=True),
                nn.SiLU(),
                nn.Linear(out_dim, out_dim, bias=True),
            )

            self.c_mapping_ln_x.apply(init_weights)
            self.c_mapping_res_x.apply(init_weights)
            self.c_mapping_ln_e.apply(init_weights)
            self.c_mapping_res_e.apply(init_weights)

            if self.features:

                self.c_mapping_ln_f = torch.nn.Sequential(
                    nn.Linear(out_dim, out_dim, bias=True),
                    nn.SiLU(),
                    nn.Linear(out_dim, 4 * out_dim, bias=True),
                )
                self.c_mapping_res_f = torch.nn.Sequential(
                    nn.Linear(out_dim, out_dim, bias=True),
                    nn.SiLU(),
                    nn.Linear(out_dim, 2 * out_dim, bias=True),
                )

                self.c_mapping_ln_f.apply(init_weights)
                self.c_mapping_res_f.apply(init_weights)

        if self.coupling:

            self.mapping_f_to_x = torch.nn.Sequential(
                nn.Linear(out_dim, out_dim, bias=True),
                nn.SiLU(),
                nn.Linear(out_dim, 2 * out_dim, bias=True),
            )
            self.mapping_x_to_f = torch.nn.Sequential(
                nn.Linear(out_dim, out_dim, bias=True),
                nn.SiLU(),
                nn.Linear(out_dim, 2 * out_dim, bias=True),
            )

        if self.spec_conditional:

            self.c_mapping_ln_xspec = torch.nn.Sequential(
                nn.Linear(out_dim, out_dim, bias=True),
                nn.SiLU(),
                nn.Linear(out_dim, 4 * out_dim, bias=True),
            )
            self.c_mapping_res_xspec = torch.nn.Sequential(
                nn.Linear(out_dim, out_dim, bias=True),
                nn.SiLU(),
                nn.Linear(out_dim, 2 * out_dim, bias=True),
            )
            self.c_mapping_ln_fspec = torch.nn.Sequential(
                nn.Linear(out_dim, out_dim, bias=True),
                nn.SiLU(),
                nn.Linear(out_dim, 4 * out_dim, bias=True),
            )
            self.c_mapping_res_fspec = torch.nn.Sequential(
                nn.Linear(out_dim, out_dim, bias=True),
                nn.SiLU(),
                nn.Linear(out_dim, 2 * out_dim, bias=True),
            )
            self.c_mapping_ln_espec = torch.nn.Sequential(
                nn.Linear(out_dim, out_dim, bias=True),
                nn.SiLU(),
                nn.Linear(out_dim, 2 * out_dim, bias=True),
            )
            self.c_mapping_res_espec = torch.nn.Sequential(
                nn.Linear(out_dim, out_dim, bias=True),
                nn.SiLU(),
                nn.Linear(out_dim, out_dim, bias=True),
            )

            self.c_mapping_ln_xspec.apply(init_weights)
            self.c_mapping_res_xspec.apply(init_weights)
            self.c_mapping_ln_fspec.apply(init_weights)
            self.c_mapping_res_fspec.apply(init_weights)
            self.c_mapping_ln_espec.apply(init_weights)
            self.c_mapping_res_espec.apply(init_weights)


        # -------
        self.update_e = cfg.get("update_e", True)
        self.bn_momentum = cfg.bn_momentum
        self.bn_no_runner = cfg.bn_no_runner
        self.rezero = cfg.get("rezero", False)

        self.act = act_dict[act]() if act is not None else nn.Identity()
        if cfg.get("attn", None) is None:
            cfg.attn = dict()
        self.use_attn = cfg.attn.get("use", True)
        # self.sigmoid_deg = cfg.attn.get("sigmoid_deg", False)
        self.deg_scaler = cfg.attn.get("deg_scaler", True)

        self.attention = MultiHeadAttentionLayerGritSparse(
            in_dim=in_dim,
            out_dim=out_dim // num_heads,
            num_heads=num_heads,
            use_bias=cfg.attn.get("use_bias", False),
            dropout=attn_dropout,
            clamp=cfg.attn.get("clamp", 5.),
            act=cfg.attn.get("act", "relu"),
            edge_enhance=cfg.attn.get("edge_enhance", True),
            sqrt_relu=cfg.attn.get("sqrt_relu", False),
            signed_sqrt=cfg.attn.get("signed_sqrt", False),
            scaled_attn =cfg.attn.get("scaled_attn", False),
            no_qk=cfg.attn.get("no_qk", False),
            features=self.features
        )

        self.O_h = nn.Linear(out_dim//num_heads * num_heads, out_dim)
        if O_e:
            self.O_e = nn.Linear(out_dim//num_heads * num_heads, out_dim)
        else:
            self.O_e = nn.Identity()

        # -------- Deg Scaler Option ------

        if self.deg_scaler:
            self.deg_coef = nn.Parameter(torch.zeros(1, out_dim//num_heads * num_heads, 2))
            nn.init.xavier_normal_(self.deg_coef)

        if self.layer_norm:
            self.layer_norm1_h = nn.LayerNorm(out_dim, elementwise_affine=False, bias=False)
            self.layer_norm1_e = nn.LayerNorm(out_dim, elementwise_affine=False, bias=False) if norm_e else nn.Identity()
                
        if self.batch_norm:
            # when the batch_size is really small, use smaller momentum to avoid bad mini-batch leading to extremely bad val/test loss (NaN)
            self.batch_norm1_h = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.bn_momentum)
            self.batch_norm1_e = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.bn_momentum) if norm_e else nn.Identity()
                
        # FFN for h
        self.FFN_h_layer1 = nn.Linear(out_dim, out_dim * 2)
        self.FFN_h_layer2 = nn.Linear(out_dim * 2, out_dim)

        if self.layer_norm:
            self.layer_norm2_h = nn.LayerNorm(out_dim, elementwise_affine=False, bias=False)

        if self.batch_norm:
            self.batch_norm2_h = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.bn_momentum)

        if self.rezero:
            self.alpha1_h = nn.Parameter(torch.zeros(1,1))
            self.alpha2_h = nn.Parameter(torch.zeros(1,1))
            self.alpha1_e = nn.Parameter(torch.zeros(1,1))

        if self.features:
            self.O_f = nn.Linear(out_dim//num_heads * num_heads, out_dim)
            # FFN for f
            self.FFN_f_layer1 = nn.Linear(out_dim, out_dim * 2)
            self.FFN_f_layer2 = nn.Linear(out_dim * 2, out_dim)
            if self.layer_norm:
                self.layer_norm1_f = nn.LayerNorm(out_dim, elementwise_affine=False, bias=False)
                self.layer_norm2_f = nn.LayerNorm(out_dim, elementwise_affine=False, bias=False)
            if self.batch_norm:
                self.batch_norm1_f = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.bn_momentum)
                self.batch_norm2_f = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.bn_momentum)

    def forward(self, batch):

        if self.ignore_edges and (batch.get("edge_attr", None) is not None):
            del batch.edge_attr

        if self.conditional:
            # if self.sep_t:
            gamma1_x, beta1_x, gamma2_x, beta2_x = self.c_mapping_ln_x(batch.c_x).chunk(4, dim=-1)
            gamma_e, beta_e = self.c_mapping_ln_e(batch.c_e).chunk(2, dim=-1)
            alpha1_x, alpha2_x = self.c_mapping_res_x(batch.c_x).chunk(2, dim=-1)
            alpha_e = self.c_mapping_res_e(batch.c_e)
            if self.features:
                gamma1_f, beta1_f, gamma2_f, beta2_f = self.c_mapping_ln_f(batch.c_f).chunk(4, dim=-1)
                alpha1_f, alpha2_f = self.c_mapping_res_f(batch.c_f).chunk(2, dim=-1)

        if self.spec_conditional:
            gamma1_xspec, beta1_xspec, gamma2_xspec, beta2_xspec = self.c_mapping_ln_xspec(batch.c_spec).chunk(4, dim=-1)
            gamma1_fspec, beta1_fspec, gamma2_fspec, beta2_fspec = self.c_mapping_ln_fspec(batch.c_spec).chunk(4, dim=-1)
            gamma_espec, beta_espec = self.c_mapping_ln_espec(batch.c_spec).chunk(2, dim=-1)
            alpha1_xspec, alpha2_xspec = self.c_mapping_res_xspec(batch.c_spec).chunk(2, dim=-1)
            alpha1_fspec, alpha2_fspec = self.c_mapping_res_fspec(batch.c_spec).chunk(2, dim=-1)
            alpha_espec = self.c_mapping_res_espec(batch.c_spec)

        h = batch.x
        num_nodes = batch.num_nodes
        log_deg = get_log_deg(batch)

        h_in1 = h  # for first residual connection
        e_in1 = batch.get("edge_attr", None)
        e = None
        f = None
        # multi-head attention out

        f_in1 = batch.get("x_features", None)
        if (f_in1 is not None) and self.coupling:
            gamma_f_x, beta_f_x = self.mapping_f_to_x(f_in1).chunk(2, dim=-1)
            gamma_x_f, beta_x_f = self.mapping_x_to_f(h_in1).chunk(2, dim=-1)
            f_in1 = f_in1 * (1 + gamma_x_f) + beta_x_f
            h_in1 = h_in1 * (1 + gamma_f_x) + beta_f_x
            batch.x = h_in1
            batch.x_features = f_in1

        h_attn_out, e_attn_out, f_attn_out = self.attention(batch)

        h = h_attn_out.view(num_nodes, -1)
        h = F.dropout(h, self.dropout, training=self.training)

        # degree scaler
        if self.deg_scaler:
            h = torch.stack([h, h * log_deg], dim=-1)
            h = (h * self.deg_coef).sum(dim=-1)

        h = self.O_h(h)
        if e_attn_out is not None:
            e = e_attn_out.flatten(1)
            e = F.dropout(e, self.dropout, training=self.training)
            e = self.O_e(e)

        if f_attn_out is not None:
            f = f_attn_out.flatten(1)
            f = F.dropout(f, self.dropout, training=self.training)
            if self.deg_scaler:
                f = torch.stack([f, f * log_deg], dim=-1)
                f = (f * self.deg_coef).sum(dim=-1)
            f = self.O_f(f)

        if self.residual:
            if self.conditional:
                h = h * alpha1_x # if self.sep_t else h * alpha1_x[batch.batch]
            if self.spec_conditional:
                h = h * alpha1_xspec[batch.batch]
            h = h_in1 + h  # residual connection
            if e is not None:
                if self.conditional:
                    e = e * alpha_e # if self.sep_t else e * alpha_e[batch.batch[batch.edge_index[0]]]
                if self.spec_conditional:
                    e = e * alpha_espec[batch.batch[batch.edge_index[0]]]
                e = e + e_in1
            if f is not None:
                if self.conditional:
                    f = f * alpha1_f # if self.sep_t else f * alpha1_f[batch.batch]
                if self.spec_conditional:
                    f = f * alpha1_fspec[batch.batch]
                f = f + f_in1
                

        if self.layer_norm:
            h = self.layer_norm1_h(h)
            if self.conditional:
                h = h * (1 + gamma1_x) + beta1_x # if self.sep_t else h * (1 + gamma1_x[batch.batch]) + beta1_x[batch.batch]
            if self.spec_conditional:
                h = h * (1 + gamma1_xspec[batch.batch]) + beta1_xspec[batch.batch]
            if e is not None:
                e = self.layer_norm1_e(e)
                if self.conditional:
                # if self.sep_t:
                    e = e * (1 + gamma_e) + beta_e
                # else:
                #     e = e * (1 + gamma_e[batch.batch[batch.edge_index[0]]]) \
                #         + beta_e[batch.batch[batch.edge_index[0]]]
                if self.spec_conditional:
                    e = e * (1 + gamma_espec[batch.batch[batch.edge_index[0]]]) + beta_espec[batch.batch[batch.edge_index[0]]]

            if f is not None:
                f = self.layer_norm1_f(f)
                if self.conditional:
                    f = f * (1 + gamma1_f) + beta1_f # if self.sep_t else f * (1 + gamma1_f[batch.batch]) + beta1_f[batch.batch]
                if self.spec_conditional:
                    f = f * (1 + gamma1_fspec[batch.batch]) + beta1_fspec[batch.batch]

        if self.batch_norm:
            h = self.batch_norm1_h(h)
            if e is not None: e = self.batch_norm1_e(e)
            if f is not None: f = self.batch_norm1_f(f)

        # FFN for h
        h_in2 = h  # for second residual connection
        h = self.FFN_h_layer1(h)
        h = self.act(h)
        h = F.dropout(h, self.dropout, training=self.training)
        h = self.FFN_h_layer2(h)

        if f is not None:
            f_in2 = f
            f = self.FFN_f_layer1(f)
            f = self.act(f)
            f = F.dropout(f, self.dropout, training=self.training)
            f = self.FFN_f_layer2(f)

        if self.residual:
            if self.conditional:
                h = h * alpha2_x # if self.sep_t else h * alpha2_x[batch.batch]
            if self.spec_conditional:
                h = h * alpha2_xspec[batch.batch]
            h = h_in2 + h  # residual connection
            if f is not None:
                if self.conditional:
                    f = f * alpha2_f # if self.sep_t else f * alpha2_f[batch.batch]
                if self.spec_conditional:
                    f = f * alpha2_fspec[batch.batch]
                f = f_in2 + f  # residual connection

        if self.layer_norm:
            h = self.layer_norm2_h(h)
            if self.conditional:
                h = h * (1 + gamma2_x) + beta2_x # if self.sep_t else h * (1 + gamma2_x[batch.batch]) + beta2_x[batch.batch]                
            if f is not None:
                f = self.layer_norm2_f(f)
                if self.conditional:
                    f = f * (1 + gamma2_f) + beta2_f # if self.sep_t else f * (1 + gamma2_f[batch.batch]) + beta2_f[batch.batch]
            if self.spec_conditional:
                h = h * (1 + gamma2_xspec[batch.batch]) + beta2_xspec[batch.batch]
                f = f * (1 + gamma2_fspec[batch.batch]) + beta2_fspec[batch.batch]

        if self.batch_norm:
            h = self.batch_norm2_h(h)
            if f is not None: f = self.batch_norm2_f(f)

        batch.x = h
        if self.update_e:
            batch.edge_attr = e
        else:
            batch.edge_attr = e_in1
        if f is not None:
            batch.x_features = f

        return batch

    def __repr__(self):
        return '{}(in_channels={}, out_channels={}, heads={}, residual={})\n[{}]'.format(
            self.__class__.__name__,
            self.in_channels,
            self.out_channels, self.num_heads, self.residual,
            super().__repr__(),
        )


@torch.no_grad()
def get_log_deg(batch):
    if "log_deg" in batch:
        log_deg = batch.log_deg
    elif "deg" in batch:
        deg = batch.deg
        log_deg = torch.log(deg + 1).unsqueeze(-1)
    else:
        warnings.warn("Compute the degree on the fly; Might be problematric if have applied edge-padding to complete graphs")
        deg = pyg.utils.degree(batch.edge_index[1],
                               num_nodes=batch.num_nodes,
                               dtype=torch.float
                               )
        log_deg = torch.log(deg + 1)
    log_deg = log_deg.view(batch.num_nodes, 1)
    return log_deg


