from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
import torch.nn as nn
import torch
import math

class GraphAttnMultiHead(Module):
    def __init__(self, in_features, out_features, negative_slope=0.2, num_heads=4, bias=True, residual=True):
        super(GraphAttnMultiHead, self).__init__()
        self.num_heads = num_heads
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, num_heads * out_features))
        self.weight_u = Parameter(torch.FloatTensor(num_heads, out_features, 1))
        self.weight_v = Parameter(torch.FloatTensor(num_heads, out_features, 1))
        self.leaky_relu = nn.LeakyReLU(negative_slope=negative_slope)
        self.residual = residual
        if self.residual:
            self.project = nn.Linear(in_features, num_heads*out_features)
        else:
            self.project = None
        if bias:
            self.bias = Parameter(torch.FloatTensor(1, num_heads * out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(-1))
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)
        self.weight.data.uniform_(-stdv, stdv)
        stdv = 1. / math.sqrt(self.weight_u.size(-1))
        self.weight_u.data.uniform_(-stdv, stdv)
        self.weight_v.data.uniform_(-stdv, stdv)

    def forward(self, inputs, adj_mat, requires_weight=False):
        support = torch.mm(inputs, self.weight)
        support = support.reshape(-1, self.num_heads, self.out_features).permute(dims=(1, 0, 2))
        f_1 = torch.matmul(support, self.weight_u).reshape(self.num_heads, 1, -1)
        f_2 = torch.matmul(support, self.weight_v).reshape(self.num_heads, -1, 1)
        logits = f_1 + f_2
        weight = self.leaky_relu(logits)
        masked_weight = torch.mul(weight.unsqueeze(-1), adj_mat.unsqueeze(-1)).to_sparse()
        attn_weights = torch.sparse.softmax(masked_weight, dim=2).to_dense()
        support = torch.matmul(attn_weights.squeeze(-1), support)
        support = support.permute(dims=(1, 0, 2)).reshape(-1, self.num_heads * self.out_features)
        if self.bias is not None:
            support = support + self.bias
        if self.residual:
            support = support + self.project(inputs)
        if requires_weight:
            return support, attn_weights
        else:
            return support, None


class PairNorm(nn.Module):
    def __init__(self, mode='PN', scale=1):
        assert mode in ['None', 'PN', 'PN-SI', 'PN-SCS']
        super(PairNorm, self).__init__()
        self.mode = mode
        self.scale = scale

    def forward(self, x):
        if self.mode == 'None':
            return x
        col_mean = x.mean(dim=0)
        if self.mode == 'PN':
            x = x - col_mean
            rownorm_mean = (1e-6 + x.pow(2).sum(dim=1).mean()).sqrt()
            x = self.scale * x / rownorm_mean
        if self.mode == 'PN-SI':
            x = x - col_mean
            rownorm_individual = (1e-6 + x.pow(2).sum(dim=1, keepdim=True)).sqrt()
            x = self.scale * x / rownorm_individual
        if self.mode == 'PN-SCS':
            rownorm_individual = (1e-6 + x.pow(2).sum(dim=1, keepdim=True)).sqrt()
            x = self.scale * x / rownorm_individual - col_mean
        return x


class GraphAttnSemIndividual(Module):
    def __init__(self, in_features, hidden_size=128, act=nn.Tanh()):
        super(GraphAttnSemIndividual, self).__init__()
        self.project = nn.Sequential(nn.Linear(in_features, hidden_size),
                                     act,
                                     nn.Linear(hidden_size, 1, bias=False))

    def forward(self, inputs, requires_weight=False):
        w = self.project(inputs)
        beta = torch.softmax(w, dim=1)
        if requires_weight:
            return (beta * inputs).sum(1), beta
        else:
            return (beta * inputs).sum(1), None


class Model(nn.Module):
    def __init__(self, configs):
        super(Model, self).__init__()
        self.in_features = configs.seq_len
        self.out_features = configs.num_stations
        self.num_heads = configs.n_heads
        self.hidden_dim = configs.d_model
        self.num_vertex = configs.num_stations
        self.num_layers = configs.e_layers
        self.pred_len = configs.pred_len

        self.encoding = nn.GRU(
            input_size=self.in_features,
            hidden_size=self.num_vertex,
            num_layers=self.num_layers,
            batch_first=True,
            bidirectional=False,
            dropout=0.1
        )
        # self.pos_adj = torch.tensor(configs.adj).cuda()
        # self.neg_adj = torch.tensor(configs.negative_adj).cuda()
        self.pos_gat = GraphAttnMultiHead(
            in_features=self.num_vertex,
            out_features=self.num_vertex,
            num_heads=self.num_heads
        )
        self.neg_gat = GraphAttnMultiHead(
            in_features=self.num_vertex,
            out_features=self.num_vertex,
            num_heads=self.num_heads
        )
        self.mlp_self = nn.Linear(self.num_vertex, self.hidden_dim)
        self.mlp_pos = nn.Linear(self.num_vertex*self.num_heads, self.hidden_dim)
        self.mlp_neg = nn.Linear(self.num_vertex*self.num_heads, self.hidden_dim)
        self.pn = PairNorm(mode='PN-SI')
        self.sem_gat = GraphAttnSemIndividual(in_features=self.hidden_dim,
                                              hidden_size=self.hidden_dim,
                                              act=nn.Tanh())
        self.predictor = nn.Sequential(
            nn.Linear(self.hidden_dim, self.out_features),
            nn.Sigmoid()
        )

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight, gain=0.02)

    def forecast(self, inputs, x_mark_enc=None, x_dec=None, x_mark_dec=None, requires_weight=False, target_x=None):
        _, _, N = inputs.shape

        inputs = inputs.permute(0, 2, 1)
        _, support = self.encoding(inputs)
        support = support.squeeze()
        batch_pos_adj = torch.zeros(self.num_heads, support.shape[0], support.shape[0]).cuda()
        batch_neg_adj = torch.zeros(self.num_heads, support.shape[0], support.shape[0]).cuda()
        pos_support, pos_attn_weights = self.pos_gat(support, batch_pos_adj, False)
        neg_support, neg_attn_weights = self.neg_gat(support, batch_neg_adj, False)
        support = self.mlp_self(support)
        pos_support = self.mlp_pos(pos_support)
        neg_support = self.mlp_neg(neg_support)
        all_embedding = torch.stack((support, pos_support, neg_support), dim=1)
        all_embedding, sem_attn_weights = self.sem_gat(all_embedding, False)
        all_embedding = self.pn(all_embedding)
        # if requires_weight:
        #     return self.predictor(all_embedding), (pos_attn_weights, neg_attn_weights, sem_attn_weights)
        # else:
        dec_out = self.predictor(all_embedding).reshape(inputs.shape[0], -1, self.num_vertex)
        return dec_out

    def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, target_x=None):
        # Normalization from Non-stationary Transformer
        means = target_x.mean(1, keepdim=True).detach() \
            if target_x is not None else x_enc.mean(1, keepdim=True).detach()
        x_enc = x_enc - means
        stdev = torch.sqrt(torch.var(target_x, dim=1, keepdim=True, unbiased=False) + 1e-5) \
            if target_x is not None else torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
        x_enc /= stdev

        dec_out = self.forecast(x_enc)

        # De-Normalization from Non-stationary Transformer
        dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
        dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
        return dec_out