import torch
from torch import nn
# from models.gcl import *
from models.gcl import E_GCL_X, E_GCL_AT_X, GMNL, GCL
from einops import rearrange


class EGNN_X(nn.Module):
    def __init__(self, num_past, num_future, in_node_nf, in_edge_nf, hidden_nf, device='cpu', act_fn=nn.LeakyReLU(0.2),
                 n_layers=4, coords_weight=3.0):
        super(EGNN_X, self).__init__()
        self.hidden_nf = hidden_nf
        self.device = device
        self.n_layers = n_layers

        self.embedding = nn.Linear(in_node_nf, self.hidden_nf)
        self.theta = nn.Parameter(torch.FloatTensor(num_future, num_past))
        for i in range(0, n_layers):
            self.add_module("gcl_%d" % i, E_GCL_X(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf,
                                                  act_fn=act_fn, recurrent=True, coords_weight=coords_weight))

        self.num_past = num_past
        self.TimeEmbedding = nn.Embedding(num_past, self.hidden_nf)

        self.reset_parameters()
        self.to(self.device)

    def reset_parameters(self):
        self.theta.data.uniform_(-1, 1)

    def forward(self, h, x, edges, edge_attr):
        h = self.embedding(h.unsqueeze(0).repeat(x.shape[0], 1, 1))
        time_embedding = self.TimeEmbedding(torch.arange(self.num_past).to(self.device)).unsqueeze(1)
        h = h + time_embedding

        for i in range(0, self.n_layers):
            h, x = self._modules["gcl_%d" % i](h, x, edges, edge_attr=edge_attr, Fs=None)

        if x.shape[0] == 1:
            x_hat = x.squeeze(0)
        else:
            x_hat = torch.einsum("ij,jkts->ikts", torch.softmax(self.theta, dim=1), x)
        return x_hat


class ESTAG_X(nn.Module):
    def __init__(self, num_past, num_future, in_node_nf, in_edge_nf, hidden_nf, fft, eat, device, \
                 n_layers, n_nodes, nodes_att_dim=0, act_fn=nn.LeakyReLU(0.2), coords_weight=1.0, with_mask=False,
                 tempo=True):
        super(ESTAG_X, self).__init__()
        self.hidden_nf = hidden_nf
        self.fft = fft
        self.eat = eat
        self.device = device
        self.n_layers = n_layers
        self.n_nodes = n_nodes
        self.num_past = num_past
        self.tempo = tempo

        self.embedding = nn.Linear(in_node_nf, self.hidden_nf)
        self.theta = nn.Parameter(torch.FloatTensor(num_future, num_past))
        self.TimeEmbedding = nn.Embedding(num_past, self.hidden_nf)

        for i in range(n_layers):
            self.add_module("egcl_%d" % (i * 2 + 1),
                            E_GCL_X(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf,
                                    nodes_att_dim=nodes_att_dim,
                                    act_fn=act_fn, recurrent=True, coords_weight=coords_weight))
            if self.eat:
                self.add_module("egcl_at_%d" % (i * 2 + 2),
                                E_GCL_AT_X(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf,
                                           act_fn=act_fn, recurrent=True, coords_weight=coords_weight,
                                           with_mask=with_mask))

        self.attn_mlp = nn.Sequential(
            nn.Linear(hidden_nf, 1),
            nn.Sigmoid())

        self.reset_parameters()
        self.to(self.device)

    def reset_parameters(self):
        # self.theta.data.uniform_(-1, 1)
        # O init
        self.theta.data *= 0

    def FFT(self, h, x, n_nodes, edges):
        x_ = rearrange(x, 't (b n) d -> t b n d', n=n_nodes)
        x_bar = torch.mean(x_, dim=-2, keepdim=True)
        x_norm = x_ - x_bar
        x_norm = rearrange(x_norm, 't b n d -> (b n) d t')

        ### (b*n_node, 3, num_past)
        F = torch.fft.fftn(x_norm, dim=-1)

        ### (b*n_node, num_past-1)
        attn_val = self.attn_mlp(h[1:]).squeeze(-1).transpose(0, 1)

        F = F[..., 1:]
        F_i = F[edges[0]]
        F_j = F[edges[1]]

        ## (n_egde, num_past-1)
        edge_attr = torch.abs(torch.sum(torch.conj(F_i) * F_j, dim=-2))
        edge_attr = edge_attr * (attn_val[edges[0]] * attn_val[edges[1]])

        edge_attr_norm = edge_attr / (torch.sum(edge_attr, dim=-1, keepdim=True) + 1e-6)

        ### (b*n_node, num_past-1)
        Fs = torch.abs(torch.sum(F ** 2, dim=-2))
        Fs = Fs * attn_val

        Fs_norm = Fs / (torch.sum(Fs, dim=-1, keepdim=True) + 1e-6)

        # print(edge_attr_norm.shape)
        # print(Fs_norm.shape)
        # assert False
        return edge_attr_norm, Fs_norm

    def forward(self, h, x, edges, edge_attr):
        h = self.embedding(h.unsqueeze(0).repeat(x.shape[0], 1, 1))

        time_embedding = self.TimeEmbedding(torch.arange(self.num_past).to(self.device)).unsqueeze(1)
        h = h + time_embedding

        Fs = None
        if self.fft:
            ### only CA
            edge_attr, Fs = self.FFT(h, x[:, :, 1, :], self.n_nodes, edges=edges)  # only CA

            ### all node
            # edge_attr, Fs = self.FFT(x, self.n_nodes, edges=edges)

            ### using cached feature
            # edge_attr, Fs = edge_attr_fft, Fs_fft
        # print(edge_attr.shape)
        # [112600, 9]

        for i in range(self.n_layers):
            h, x = self._modules["egcl_%d" % (i * 2 + 1)](h, x, edges, edge_attr, Fs)

            if self.eat:
                h, x = self._modules["egcl_at_%d" % (i * 2 + 2)](h, x)

        # print(x.shape) torch.Size([10, 21300, 4, 3])

        if self.tempo:
            x_hat = torch.einsum("ij,jkts->ikts", self.theta, x - x[-1].unsqueeze(0)).squeeze(0) + x[-1]
        else:
            x_hat = torch.einsum("ij,jkts->ikts", torch.softmax(self.theta, dim=1), x).squeeze(0)

        return x_hat


