import torch.nn as nn
import torch
from math import sqrt
import torch.nn.functional as F
from torch.nn.functional import gumbel_softmax
import math
import torch.fft
from einops import rearrange

# dynamic graph learning
class DGL(nn.Module):
    def __init__(self, configs, d_len, hops):
        super(DGL, self).__init__()
        self.d_len = d_len
        self.dynamicGNN = DynamicGraphUpdate(configs, d_len, hops)
        self.agg_mlp = torch.nn.Conv1d(d_len, configs.d_model, kernel_size=1, padding=0, stride=1, bias=True)

    def forward(self, x):  # [B,N,D]
        Xout, adj_structure = self.dynamicGNN(x)  # Xout[B,D,N]
        Xout = self.agg_mlp(Xout)
        return Xout, adj_structure
class DynamicGraphUpdate(nn.Module):
    def __init__(self, configs,deep_len, hops):
        super(DynamicGraphUpdate, self).__init__()
        self.enc_in = configs.enc_in
        self.d_model = configs.d_model
        self.deep_len = deep_len
        self.dropout = configs.dropout
        self.nd = configs.nodedim

        self.nodeEmbedding_1 = nn.Parameter(torch.randn(self.enc_in, self.nd))
        self.nodeEmbedding_2 = nn.Parameter(torch.randn(self.nd, self.enc_in))

        self.nodeEmb_gate1 = nn.Sequential(nn.Linear(self.deep_len + self.nd, 1), nn.Tanh(), nn.ReLU())

        self.nodeEmb_gate2 = nn.Sequential(nn.Linear(self.deep_len + self.nd, 1), nn.Tanh(), nn.ReLU())

        self.nodeLinear1 = nn.Linear(self.deep_len, self.nd)
        self.nodeLinear2 = nn.Linear(self.deep_len, self.nd)

        self.mhGNN = GraphConv(self.deep_len, self.deep_len, self.dropout, multiHop=hops)

    def forward(self, x):
        B, _, _ = x.size()
        nodeEmb_1 = self.nodeEmbedding_1.view(1, self.enc_in, self.nd).repeat(B, 1, 1)
        nodeEmb_2 = self.nodeEmbedding_2.view(1, self.nd, self.enc_in).repeat(B, 1, 1)

        nodeGate_1 = self.nodeEmb_gate1(torch.cat([x, nodeEmb_1], dim=-1))
        nodeGate_2 = self.nodeEmb_gate2(torch.cat([x, nodeEmb_2.permute(0, 2, 1)], dim=-1))

        xL1 = nodeGate_1 * self.nodeLinear1(x)
        xL2 = nodeGate_2 * self.nodeLinear2(x)

        nodevector_1 = nodeEmb_1 + xL1
        nodevector_2 = nodeEmb_2 + xL2.permute(0, 2, 1)

        A_out = F.softmax(F.relu(torch.matmul(nodevector_1, nodevector_2)), dim=-1)

        adj_output = A_out

        A_out = [A_out]
        x = x.permute(0, 2, 1)
        x = self.mhGNN(x, A_out)
        return x, adj_output
class gconv(nn.Module):
    def __init__(self):
        super(gconv, self).__init__()

    def forward(self, x, A):
        x = torch.einsum('bfn,bnv->bfv', (x, A))
        return x.contiguous()
class GraphConv(nn.Module):
    def __init__(self, c_in, c_out, dropout, multiHop=2):
        super(GraphConv, self).__init__()
        self.gconv = gconv()
        c_in = (multiHop + 1) * c_in
        self.linear = torch.nn.Conv1d(c_in, c_out, kernel_size=1, padding=0, stride=1, bias=True)
        self.dropout = dropout
        self.multiHop = multiHop

    def forward(self, x, adj):  # [B,D,N]
        multi_X = [x]
        for a in adj:
            x1 = self.gconv(x, a)
            multi_X.append(x1)
            for k in range(2, self.multiHop + 1):
                x2 = self.gconv(x1, a)
                multi_X.append(x2)
                x1 = x2

        x_cat = torch.cat(multi_X, dim=1)
        x_cat = self.linear(x_cat)  # [B,D,N]
        return F.relu(x_cat)
# dynamic graph learning

