import math
from math import pi as PI
from tkinter import NONE
from typing import List, Tuple

import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import init
from torch.nn.modules.linear import Linear
from torch_geometric.utils import scatter
from torch_geometric.utils import softmax as pyg_softmax
from torch_geometric.utils import to_dense_adj, to_dense_batch

try:
    from torch_scatter import scatter
except:
    from torch_geometric.utils import scatter

from fragfm.model.layer import *


class GlobalGTLayer(nn.Module):
    """
    A MPNN layer with global features updates node and edge features
    """

    def __init__(
        self,
        h_dim: int,
        e_dim: int,
        g_dim: int,
        # hid_dim: int,
        n_head: int,
        dropout: float,
        layer_norm: bool,
        activation: str,
        init_method: str,
    ):
        super().__init__()

        if activation == "none":
            pass
        elif activation == "relu":
            self.act = nn.ReLU()
        elif activation == "silu":
            self.act = nn.SiLU()
        elif activation == "leaky_relu":
            self.act = nn.LeakyReLU()
        elif activation == "softplus":
            self.act = nn.Softplus()
        else:
            raise NotImplementedError

        self.h_dim, self.e_dim = h_dim, e_dim
        self.n_head = n_head
        self.d = h_dim // n_head
        assert h_dim % n_head == 0

        self.q = Linear(h_dim, h_dim)
        self.k = Linear(h_dim, h_dim)
        self.v = Linear(h_dim, h_dim)

        """self.q = MLP(
            dims=[h_dim, h_dim, h_dim],
            dropout=dropout,
            layer_norm=layer_norm,
            activation=activation,
            init_method=init_method,
        )
        self.k = MLP(
            dims=[h_dim, h_dim, h_dim],
            dropout=dropout,
            layer_norm=layer_norm,
            activation=activation,
            init_method=init_method,
        )
        self.v = MLP(
            dims=[h_dim, h_dim, h_dim],
            dropout=dropout,
            layer_norm=layer_norm,
            activation=activation,
            init_method=init_method,
        )"""

        # FiLM E to X
        self.e_add = Linear(e_dim, h_dim)
        self.e_mul = Linear(e_dim, h_dim)

        # FiLM y to E
        self.g_e_mul = Linear(g_dim, h_dim)
        self.g_e_add = Linear(g_dim, h_dim)

        # FiLM y to X
        self.g_h_mul = Linear(g_dim, h_dim)
        self.g_h_add = Linear(g_dim, h_dim)

        # Process y
        self.g_g = Linear(g_dim, g_dim)
        self.h_g = NodeToGlobal(h_dim, g_dim, "xavier")
        self.e_g = EdgeToGlobal(e_dim, g_dim, "xavier")

        # Output layers
        self.h_out = Linear(h_dim, h_dim)
        self.e_out = Linear(h_dim, e_dim)
        self.g_out = nn.Sequential(
            nn.Linear(g_dim, g_dim), nn.ReLU(), nn.Linear(g_dim, g_dim)
        )

        self.fc_h = MLP(
            dims=[h_dim, h_dim, h_dim],
            dropout=dropout,
            layer_norm=layer_norm,
            activation=activation,
            init_method=init_method,
        )
        self.fc_e = MLP(
            dims=[e_dim, e_dim, e_dim],
            dropout=dropout,
            layer_norm=layer_norm,
            activation=activation,
            init_method=init_method,
        )
        self.fc_g = MLP(
            dims=[g_dim, g_dim, g_dim],
            dropout=dropout,
            layer_norm=layer_norm,
            activation=activation,
            init_method=init_method,
        )

    def forward(self, h, e_index, e, g, batch):
        bs, e_batch = h.size(0), batch[e_index[1]]
        n_node, n_edge = h.size(0), e.size(0)

        # get query and key
        q = self.q(h)  # [n_node, h_dim]
        k = self.k(h)
        q = q.reshape(n_node, self.n_head, self.d)  # [n_node, n_head, d]
        k = k.reshape(n_node, self.n_head, self.d)

        # compute unnormalized attenttion
        m = q[e_index[0]] * k[e_index[1]]
        m = m / math.sqrt(m.size(-1))  # [n_edge, n_head, d]

        # incroporate edge to attention
        e1 = self.e_mul(e)
        e2 = self.e_add(e)
        e1 = e1.reshape(n_edge, self.n_head, self.d)  # [n_edge, n_head, d]
        e2 = e2.reshape(n_edge, self.n_head, self.d)
        m = m * (e1 + 1) + e2

        # incorporate global to edge
        ge1 = self.g_e_mul(g)
        ge2 = self.g_e_add(g)
        new_e = m.flatten(start_dim=1)
        new_e = new_e * (ge1[e_batch] + 1) + ge2[e_batch]

        # output e
        new_e = self.e_out(new_e)

        # get attention
        attn_score = m.sum(dim=-1)  # [n_edge, n_head]
        attn = pyg_softmax(attn_score, e_index[1], dim=0)  # [n_edge, n_head]

        # v
        v = self.v(h)
        v = v.reshape(n_node, self.n_head, self.d)
        wv = attn.unsqueeze(-1) * v[e_index[0]]  # [n_node, n_head, d]
        wv_aggr = scatter(wv, e_index[1], dim=0, dim_size=n_node, reduce="sum")
        wv_aggr = wv_aggr.flatten(start_dim=1)  # .reshape(n_node, -1)

        # incorp g to h
        gh1 = self.g_h_mul(g)
        gh2 = self.g_h_add(g)
        new_h = wv_aggr * (gh1[batch] + 1) + gh2[batch]

        # output h
        new_h = self.h_out(new_h)

        # process g based on h and e
        g = self.g_g(g)
        hg = self.h_g(h, batch)
        eg = self.e_g(e_index, e, batch)
        new_g = g + hg + eg
        new_g = self.g_out(new_g)

        return self.fc_h(new_h), self.fc_e(new_e), self.fc_g(new_g)