class STFT_X(nn.Module):
    def __init__(self, num_past, num_future, in_node_nf, in_edge_nf, hidden_nf, fft, eat, device, n_layers, n_nodes,
                 hop_length, window_length,
                 n_fft, nodes_att_dim=0,
                 act_fn=nn.SiLU(), coords_weight=1.0, with_mask=False, tempo=True, filter=True):
        super(STFT_X, self).__init__()
        self.hidden_nf = hidden_nf
        self.fft = fft
        self.k = 2
        self.eat = eat
        self.device = device
        self.n_layers = n_layers
        self.n_nodes = n_nodes
        self.num_past = num_past
        self.hop_length = hop_length
        self.window_length = window_length
        self.n_fft = n_fft
        self.tempo = tempo
        self.filter = filter
        # self.PosEmbedding = PositionalEncoding(hidden_nf, max_len=num_past)
        self.TimeEmbedding = nn.Embedding(num_past, self.hidden_nf)
        # self.PosEmbedding = PositionalEncoding(hidden_nf)
        self.embedding = nn.Linear(in_node_nf, self.hidden_nf)
        self.edge_attr = nn.Linear(1, 9)

        self.hidden_nf += n_fft // 2 + 1

        for i in range(n_layers):
            self.add_module("egcl_%d" % (i * 2 + 1),
                            E_GCL_X(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf,
                                    nodes_att_dim=nodes_att_dim,
                                    act_fn=act_fn, recurrent=True, coords_weight=coords_weight, norm_diff=True,
                                    clamp=True))
            # if self.eat:
            # self.add_module("egcl_at_%d" % (i*2+2), E_GCL_AT(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf,
            # act_fn=act_fn, recurrent=True, coords_weight=coords_weight, with_mask=with_mask))
        self.theta = nn.Parameter(torch.FloatTensor(num_future, num_past))

        self.attn_mlp = nn.Sequential(
            nn.Linear(hidden_nf, 1),
            nn.Sigmoid())

        self.reset_parameters()
        self.seq_len = self.num_past
        self.pred_len = 1
        self.predict_linear = nn.Linear(self.seq_len, self.seq_len + self.pred_len)
        self.to(self.device)

    def reset_parameters(self):
        self.theta.data.uniform_(-1, 1)
        # O init
        self.theta.data *= 0

    def forward(self, h, x, edges, edge_attr):

        # print(h.shape) torch.Size([21300, 4])

        # print(x.shape) torch.Size([10, 21300, 4, 3])

        # print(edge_attr.shape) torch.Size([112600, 1])

        h = self.embedding(h.unsqueeze(0).repeat(x.shape[0], 1, 1))

        time_embedding = self.TimeEmbedding(torch.arange(self.num_past).to(self.device)).unsqueeze(1)
        h = h + time_embedding
        # h.shape: torch.Size([10, 21300, 16])
        Fs = None

        window = torch.hann_window(self.window_length).to(x.device)

        stft_results = []
        # x.shape: [10, 21300, 4, 3]
        x = x.permute(1, 2, 0, 3)
        # 存放所有通道的 STFT 结果
        stft_results_per_channel = []

        for channel in range(x.shape[-1]):  # 遍历 xyz 三个轴
            stft_results_per_window = []

            # 遍历通道（这里的 4 是多通道数）
            for window_idx in range(x.shape[1]):  # x.shape[1] == 4
                stft_result = torch.stft(
                    x[:, window_idx, :, channel],  # 处理每个通道的时间序列
                    n_fft=self.n_fft,
                    hop_length=self.hop_length,
                    win_length=self.window_length,
                    window=window,
                    return_complex=True
                )
                stft_results_per_window.append(stft_result.abs())

            # 计算每个通道的频域特征
            stft_results_per_channel.append(
                torch.sqrt(sum(r ** 2 for r in stft_results_per_window) / len(stft_results_per_window))
            )

        # 合并所有轴的 STFT 结果
        stft_sqr = torch.sqrt(
            sum(r ** 2 for r in stft_results_per_channel) / len(stft_results_per_channel)
        )  # [21300, 6, 3]
        # 假设 stft_sqr.shape: [21300, 6, 3]
        # 重新排列维度
        stft_sqr = stft_sqr.permute(0, 2, 1)  # [21300, 3, 6]

        for i in range(0, h.shape[0] // self.hop_length):
            stft_sqr[:, i] = (stft_sqr[:, i] + stft_sqr[:, i + 1]) / 2

        # 把频域的特征和时域对齐
        stft_sqr = stft_sqr[:, :h.shape[0] // self.hop_length]
        # [21300, 2, 6]

        # 填充以对齐时间长度
        stft_sqr = stft_sqr.repeat_interleave(self.hop_length, dim=1)  # [21300, 40, 3]
        # [21300, 10, 6]

        # 与原来的feature叠加
        h = torch.cat((h, stft_sqr.permute(1, 0, 2)), dim=2)
        # period : [n, k]
        # weight : [b, n, k]
        # h: [bn, t, emb]
        # x: [bn, t, 3]

        x = x.permute(2, 0, 1, 3)
        for i in range(self.n_layers):
            h, x = self._modules["egcl_%d" % (i * 2 + 1)](h, x, edges, edge_attr, Fs)


        x_hat = torch.einsum("ij,jkts->ikts", self.theta, x - x[-1].unsqueeze(0)).squeeze(0) + x[-1]

        return x_hat


class NS_EGNN_X(nn.Module):
    def __init__(self, num_past, num_future, in_node_nf, in_edge_nf, hidden_nf, fft, eat, device, n_layers, n_nodes,
                 hop1=5, hop2=10, hop3=15, nodes_att_dim=0,
                 act_fn=nn.SiLU(), coords_weight=1.0, with_mask=False, tempo=True, filter=True):
        super(NS_EGNN_X, self).__init__()
        self.hidden_nf = hidden_nf
        self.fft = fft
        self.k = 2
        self.eat = eat
        self.device = device
        self.n_layers = n_layers
        self.n_nodes = n_nodes
        self.num_past = num_past
        self.tempo = tempo
        self.filter = filter
        # self.PosEmbedding = PositionalEncoding(hidden_nf, max_len=num_past)
        self.TimeEmbedding = nn.Embedding(num_past, self.hidden_nf)
        self.hop1 = hop1
        self.hop2 = hop2
        self.hop3 = hop3
        self.embedding = nn.Linear(in_node_nf, self.hidden_nf)
        # self.spectral_embedding = nn.Linear()
        self.spectral_embedding = nn.Linear(hop1 + hop2 + hop3 + 3, self.hidden_nf // 2)

        self.hidden_nf += self.hidden_nf // 2
        # self.hidden_nf += hop1 + 1
        # self.hidden_nf += hop2 + 1
        # self.hidden_nf += hop3 + 1
        # self.hidden_nf += 21
        # self.hidden_nf += 51
        for i in range(n_layers):
            self.add_module("prior_egcl_%d" % (i * 2 + 1),
                            E_GCL_X(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf,
                                    nodes_att_dim=nodes_att_dim,
                                    act_fn=act_fn, recurrent=True, coords_weight=coords_weight, norm_diff=True,
                                    clamp=True))

        for i in range(n_layers):
            self.add_module("egcl_%d" % (i * 2 + 1),
                            E_GCL_X(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf,
                                    nodes_att_dim=nodes_att_dim,
                                    act_fn=act_fn, recurrent=True, coords_weight=coords_weight, norm_diff=True,
                                    clamp=True))
            if self.eat:
                self.add_module("egcl_at_%d" % (i * 2 + 2),
                                E_GCL_AT_X(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf,
                                           act_fn=act_fn, recurrent=True, coords_weight=coords_weight,
                                           with_mask=with_mask))
        self.theta = nn.Parameter(torch.FloatTensor(num_future, num_past))

        self.attn_mlp = nn.Sequential(
            nn.Linear(hidden_nf, 1),
            nn.Sigmoid())

        self.reset_parameters()
        self.seq_len = self.num_past
        self.pred_len = 1
        self.predict_linear = nn.Linear(self.seq_len, self.seq_len + self.pred_len)
        self.to(self.device)

    def reset_parameters(self):
        self.theta.data.uniform_(-1, 1)
        # O init
        self.theta.data *= 0

    def forward(self, h, x, edges, edge_attr):
        """parameters
            h: (b*n_node, 1)
            x: (num_past, b*n_node, 3)
            edges: (2, n_edge)
            edge_attr: (n_edge, 3)
        """

        h = self.embedding(h.unsqueeze(0).repeat(x.shape[0], 1, 1))

        time_embedding = self.TimeEmbedding(torch.arange(self.num_past).to(self.device)).unsqueeze(1)
        h = h + time_embedding
        Fs = None

        x = x.permute(1, 2, 0, 3)
        hop_lengths = [self.hop1, self.hop2, self.hop3]
        device = x.device

        # x此时已经是 permute(1, 2, 0, 3) 后的形状:
        # x.shape => [batch_nodes, channel, time_length, axis] [21300, 4, 30, 3]
        #   - x.shape[0] = batch_nodes
        #   - x.shape[1] = channel
        #   - x.shape[2] = time_length
        #   - x.shape[3] = axis (3)
        stft_list = []
        for hop_length in hop_lengths:
            win_length = hop_length * 2  # 可以根据需要自行确定
            n_fft = win_length
            window = torch.hann_window(win_length).to(device)

            # ============ 多通道、多轴 STFT 处理 ============ #
            stft_results_per_channel = []

            # 遍历轴（xyz）
            for channel in range(x.shape[-1]):  # x.shape[-1] == 3
                stft_results_per_window = []

                # 遍历通道
                for window_idx in range(x.shape[1]):  # x.shape[1] == 4 (多通道)
                    stft_result = torch.stft(
                        x[:, window_idx, :, channel],  # 处理每个通道的时间序列
                        n_fft=n_fft,
                        hop_length=hop_length,
                        win_length=win_length,
                        window=window,
                        return_complex=True
                    )
                    stft_results_per_window.append(stft_result.abs())

                stft_results_per_channel.append(
                    torch.sqrt(sum(r ** 2 for r in stft_results_per_window) / len(stft_results_per_window))
                )

            stft_sqr = torch.sqrt(
                sum(r ** 2 for r in stft_results_per_channel) / len(stft_results_per_channel)
            )
            # [21300, 3, 16]
            stft_sqr = stft_sqr.permute(0, 2, 1)  # [21300, 3, 6]

            for i in range(0, h.shape[0] // hop_length):
                stft_sqr[:, i] = (stft_sqr[:, i] + stft_sqr[:, i + 1]) / 2

            # 把频域的特征和时域对齐
            stft_sqr = stft_sqr[:, :h.shape[0] // hop_length]
            # [21300, 2, 6]

            # 填充以对齐时间长度
            stft_sqr = stft_sqr.repeat_interleave(hop_length, dim=1)  # [21300, 40, 3]

            stft_list.append(stft_sqr)
            # print(h.shape)
            # print(stft_sqr.shape)
            # exit()
            # 与原来的feature叠加

        stft_concat = torch.cat(stft_list, dim=2)
        stft_concat = self.spectral_embedding(stft_concat)
        h = torch.cat((h, stft_concat.permute(1, 0, 2)), dim=2)

        # print(h.shape)
        # exit()
        # period : [n, k]
        # weight : [b, n, k]
        # h: [bn, t, emb]
        # x: [bn, t, 3]
        # exit()

        # for i in range(self.n_layers):
        # h, x = self._modules["prior_egcl_%d" % (i*2+1)](h, x, edges, edge_attr, Fs)
        x = x.permute(2, 0, 1, 3)

        for i in range(self.n_layers):
            h, x = self._modules["egcl_%d" % (i * 2 + 1)](h, x, edges, edge_attr, Fs)
            # if self.eat:
            # h, x = self._modules["egcl_at_%d" % (i*2+2)](h, x)

        # x_hat = torch.einsum("ij,jkt->ikt", self.theta, x - x[-1].unsqueeze(0)).squeeze(0) + x[-1]
        x_hat = torch.einsum("ij,jkts->ikts", self.theta, x - x[-1].unsqueeze(0)).squeeze(0) + x[-1]
        return x_hat


class GMN(nn.Module):
    def __init__(self, num_past, num_future, in_node_nf, in_edge_nf, hidden_nf, device='cpu', act_fn=nn.LeakyReLU(0.2),
                 n_layers=4, coords_weight=3.0):
        super(GMN, self).__init__()
        self.hidden_nf = hidden_nf
        self.device = device
        self.n_layers = n_layers
        self.num_past = num_past

        self.TimeEmbedding = nn.Embedding(num_past, self.hidden_nf)
        self.embedding = nn.Linear(in_node_nf, self.hidden_nf)
        self.theta = nn.Parameter(torch.FloatTensor(num_future, num_past))
        for i in range(0, n_layers):
            self.add_module("gmnl_%d" % i, GMNL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf,
                                                act_fn=act_fn, recurrent=True, coords_weight=coords_weight))

        self.reset_parameters()
        self.to(self.device)

    def reset_parameters(self):
        self.theta.data.uniform_(-1, 1)

    def forward(self, h, x, edges, edge_attr):
        h = self.embedding(h.unsqueeze(0).repeat(x.shape[0], 1, 1))
        time_embedding = self.TimeEmbedding(torch.arange(self.num_past).to(self.device)).unsqueeze(1)
        h = h + time_embedding

        for i in range(0, self.n_layers):
            h, x = self._modules["gmnl_%d" % i](h, x, edges, edge_attr=edge_attr)

        if x.shape[0] == 1:
            x_hat = x.squeeze(0)
        else:
            x_hat = torch.einsum("ij,jkts->ikts", torch.softmax(self.theta, dim=1), x).squeeze(0)
        return x_hat


class GNN_X(nn.Module):
    def __init__(self, num_past, num_future, input_dim, in_edge_nf, hidden_nf, device='cpu', act_fn=nn.SiLU(),
                 n_layers=4, attention=0, recurrent=False):
        super(GNN_X, self).__init__()
        self.hidden_nf = hidden_nf
        self.device = device
        self.num_future = num_future
        self.n_layers = n_layers
        for i in range(0, n_layers):
            self.add_module("gcl_%d" % i, GCL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_nf=in_edge_nf,
                                              act_fn=act_fn, attention=attention, recurrent=recurrent))

        self.decoder = nn.Sequential(nn.Linear(hidden_nf, hidden_nf),
                                     act_fn,
                                     nn.Linear(hidden_nf, 4 * 3))
        self.embedding = nn.Sequential(nn.Linear(input_dim, hidden_nf))
        self.theta = nn.Parameter(torch.FloatTensor(num_future, num_past))
        self.num_past = num_past
        self.TimeEmbedding = nn.Embedding(num_past, self.hidden_nf)

        self.reset_parameters()
        self.to(self.device)

    def reset_parameters(self):
        self.theta.data.uniform_(-1, 1)

    def forward(self, nodes, edges, edge_attr=None):
        h = self.embedding(nodes)
        time_embedding = self.TimeEmbedding(torch.arange(self.num_past).to(self.device)).unsqueeze(1)
        h = h + time_embedding

        for i in range(0, self.n_layers):
            h, _ = self._modules["gcl_%d" % i](h, edges, edge_attr=edge_attr)

        x = self.decoder(h)

        x_hat = torch.einsum("ij,jkt->ikt", torch.softmax(self.theta, dim=1), x).squeeze(0)

        return x_hat.reshape(self.num_future, -1, 4, 3)
