import math
import torch
import torch.nn as nn
from einops import rearrange
import torch.nn.functional as F


class MLP(nn.Module):
    '''
    Multilayer perceptron to encode/decode high dimension representation of sequential data
    '''

    def __init__(self,
                 f_in,
                 f_out,
                 hidden_dim=128,
                 hidden_layers=2,
                 dropout=0.05,
                 activation='tanh'):
        super(MLP, self).__init__()
        self.f_in = f_in
        self.f_out = f_out
        self.hidden_dim = hidden_dim
        self.hidden_layers = hidden_layers
        self.dropout = dropout
        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'tanh':
            self.activation = nn.Tanh()
        else:
            raise NotImplementedError

        layers = [nn.Linear(self.f_in, self.hidden_dim),
                  self.activation, nn.Dropout(self.dropout)]
        for i in range(self.hidden_layers-2):
            layers += [nn.Linear(self.hidden_dim, self.hidden_dim),
                       self.activation, nn.Dropout(dropout)]

        layers += [nn.Linear(hidden_dim, f_out)]
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        # x:     B x S x f_in
        # y:     B x S x f_out
        y = self.layers(x)
        return y


class KPLayerApprox(nn.Module):
    def __init__(self):
        super(KPLayerApprox, self).__init__()
        self.K = None
        self.K_step = None

    def forward(self, z):
        B, N, input_len, hidden_dim = z.shape
        pred_len = input_len
        z = rearrange(z, 'b n pn m -> (b n) pn m')
        x, y = z[:, :-1], z[:, 1:]

        self.K = torch.linalg.lstsq(x, y).solution

        if torch.isnan(self.K).any():
            print('Encounter K with nan, replace K by identity matrix')
            self.K = torch.eye(self.K.shape[1]).to(
                self.K.device).unsqueeze(0).repeat(B, 1, 1)

        self.K_step = torch.linalg.matrix_power(self.K, pred_len)
        if torch.isnan(self.K_step).any():
            print('Encounter multistep K with nan, replace it by identity matrix')
            self.K_step = torch.eye(self.K_step.shape[1]).to(
                self.K_step.device).unsqueeze(0).repeat(B, 1, 1)
        z_pred = torch.bmm(z[:, -pred_len:, :], self.K_step)
        return z_pred


class KTDlayer_aba(nn.Module):
    """
        Koopman Temporal Detector layer
    """

    def __init__(self, configs,
                 enc_in, snap_size, proj_dim, hidden_dim, hidden_layers):
        super(KTDlayer_aba, self).__init__()
        self.enc_in = enc_in
        self.snap_size = snap_size
        self.dynamics = KPLayerApprox()
        self.encoder = MLP(f_in=snap_size, f_out=proj_dim,
                           hidden_dim=hidden_dim, hidden_layers=hidden_layers)
        self.decoder = MLP(f_in=proj_dim, f_out=snap_size,
                           hidden_dim=hidden_dim, hidden_layers=hidden_layers)
        self.padding_len = snap_size - \
            (enc_in % snap_size) if enc_in % snap_size != 0 else 0

    def forward(self, x):
        # x: B L D
        B, N, D = x.shape

        res = torch.cat((x[:, :, D-self.padding_len:], x), dim=-1)

        res = rearrange(res, 'b n (p_n p) -> b n p_n p', p=self.snap_size)

        res = self.encoder(res)  # b n p_n m, m means hidden dim

        # b*n f_n m, f_n means forecast patch num
        x_pred = self.dynamics(res)

        x_pred = self.decoder(x_pred)     # b*n f_n p

        x_pred = rearrange(x_pred, '(b n) f_n p -> b n (f_n p)', b=B)

        return x_pred

# dynamic graph learning
class DGL(nn.Module):
    def __init__(self, configs, d_len, hops):
        super(DGL, self).__init__()
        self.d_len = d_len
        self.dynamicGNN = DynamicGraphUpdate(configs, d_len, hops)
        self.agg_mlp = torch.nn.Conv1d(d_len, configs.d_model, kernel_size=1, padding=0, stride=1, bias=True)

    def forward(self, x):  # [B,N,D]
        Xout, adj_structure = self.dynamicGNN(x)  # Xout[B,D,N]
        Xout = self.agg_mlp(Xout)
        return Xout, adj_structure
class DynamicGraphUpdate(nn.Module):
    def __init__(self, configs,deep_len, hops):
        super(DynamicGraphUpdate, self).__init__()
        self.enc_in = configs.enc_in
        self.d_model = configs.d_model
        self.deep_len = deep_len
        self.dropout = configs.dropout
        self.nd = configs.nodedim

        self.nodeEmbedding_1 = nn.Parameter(torch.randn(self.enc_in, self.nd))
        self.nodeEmbedding_2 = nn.Parameter(torch.randn(self.nd, self.enc_in))

        self.nodeEmb_gate1 = nn.Sequential(nn.Linear(self.deep_len + self.nd, 1), nn.Tanh(), nn.ReLU())

        self.nodeEmb_gate2 = nn.Sequential(nn.Linear(self.deep_len + self.nd, 1), nn.Tanh(), nn.ReLU())

        self.nodeLinear1 = nn.Linear(self.deep_len, self.nd)
        self.nodeLinear2 = nn.Linear(self.deep_len, self.nd)

        self.mhGNN = GraphConv(self.deep_len, self.deep_len, self.dropout, multiHop=hops)

    def forward(self, x):
        B, _, _ = x.size()
        nodeEmb_1 = self.nodeEmbedding_1.view(1, self.enc_in, self.nd).repeat(B, 1, 1)
        nodeEmb_2 = self.nodeEmbedding_2.view(1, self.nd, self.enc_in).repeat(B, 1, 1)

        nodeGate_1 = self.nodeEmb_gate1(torch.cat([x, nodeEmb_1], dim=-1))
        nodeGate_2 = self.nodeEmb_gate2(torch.cat([x, nodeEmb_2.permute(0, 2, 1)], dim=-1))

        xL1 = nodeGate_1 * self.nodeLinear1(x)
        xL2 = nodeGate_2 * self.nodeLinear2(x)

        nodevector_1 = nodeEmb_1 + xL1
        nodevector_2 = nodeEmb_2 + xL2.permute(0, 2, 1)

        A_out = F.softmax(F.relu(torch.matmul(nodevector_1, nodevector_2)), dim=-1)

        adj_output = A_out

        A_out = [A_out]
        x = x.permute(0, 2, 1)
        x = self.mhGNN(x, A_out)
        return x, adj_output
class gconv(nn.Module):
    def __init__(self):
        super(gconv, self).__init__()

    def forward(self, x, A):
        x = torch.einsum('bfn,bnv->bfv', (x, A))
        return x.contiguous()
class GraphConv(nn.Module):
    def __init__(self, c_in, c_out, dropout, multiHop=2):
        super(GraphConv, self).__init__()
        self.gconv = gconv()
        c_in = (multiHop + 1) * c_in
        self.linear = torch.nn.Conv1d(c_in, c_out, kernel_size=1, padding=0, stride=1, bias=True)
        self.dropout = dropout
        self.multiHop = multiHop

    def forward(self, x, adj):  # [B,D,N]
        multi_X = [x]
        for a in adj:
            x1 = self.gconv(x, a)
            multi_X.append(x1)
            for k in range(2, self.multiHop + 1):
                x2 = self.gconv(x1, a)
                multi_X.append(x2)
                x1 = x2

        x_cat = torch.cat(multi_X, dim=1)
        x_cat = self.linear(x_cat)  # [B,D,N]
        return F.relu(x_cat)
# dynamic graph learning