class EncoderLayer(nn.Module):
    def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
        super(EncoderLayer, self).__init__()
        d_ff = d_ff or 4 * d_model
        self.attention = attention
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu

    def forward(self, x, attn_mask=None, tau=None, delta=None):
        new_x, attn = self.attention(
            x, x, x,
            attn_mask=attn_mask,
            tau=tau, delta=delta
        )
        x = x + self.dropout(new_x)

        y = x = self.norm1(x)
        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
        y = self.dropout(self.conv2(y).transpose(-1, 1))

        return self.norm2(x + y), attn


class EncoderLayer_DGL(nn.Module):
    def __init__(self, attention, d_model, d_ff=None, config=None, dropout=0.1, activation="relu"):
        super(EncoderLayer_DGL, self).__init__()
        d_ff = d_ff or 4 * d_model
        self.attention = attention
        # self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
        # self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu
        self.dgl = DGL(config, d_model, config.order)

    def forward(self, x, attn_mask=None, tau=None, delta=None):
        new_x, attn = self.attention(
            x, x, x,
            attn_mask=attn_mask,
            tau=tau, delta=delta
        )
        x = x + self.dropout(new_x)

        y = x = self.norm1(x)
        # y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
        # y = self.dropout(self.conv2(y).transpose(-1, 1))

        # dynamic graph learning
        output, adjacency_matrix = self.dgl(y)    # output [B,D,N]    adjacency_matrix[B,N,N]
        y = self.dropout(output)
        y = y.permute(0, 2, 1)

        return self.norm2(x + y), attn, adjacency_matrix

class Encoder(nn.Module):
    def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
        super(Encoder, self).__init__()
        self.attn_layers = nn.ModuleList(attn_layers)
        self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
        self.norm = norm_layer

    def forward(self, x, attn_mask=None, tau=None, delta=None):
        # x [B, L, D]
        attns = []
        if self.conv_layers is not None:
            for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)):
                delta = delta if i == 0 else None
                x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
                x = conv_layer(x)
                attns.append(attn)
            x, attn = self.attn_layers[-1](x, tau=tau, delta=None)
            attns.append(attn)
        else:
            for attn_layer in self.attn_layers:
                x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
                attns.append(attn)

        if self.norm is not None:
            x = self.norm(x)

        return x, attns

class Encoder_DGL(nn.Module):
    def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
        super(Encoder_DGL, self).__init__()
        self.attn_layers = nn.ModuleList(attn_layers)
        self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
        self.norm = norm_layer

    def forward(self, x, attn_mask=None, tau=None, delta=None):
        # x [B, L, D]
        attns = []
        adjs = []
        if self.conv_layers is not None:
            for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)):
                delta = delta if i == 0 else None
                x, attn, adj = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
                x = conv_layer(x)
                attns.append(attn)
            x, attn, adj = self.attn_layers[-1](x, tau=tau, delta=None)
            attns.append(attn)
            adjs.append(adj)
        else:
            for attn_layer in self.attn_layers:
                x, attn, adj = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
                attns.append(attn)
                adjs.append(adj)

        if self.norm is not None:
            x = self.norm(x)

        return x, attns, adjs

class FullAttention(nn.Module):
    def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
        super(FullAttention, self).__init__()
        self.scale = scale
        self.mask_flag = mask_flag
        self.output_attention = True # output_attention  # reset to True  todo verifying
        self.dropout = nn.Dropout(attention_dropout)

    def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
        B, L, H, E = queries.shape
        _, S, _, D = values.shape
        scale = self.scale or 1. / sqrt(E)

        scores = torch.einsum("blhe,bshe->bhls", queries, keys)

        # if self.mask_flag:
        #     large_negative = -math.log(1e10)
        #     attention_mask = torch.where(attn_mask == 0, torch.tensor(large_negative), attn_mask)
        #
        #     scores = scores * attention_mask
        if self.mask_flag:
            large_negative = -math.log(1e10)
            attention_mask = torch.where(attn_mask == 0, large_negative, 0)

            scores = scores * attn_mask + attention_mask

        A = self.dropout(torch.softmax(scale * scores, dim=-1))
        V = torch.einsum("bhls,bshd->blhd", A, values)

        if self.output_attention:
            return V.contiguous(), A
        else:
            return V.contiguous(), None

