import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class EvolutionaryStateModeling(nn.Module):
    def __init__(self, dim, n_nodes, n_layers, dropout, activation, local_gnn_decay=1):
        super().__init__()

        self.dim = dim
        self.n_nodes = n_nodes
        self.n_layers = n_layers
        self.activation = activation
        self.local_gnn_decay = local_gnn_decay

        self.gamma = 1 / self.local_gnn_decay
        self.q = nn.Linear(self.dim, self.dim)
        self.k = nn.Linear(self.dim, self.dim)
        self.v = nn.Linear(self.dim, self.dim)

        self.w0 = nn.Parameter(torch.randn(1, self.dim))
        self.a0 = nn.Parameter(torch.randn(1, self.dim))

        self.h_r = nn.Linear(self.dim, self.dim)
        self.x_r = nn.Linear(self.dim, self.dim)
        self.h_z = nn.Linear(self.dim, self.dim)
        self.x_z = nn.Linear(self.dim, self.dim)
        self.h_m = nn.Linear(self.dim, self.dim)
        self.x_m = nn.Linear(self.dim, self.dim)

        self.out_proj = nn.Linear(self.dim, self.dim)

        self.dropout = nn.Dropout(dropout)

        self.reset_parameters()

        self.norm = nn.LayerNorm(self.dim)

    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.dim)

        for module in [self.h_r, self.x_r, self.h_z, self.x_z, self.h_m, self.x_m]:
            for weight in module.parameters():
                nn.init.uniform_(weight, -stdv, stdv)

    def forward(self, x):
        # (..., N, D)
        bs = x.shape[0]
        h = torch.zeros_like(x[..., 0, :])  # (..., T, D)
        s = torch.zeros(bs, self.dim, self.dim, device=x.device)
        m = torch.zeros(bs, 1, self.dim, device=x.device)

        output = []
        attn_weight, attn_value = [], []
        for t in range(self.n_nodes):
            xt = x[..., t, :]
            query = self.q(xt).unsqueeze(dim=-2)  # (B, 1, D)
            key = self.k(xt)  # (B, D, 1)
            value = self.v(xt).unsqueeze(dim=-2)  # (B, 1, D)

            fn = lambda x: torch.exp(x)  # This activation may lead to huge values, e.g. 1e26
            query = fn(query)
            key = fn(key)

            a = self.a0.sigmoid()
            w = self.w0.sigmoid()

            s = a * s + w * key.unsqueeze(dim=-1) * value  # (B, D, D)
            m = a * m + w * key.unsqueeze(dim=1)  # (B, 1, D)

            div = torch.matmul(query, self.dropout(m).transpose(1, 2)) + 1e-5
            out = torch.matmul(query, self.dropout(s)) / div
            if hasattr(self, 'attn_weight'):  # for visualization
                attn_weight.append(div.squeeze().clone().detach())
            if hasattr(self, 'attn_value'):
                attn_value.append(out.clone().detach().sum(dim=-1))
            out = torch.where(torch.isinf(out), 0, out)
            out = out.squeeze(dim=1)

            out = self.norm(out)
            out = self.gamma * out

            r = torch.sigmoid(out + self.x_r(xt) + self.h_r(h))
            z = torch.sigmoid(out + self.x_z(xt) + self.h_z(h))
            n = torch.tanh(out + self.x_m(xt) + r * self.h_m(h))
            h = (1 - z) * n + z * h
            h = self.dropout(h)
            output.append(h)

        output = self.out_proj(torch.stack(output, dim=-2))  # (..., T, D)
        x = x + output
        if hasattr(self, 'attn_weight'):
            self.attn_weight.append(torch.stack(attn_weight, dim=1))
        if hasattr(self, 'attn_value'):
            self.attn_value.append(torch.stack(attn_value, dim=1))
        if hasattr(self, 'out_value'):
            self.out_value.append(output.clone().detach().sum(dim=-1))

        if self.activation == 'relu':
            x = torch.relu(x)
        elif self.activation == 'tanh':
            x = torch.tanh(x)
        elif self.activation == 'leaky':
            x = F.leaky_relu(x)

        return x