class FullAttention_DGL_CAL(nn.Module):
    def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
        super(FullAttention_DGL_CAL, self).__init__()
        self.scale = scale
        self.mask_flag = mask_flag
        self.output_attention = True # output_attention  # reset to True  todo verifying
        self.dropout = nn.Dropout(attention_dropout)

    def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
        B, L, H, E = queries.shape
        _, S, _, D = values.shape
        scale = self.scale or 1. / sqrt(E)

        scores = torch.einsum("blhe,bshe->bhls", queries, keys)

        # if self.mask_flag:
        #     large_negative = -math.log(1e10)
        #     attention_mask = torch.where(attn_mask == 0, torch.tensor(large_negative), attn_mask)
        #
        #     scores = scores * attention_mask
        if self.mask_flag:
            large_negative = -math.log(1e10)
            attention_mask = torch.where(attn_mask == 0, large_negative, 0)

            scores = scores * attn_mask + attention_mask

        A = self.dropout(torch.softmax(scale * scores, dim=-1))
        V = torch.einsum("bhls,bshd->blhd", A, values)

        if self.output_attention:
            return V.contiguous(), A
        else:
            return V.contiguous(), None

class AttentionLayer(nn.Module):
    def __init__(self, attention, d_model, n_heads, d_keys=None,
                 d_values=None):
        super(AttentionLayer, self).__init__()

        d_keys = d_keys or (d_model // n_heads)
        d_values = d_values or (d_model // n_heads)

        self.inner_attention = attention
        self.query_projection = nn.Linear(d_model, d_keys * n_heads)
        self.key_projection = nn.Linear(d_model, d_keys * n_heads)
        self.value_projection = nn.Linear(d_model, d_values * n_heads)
        self.out_projection = nn.Linear(d_values * n_heads, d_model)
        self.n_heads = n_heads

    def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
        B, L, _ = queries.shape
        _, S, _ = keys.shape
        H = self.n_heads

        queries = self.query_projection(queries).view(B, L, H, -1)
        keys = self.key_projection(keys).view(B, S, H, -1)
        values = self.value_projection(values).view(B, S, H, -1)

        out, attn = self.inner_attention(
            queries,
            keys,
            values,
            attn_mask,
            tau=tau,
            delta=delta
        )
        out = out.view(B, L, -1)

        return self.out_projection(out), attn


class Mahalanobis_mask(nn.Module):
    def __init__(self, input_size):
        super(Mahalanobis_mask, self).__init__()
        frequency_size = input_size // 2 + 1
        self.A = nn.Parameter(torch.randn(frequency_size, frequency_size), requires_grad=True)

    def calculate_prob_distance(self, X):
        XF = torch.abs(torch.fft.rfft(X, dim=-1))
        X1 = XF.unsqueeze(2)
        X2 = XF.unsqueeze(1)

        # B x C x C x D
        diff = X1 - X2

        temp = torch.einsum("dk,bxck->bxcd", self.A, diff)

        dist = torch.einsum("bxcd,bxcd->bxc", temp, temp)

        # exp_dist = torch.exp(-dist)
        exp_dist = 1 / (dist + 1e-10)
        # 对角线置零

        identity_matrices = 1 - torch.eye(exp_dist.shape[-1])
        mask = identity_matrices.repeat(exp_dist.shape[0], 1, 1).to(exp_dist.device)
        exp_dist = torch.einsum("bxc,bxc->bxc", exp_dist, mask)
        exp_max, _ = torch.max(exp_dist, dim=-1, keepdim=True)
        exp_max = exp_max.detach()

        # B x C x C
        p = exp_dist / exp_max

        identity_matrices = torch.eye(p.shape[-1])
        p1 = torch.einsum("bxc,bxc->bxc", p, mask)

        diag = identity_matrices.repeat(p.shape[0], 1, 1).to(p.device)
        p = (p1 + diag) * 0.99

        return p

    def bernoulli_gumbel_rsample(self, distribution_matrix):
        b, c, d = distribution_matrix.shape

        flatten_matrix = rearrange(distribution_matrix, 'b c d -> (b c d) 1')
        r_flatten_matrix = 1 - flatten_matrix

        log_flatten_matrix = torch.log(flatten_matrix / r_flatten_matrix)
        log_r_flatten_matrix = torch.log(r_flatten_matrix / flatten_matrix)

        new_matrix = torch.concat([log_flatten_matrix, log_r_flatten_matrix], dim=-1)
        resample_matrix = gumbel_softmax(new_matrix, hard=True)

        resample_matrix = rearrange(resample_matrix[..., 0], '(b c d) -> b c d', b=b, c=c, d=d)
        return resample_matrix

    def forward(self, X):
        p = self.calculate_prob_distance(X)

        sample = self.bernoulli_gumbel_rsample(p)

        mask = sample.unsqueeze(1)
        cnt = torch.sum(mask, dim=-1)
        return mask
