from __future__ import division
import torch
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import torch.optim as optim
from torch import nn
import matplotlib.pyplot as plt
from utils import generate_dataset, get_normalized_adj, get_Laplace, calculate_random_walk_matrix,gauss_loss
from model import *
import random,os,copy
import math
import tqdm
from scipy.stats import norm
import pickle as pk
import os
os.environ["CUDA_VISIBLE_DEVICES"]='1'
# Parameters
torch.manual_seed(0)
device = torch.device('cuda') #use_gpu = False
#num_timesteps_input = 24
num_timesteps_output = 4 # num_timesteps_input # 12
num_timesteps_input = num_timesteps_output

A = np.load('STTD-main/ny_data_full_15min/adj_rand0.npy') # change the loading folder
# 再导入特矩阵
X = np.load('STTD-main/ny_data_full_15min/cta_samp_rand0.npy')

space_dim = X.shape[1]
batch_size = 4 # 12
hidden_dim_s = 42
hidden_dim_t = 42
rank_s = 20
rank_t = 4

epochs = 100 #35#50 #500

# Initial networks
TCN1 = B_TCN(space_dim, hidden_dim_t, kernel_size=3).to(device=device)
TCN2 = B_TCN(hidden_dim_t, rank_t, kernel_size = 3, activation = 'linear').to(device=device)
TCN3 = B_TCN(rank_t, hidden_dim_t, kernel_size= 3).to(device=device)
# TCN4 = B_TCN(hidden_dim_t, space_dim, kernel_size =6, activation = 'linear')
TNB = GaussNorm(hidden_dim_t,space_dim).to(device=device)
SCN1 = D_GCN(num_timesteps_input, hidden_dim_s, 3).to(device=device)
SCN2 = D_GCN(hidden_dim_s, rank_s, 2, activation = 'linear').to(device=device)
SCN3 = D_GCN(rank_s, hidden_dim_s, 2).to(device=device)
# SCN4 = D_GCN(hidden_dim_s, num_timesteps_input, 3, activation = 'linear')
SNB = GaussNorm(hidden_dim_s,num_timesteps_output).to(device=device)
STmodel = ST_Gau(SCN1, SCN2, SCN3, TCN1, TCN2, TCN3, SNB,TNB).to(device=device)

# Load dataset
#A = np.load('ny_data_60min/adj_only10_rand0.npy')
#X = np.load('ny_data_60min/cta_samp_only10_rand0.npy')
X = X.T
X = X.astype(np.float32)
X = X.reshape((X.shape[0],1,X.shape[1]))
split_line1 = int(X.shape[2] * 0.6)
split_line2 = int(X.shape[2] * 0.7)

print(X.shape,A.shape)
# normalization
max_value = np.max(X.shape[2] * 0.6)
#X = X/max_value
#means = np.mean(X, axis=(0, 2))
#X = X - means.reshape(1, -1, 1)
#stds = np.std(X, axis=(0, 2))
#X = X / stds.reshape(1, -1, 1)

train_original_data = X[:, :, :split_line1]
val_original_data = X[:, :, split_line1:split_line2]
test_original_data = X[:, :, split_line2:]
training_input, training_target = generate_dataset(train_original_data,
                                                    num_timesteps_input=num_timesteps_input,
                                                    num_timesteps_output=num_timesteps_output)
val_input, val_target = generate_dataset(val_original_data,
                                            num_timesteps_input=num_timesteps_input,
                                            num_timesteps_output=num_timesteps_output)
test_input, test_target = generate_dataset(test_original_data,
                                            num_timesteps_input=num_timesteps_input,
                                            num_timesteps_output=num_timesteps_output)
print('input shape: ',training_input.shape,val_input.shape,test_input.shape)


A_wave = get_normalized_adj(A)
A_q = torch.from_numpy((calculate_random_walk_matrix(A_wave).T).astype('float32'))
A_h = torch.from_numpy((calculate_random_walk_matrix(A_wave.T).T).astype('float32'))
A_q = A_q.to(device=device)
A_h = A_h.to(device=device)
# Define the training process
# criterion = nn.MSELoss()
optimizer = optim.Adam(STmodel.parameters(), lr=1e-3)
training_nll   = []
validation_nll = []
validation_mae = []

for epoch in range(epochs):
    ## Step 1, training
    """
    # Begin training, similar training procedure from STGCN
    Trains one epoch with the given data.
    :param training_input: Training inputs of shape (num_samples, num_nodes,
    num_timesteps_train, num_features).
    :param training_target: Training targets of shape (num_samples, num_nodes,
    num_timesteps_predict).
    :param batch_size: Batch size to use during training.
    """
    permutation = torch.randperm(training_input.shape[0])
    epoch_training_losses = []
    for i in range(0, training_input.shape[0], batch_size):
        STmodel.train()
        optimizer.zero_grad()

        indices = permutation[i:i + batch_size]
        X_batch, y_batch = training_input[indices], training_target[indices]
        X_batch = X_batch.to(device=device)
        y_batch = y_batch.to(device=device)

        loc_train,scale_train = STmodel(X_batch,A_q,A_h)
#       print('batch and n',np.mean(X_batch.detach().cpu().numpy()),np.mean(n_train.detach().cpu().numpy()))
#        print(np.mean(n_train.detach().cpu().numpy()))
#        print('ybatchshape',y_batch.shape)
        loss = gauss_loss(y_batch,loc_train,scale_train)
#       print('loss',loss)
        loss.backward()
        optimizer.step()
        epoch_training_losses.append(loss.detach().cpu().numpy())
    training_nll.append(sum(epoch_training_losses)/len(epoch_training_losses))
    ## Step 2, validation
#     with torch.no_grad():
#         STmodel.eval()
#         val_input = val_input.to(device=device)
#         val_target = val_target.to(device=device)

#         loc_val,scale_val = STmodel(val_input,A_q,A_h)
# #        print(n_val)
#         val_loss    = gauss_loss(val_target,loc_val,scale_val).to(device="cpu")
#         validation_nll.append(np.asscalar(val_loss.detach().numpy()))

#         # Calculate the probability mass function for up to 35 vehicles
#         #y = range(36)
#         #probs = nbinom.pmf(y, n, p)

#         # Calculate the expectation value
#         val_pred = norm.mean(loc_val.detach().cpu().numpy(),scale_val.detach().cpu().numpy())
#         print(val_pred.mean())
#         # Calculate the 80% confidence interval
#         #lower, upper = nbinom.interval(0.8, n, p)
        
#         mae = np.mean(np.abs(val_pred - val_target.detach().cpu().numpy()))
#         validation_mae.append(mae)

#         n_val,p_val = None,None
#         val_input = val_input.to(device="cpu")
#         val_target = val_target.to(device="cpu")
    
#     print('Epoch %d: trainNLL %.5f; valNLL %.5f; mae %.4f'%(epoch,
    # training_nll[-1],validation_nll[-1],validation_mae[-1]))
    print('Epoch: {}'.format(epoch))
    print("Training loss: {}".format(training_nll[-1]))
    if np.asscalar(training_nll[-1]) == min(training_nll):
        best_model = copy.deepcopy(STmodel.state_dict())
    checkpoint_path = "checkpoints/"
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)
    with open("checkpoints/losses.pk", "wb") as fd:
        # pk.dump((training_nll, validation_nll, validation_mae), fd)
        pk.dump((training_nll), fd)
    if np.isnan(training_nll[-1]):
        break
STmodel.load_state_dict(best_model)
torch.save(STmodel,'pth/ST_Gauss_route20_30min_in4-out4-h4_20220221.pth')


from __future__ import division
import torch
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import torch.optim as optim
from torch import nn
import matplotlib.pyplot as plt
from utils import generate_dataset, get_normalized_adj, get_Laplace, calculate_random_walk_matrix,nb_nll_loss
from model import *
import random,os,copy
import math
import tqdm
from scipy.stats import nbinom
import pickle as pk
import os
os.environ["CUDA_VISIBLE_DEVICES"]='1'
# Parameters
torch.manual_seed(0)
device = torch.device('cuda') #use_gpu = False
#num_timesteps_input = 24
num_timesteps_output = 4 # num_timesteps_input # 12
num_timesteps_input = num_timesteps_output

# Load dataset
A = np.load('route20_5min/adj_rand0.npy')#np.load('ny_data_60min/adj_only10_rand0.npy')
X = np.load('route20_5min/cta_samp_rand0.npy')#np.load('ny_data_60min/cta_samp_only10_rand0.npy')

space_dim = X.shape[1]
batch_size = 4
hidden_dim_s = 42
hidden_dim_t = 42
rank_s = 20
rank_t = 4

epochs = 100

# Initial networks
TCN1 = B_TCN(space_dim, hidden_dim_t, kernel_size=3).to(device=device)
TCN2 = B_TCN(hidden_dim_t, rank_t, kernel_size = 3, activation = 'linear').to(device=device)
TCN3 = B_TCN(rank_t, hidden_dim_t, kernel_size= 3).to(device=device)
# TCN4 = B_TCN(hidden_dim_t, space_dim, kernel_size =6, activation = 'linear')
TNB = NBNorm(hidden_dim_t,space_dim).to(device=device)
SCN1 = D_GCN(num_timesteps_input, hidden_dim_s, 2).to(device=device)
SCN2 = D_GCN(hidden_dim_s, rank_s, 2, activation = 'linear').to(device=device)
SCN3 = D_GCN(rank_s, hidden_dim_s, 2).to(device=device)
# SCN4 = D_GCN(hidden_dim_s, num_timesteps_input, 3, activation = 'linear')
SNB = NBNorm(hidden_dim_s,num_timesteps_output).to(device=device)
STmodel = ST_NB(SCN1, SCN2, SCN3, TCN1, TCN2, TCN3, SNB,TNB).to(device=device)

# Load dataset
#A = np.load('data_full/adj_rand0.npy')#np.load('ny_data_60min/adj_only10_rand0.npy')
#X = np.load('data_full/cta_samp_rand0.npy')#np.load('ny_data_60min/cta_samp_only10_rand0.npy')
X = X.T
X = X.astype(np.float32)
X = X.reshape((X.shape[0],1,X.shape[1]))
split_line1 = int(X.shape[2] * 0.6)
split_line2 = int(X.shape[2] * 0.7)

print(X.shape,A.shape)
# normalization
max_value = np.max(X.shape[2] * 0.6)
#X = X/max_value
#means = np.mean(X, axis=(0, 2))
#X = X - means.reshape(1, -1, 1)
#stds = np.std(X, axis=(0, 2))
#X = X / stds.reshape(1, -1, 1)

train_original_data = X[:, :, :split_line1]
val_original_data = X[:, :, split_line1:split_line2]
test_original_data = X[:, :, split_line2:]
training_input, training_target = generate_dataset(train_original_data,
                                                    num_timesteps_input=num_timesteps_input,
                                                    num_timesteps_output=num_timesteps_output)
val_input, val_target = generate_dataset(val_original_data,
                                            num_timesteps_input=num_timesteps_input,
                                            num_timesteps_output=num_timesteps_output)
test_input, test_target = generate_dataset(test_original_data,
                                            num_timesteps_input=num_timesteps_input,
                                            num_timesteps_output=num_timesteps_output)
print('input shape: ',training_input.shape,val_input.shape,test_input.shape)
#training_input, training_target = training_input.type(torch.cuda.FloatTensor),training_target.type(torch.cuda.FloatTensor)
#val_input, val_target = val_input.type(torch.cuda.FloatTensor), val_target.type(torch.cuda.FloatTensor)
#test_input, test_target = test_input.type(torch.cuda.FloatTensor), test_target.type(torch.cuda.FloatTensor)

A_wave = get_normalized_adj(A)
A_q = torch.from_numpy((calculate_random_walk_matrix(A_wave).T).astype('float32'))
A_h = torch.from_numpy((calculate_random_walk_matrix(A_wave.T).T).astype('float32'))
A_q = A_q.to(device=device)
A_h = A_h.to(device=device)
# Define the training process
# criterion = nn.MSELoss()
optimizer = optim.Adam(STmodel.parameters(), lr=1e-3)
training_nll   = []
validation_nll = []
validation_mae = []

for epoch in range(epochs):
    ## Step 1, training
    """
    # Begin training, similar training procedure from STGCN
    Trains one epoch with the given data.
    :param training_input: Training inputs of shape (num_samples, num_nodes,
    num_timesteps_train, num_features).
    :param training_target: Training targets of shape (num_samples, num_nodes,
    num_timesteps_predict).
    :param batch_size: Batch size to use during training.
    """
    permutation = torch.randperm(training_input.shape[0])
    epoch_training_losses = []
    for i in range(0, training_input.shape[0], batch_size):
        print(i)
        STmodel.train()
        optimizer.zero_grad()

        indices = permutation[i:i + batch_size]
        X_batch, y_batch = training_input[indices], training_target[indices]
        X_batch = X_batch.to(device=device)
        

        n_train,p_train = STmodel(X_batch,A_q,A_h)
#       print('batch and n',np.mean(X_batch.detach().cpu().numpy()),np.mean(n_train.detach().cpu().numpy()))
#        print(np.mean(n_train.detach().cpu().numpy()))
#        print('ybatchshape',y_batch.shape)
        y_batch = y_batch.to(device=device)
        loss = nb_nll_loss(y_batch,n_train,p_train)
#       print('loss',loss)
        loss.backward()
        optimizer.step()
        epoch_training_losses.append(loss.detach().cpu().numpy())
    training_nll.append(sum(epoch_training_losses)/len(epoch_training_losses))
    ## Step 2, validation
#     STmodel.eval()
#     with torch.no_grad():
#         val_input = val_input.to(device=device)
#         val_target = val_target.to(device=device)

#         n_val,p_val = STmodel(val_input,A_q,A_h)
# #        print(n_val)
#         val_loss    = nb_nll_loss(val_target,n_val,p_val).to(device="cpu")
#         validation_nll.append(np.asscalar(val_loss.detach().numpy()))

#         # Calculate the probability mass function for up to 35 vehicles
#         #y = range(36)
#         #probs = nbinom.pmf(y, n, p)

#         # Calculate the expectation value
#         val_pred = nbinom.mean(n_val.detach().cpu().numpy(),p_val.detach().cpu().numpy())
#         print(val_pred.mean())
#         # Calculate the 80% confidence interval
#         #lower, upper = nbinom.interval(0.8, n, p)
        
#         mae = np.mean(np.abs(val_pred - val_target.detach().cpu().numpy()))
#         validation_mae.append(mae)

#         n_val,p_val = None,None
#         val_input = val_input.to(device="cpu")
#         val_target = val_target.to(device="cpu")
    
    # print('Epoch %d: trainNLL %.5f; valNLL %.5f; mae %.4f'%(epoch,
    # training_nll[-1],validation_nll[-1],validation_mae[-1]))
    print('Epoch: {}'.format(epoch))
    print("Training loss: {}".format(training_nll[-1]))
    if np.asscalar(training_nll[-1]) == min(training_nll):
        best_model = copy.deepcopy(STmodel.state_dict())
    checkpoint_path = "checkpoints/"
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)
    with open("checkpoints/losses.pk", "wb") as fd:
        # pk.dump((training_nll, validation_nll, validation_mae), fd)
        pk.dump((training_nll), fd)
    if np.isnan(training_nll[-1]):
        break
STmodel.load_state_dict(best_model)
torch.save(STmodel,'pth/STNB_route20_5min_h4_20220331.pth')


from __future__ import division
import torch
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import torch.optim as optim
from torch import nn
import matplotlib.pyplot as plt
from utils import generate_dataset, get_normalized_adj, get_Laplace, calculate_random_walk_matrix,gauss_loss
from model import *
import random,os,copy
import math
import tqdm
from scipy.stats import truncnorm
import pickle as pk
import os
os.environ["CUDA_VISIBLE_DEVICES"]='1'
# Parameters
torch.manual_seed(0)
device = torch.device('cuda') #use_gpu = False
#num_timesteps_input = 24
num_timesteps_output = 12 # num_timesteps_input # 12
num_timesteps_input = num_timesteps_output

space_dim = 100
batch_size = 12
hidden_dim_s = 42
hidden_dim_t = 42
rank_s = 20
rank_t = 4

epochs = 35 #500

# Initial networks
TCN1 = B_TCN(space_dim, hidden_dim_t, kernel_size=6).to(device=device)
TCN2 = B_TCN(hidden_dim_t, rank_t, kernel_size = 6, activation = 'linear').to(device=device)
TCN3 = B_TCN(rank_t, hidden_dim_t, kernel_size= 6).to(device=device)
# TCN4 = B_TCN(hidden_dim_t, space_dim, kernel_size =6, activation = 'linear')
TNB = GaussNorm(hidden_dim_t,space_dim).to(device=device)
SCN1 = D_GCN(num_timesteps_input, hidden_dim_s, 3).to(device=device)
SCN2 = D_GCN(hidden_dim_s, rank_s, 2, activation = 'linear').to(device=device)
SCN3 = D_GCN(rank_s, hidden_dim_s, 2).to(device=device)
# SCN4 = D_GCN(hidden_dim_s, num_timesteps_input, 3, activation = 'linear')
SNB = GaussNorm(hidden_dim_s,num_timesteps_output).to(device=device)
STmodel = ST_Gau(SCN1, SCN2, SCN3, TCN1, TCN2, TCN3, SNB,TNB).to(device=device)

# Load dataset
A = np.load('ny_data_60min/adj_only10_rand0.npy')
X = np.load('ny_data_60min/cta_samp_only10_rand0.npy')
X = X.T
X = X.astype(np.float32)
X = X.reshape((X.shape[0],1,X.shape[1]))
split_line1 = int(X.shape[2] * 0.6)
split_line2 = int(X.shape[2] * 0.7)

print(X.shape,A.shape)
# normalization
max_value = np.max(X.shape[2] * 0.6)
#X = X/max_value
#means = np.mean(X, axis=(0, 2))
#X = X - means.reshape(1, -1, 1)
#stds = np.std(X, axis=(0, 2))
#X = X / stds.reshape(1, -1, 1)

train_original_data = X[:, :, :split_line1]
val_original_data = X[:, :, split_line1:split_line2]
test_original_data = X[:, :, split_line2:]
training_input, training_target = generate_dataset(train_original_data,
                                                    num_timesteps_input=num_timesteps_input,
                                                    num_timesteps_output=num_timesteps_output)
val_input, val_target = generate_dataset(val_original_data,
                                            num_timesteps_input=num_timesteps_input,
                                            num_timesteps_output=num_timesteps_output)
test_input, test_target = generate_dataset(test_original_data,
                                            num_timesteps_input=num_timesteps_input,
                                            num_timesteps_output=num_timesteps_output)
print('input shape: ',training_input.shape,val_input.shape,test_input.shape)


A_wave = get_normalized_adj(A)
A_q = torch.from_numpy((calculate_random_walk_matrix(A_wave).T).astype('float32')).to(device=device)
A_h = torch.from_numpy((calculate_random_walk_matrix(A_wave.T).T).astype('float32')).to(device=device)
A_q = A_q.to(device=device)
A_h = A_h.to(device=device)
# Define the training process
# criterion = nn.MSELoss()
optimizer = optim.Adam(STmodel.parameters(), lr=1e-3)
training_nll   = []
validation_nll = []
validation_mae = []

for epoch in range(epochs):
    ## Step 1, training
    """
    # Begin training, similar training procedure from STGCN
    Trains one epoch with the given data.
    :param training_input: Training inputs of shape (num_samples, num_nodes,
    num_timesteps_train, num_features).
    :param training_target: Training targets of shape (num_samples, num_nodes,
    num_timesteps_predict).
    :param batch_size: Batch size to use during training.
    """
    permutation = torch.randperm(training_input.shape[0])
    epoch_training_losses = []
    for i in range(0, training_input.shape[0], batch_size):
        STmodel.train()
        optimizer.zero_grad()

        indices = permutation[i:i + batch_size]
        X_batch, y_batch = training_input[indices], training_target[indices]
        X_batch = X_batch.to(device=device)
        y_batch = y_batch.to(device=device)

        loc_train,scale_train = STmodel(X_batch,A_q,A_h)
#       print('batch and n',np.mean(X_batch.detach().cpu().numpy()),np.mean(n_train.detach().cpu().numpy()))
#        print(np.mean(n_train.detach().cpu().numpy()))
#        print('ybatchshape',y_batch.shape)
        loss = gauss_loss(y_batch,loc_train,scale_train)
#       print('loss',loss)
        loss.backward()
        optimizer.step()
        epoch_training_losses.append(loss.detach().cpu().numpy())
    training_nll.append(sum(epoch_training_losses)/len(epoch_training_losses))
    ## Step 2, validation
    with torch.no_grad():
        STmodel.eval()
        val_input = val_input.to(device=device)
        val_target = val_target.to(device=device)

        loc_val,scale_val = STmodel(val_input,A_q,A_h)
#        print(n_val)
        val_loss    = gauss_loss(val_target,loc_val,scale_val).to(device="cpu")
        validation_nll.append(np.asscalar(val_loss.detach().numpy()))

        # Calculate the probability mass function for up to 35 vehicles
        #y = range(36)
        #probs = nbinom.pmf(y, n, p)

        # Calculate the expectation value
        a,b = 0,np.inf
        val_pred = truncnorm.mean(a=a,b=b,loc=loc_val.detach().cpu().numpy(),scale=scale_val.detach().cpu().numpy())
        print(val_pred.mean())
        # Calculate the 80% confidence interval
        #lower, upper = nbinom.interval(0.8, n, p)
        
        mae = np.mean(np.abs(val_pred - val_target.detach().cpu().numpy()))
        validation_mae.append(mae)

        n_val,p_val = None,None
        val_input = val_input.to(device="cpu")
        val_target = val_target.to(device="cpu")
    
    print('Epoch %d: trainNLL %.5f; valNLL %.5f; mae %.4f'%(epoch,
    training_nll[-1],validation_nll[-1],validation_mae[-1]))
    if np.asscalar(val_loss.detach().numpy()) == min(validation_nll):
        best_model = copy.deepcopy(STmodel.state_dict())
    checkpoint_path = "checkpoints/"
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)
    with open("checkpoints/losses.pk", "wb") as fd:
        pk.dump((training_nll, validation_nll, validation_mae), fd)

STmodel.load_state_dict(best_model)
torch.save(STmodel,'pth/ST_Truncnorm_ny_60min_samp_only10_in12-out12-h12_nonorm_20210901.pth')


from __future__ import division
import torch
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import torch.optim as optim
from torch import nn
import matplotlib.pyplot as plt
from utils import generate_dataset, get_normalized_adj, get_Laplace, calculate_random_walk_matrix, tweedie_nll_loss
from model import *
import random,os,copy
import math
import tqdm
from scipy.stats import nbinom
import pickle as pk
import os
# os.environ["CUDA_VISIBLE_DEVICES"]='2'
# Parameters
torch.manual_seed(0)
device = torch.device('cuda:2')
A = np.load('STTD-main/ny_data_full_15min/adj_rand0.npy') # change the loading folder
X = np.load('STTD-main/ny_data_full_15min/cta_samp_rand0.npy')

num_timesteps_output = 4
num_timesteps_input = 4

space_dim = X.shape[1]
batch_size = 200
hidden_dim_s = 42
hidden_dim_t = 42
rank_s = 20
rank_t = 4

epochs = 100

# Initial networks
TCN1 = B_TCN(space_dim, hidden_dim_t, kernel_size=3, device=device).to(device=device)
TCN2 = B_TCN(hidden_dim_t, rank_t, kernel_size = 3, activation = 'linear', device=device).to(device=device)
TCN3 = B_TCN(rank_t, hidden_dim_t, kernel_size= 3, device=device).to(device=device)
TNB = NBNorm_ZeroInflated(hidden_dim_t,space_dim).to(device=device)
SCN1 = D_GCN(num_timesteps_input, hidden_dim_s, 3).to(device=device)
SCN2 = D_GCN(hidden_dim_s, rank_s, 2, activation = 'linear').to(device=device)
SCN3 = D_GCN(rank_s, hidden_dim_s, 2).to(device=device)
SNB = NBNorm_ZeroInflated(hidden_dim_s,num_timesteps_output).to(device=device)
STmodel = ST_TWEEDIE(SCN1, SCN2, SCN3, TCN1, TCN2, TCN3, SNB,TNB).to(device=device)

# Load dataset

X = X.T
X = X.astype(np.float32)
X = X.reshape((X.shape[0],1,X.shape[1]))

split_line1 = int(X.shape[2] * 0.60)
split_line2 = int(X.shape[2] * 0.70)
print(X.shape,A.shape)

# normalization
max_value = np.max(X[:, :, :split_line1])

train_original_data = X[:, :, :split_line1]
val_original_data = X[:, :, split_line1:split_line2]
test_original_data = X[:, :, split_line2:]
training_input, training_target = generate_dataset(train_original_data,
                                                    num_timesteps_input=num_timesteps_input,
                                                    num_timesteps_output=num_timesteps_output)
val_input, val_target = generate_dataset(val_original_data,
                                            num_timesteps_input=num_timesteps_input,
                                            num_timesteps_output=num_timesteps_output)
test_input, test_target = generate_dataset(test_original_data,
                                            num_timesteps_input=num_timesteps_input,
                                            num_timesteps_output=num_timesteps_output)
print('input shape: ',training_input.shape,val_input.shape,test_input.shape)


A_wave = get_normalized_adj(A)
A_q = torch.from_numpy((calculate_random_walk_matrix(A_wave).T).astype('float32'))
A_h = torch.from_numpy((calculate_random_walk_matrix(A_wave.T).T).astype('float32'))
A_q = A_q.to(device=device)
A_h = A_h.to(device=device)
# Define the training process
# criterion = nn.MSELoss()
optimizer = optim.Adam(STmodel.parameters(), lr=1e-3)
training_nll   = []
validation_nll = []
validation_mae = []

for epoch in range(epochs):
    ## Step 1, training
    """
    # Begin training, similar training procedure from STGCN
    Trains one epoch with the given data.
    :param training_input: Training inputs of shape (num_samples, num_nodes,
    num_timesteps_train, num_features).
    :param training_target: Training targets of shape (num_samples, num_nodes,
    num_timesteps_predict).
    :param batch_size: Batch size to use during training.
    """
    permutation = torch.randperm(training_input.shape[0])
    epoch_training_losses = []
    training_input = training_input.cuda()
    for i in range(0, training_input.shape[0], batch_size):
        STmodel.train()
        optimizer.zero_grad()

        indices = permutation[i:i + batch_size]
        X_batch, y_batch = training_input[indices], training_target[indices]
        X_batch = X_batch.to(device=device)
        y_batch = y_batch.to(device=device)

        n_train,p_train,pi_train,_ = STmodel(X_batch,A_q,A_h)
        loss = tweedie_nll_loss(y_batch,n_train,p_train,pi_train)

        loss.backward()
        optimizer.step()
        epoch_training_losses.append(loss.detach().cpu().numpy())
        # print("epoch:{}, indices:{}, loss:{:.4f}".format(epoch, i, loss.item()))

    training_nll.append(sum(epoch_training_losses)/len(epoch_training_losses))
    training_input = training_input.cpu()
    torch.cuda.empty_cache()
    ## Step 2, validation
    with torch.no_grad():
        STmodel.eval()
        val_input = val_input.to(device=device)
        val_target = val_target.to(device=device)

        n_val,p_val,pi_val,_ = STmodel(val_input,A_q,A_h)
        print('Pi_val,mean,min,max',torch.mean(pi_val),torch.min(pi_val),torch.max(pi_val))

        val_loss = nb_tweedie_nll_loss(val_target,n_val,p_val,pi_val).to(device="cpu")
        validation_nll.append(np.asscalar(val_loss.detach().numpy()))
        pi_val = torch.exp(pi_val)      # fixme
        mae = np.mean(np.abs(pi_val.detach().cpu().numpy() - val_target.detach().cpu().numpy()))
        validation_mae.append(mae)
        # print(mae)
        n_val,p_val,pi_val = None,None,None
        val_input = val_input.to(device="cpu")
        val_target = val_target.to(device="cpu")
    print('Epoch: {}'.format(epoch))
    print("Training loss: {}".format(training_nll[-1]))
    print('Epoch %d: trainNLL %.5f; valNLL %.5f; mae %.4f'%(epoch,training_nll[-1],validation_nll[-1],validation_mae[-1]))
    if np.asscalar(training_nll[-1]) == min(training_nll):
        best_model = copy.deepcopy(STmodel.state_dict())
    checkpoint_path = "checkpoints/"
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)
    with open("checkpoints/losses.pk", "wb") as fd:
        pk.dump((training_nll, validation_nll, validation_mae), fd)
    if np.isnan(training_nll[-1]):
        break
STmodel.load_state_dict(best_model)
torch.save(STmodel,'pth/STZINB_ny_full_5min.pth')


from __future__ import division
import torch
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import torch.optim as optim
from torch import nn
import matplotlib.pyplot as plt
from utils import generate_dataset, get_normalized_adj, get_Laplace, calculate_random_walk_matrix,nb_zeroinflated_nll_loss,nb_zeroinflated_draw, nb_tweedie_nll_loss
from model import *
from utils import *
import heapq
import random,os,copy
import math
import tqdm
from scipy.stats import nbinom
import pickle as pk
import os
# os.environ["CUDA_VISIBLE_DEVICES"]='2'
# Parameters
torch.manual_seed(0)
device = torch.device('cuda:0')
# 先导入邻接矩阵
A = np.load('STTD-main/ny_data_full_15min/adj_rand0.npy') # change the loading folder
# 再导入特矩阵
X = np.load('STTD-main/ny_data_full_15min/cta_samp_rand0.npy')

num_timesteps_output = 3                            # 预测输出天数
num_timesteps_input = num_timesteps_output          # 预测输入天数

space_dim = X.shape[1]
batch_size = 40                                       # 每次输入的batch_size
hidden_dim_s = 42                                    # spatial的hidden
hidden_dim_t = 42                                    # temporal的hidden
rank_s = 20                                         # todo
rank_t = 4                                          # todo

epochs = 1

# Initial networks
# TCN是对每一个节点的temporal编码
TCN1 = B_TCN(space_dim, hidden_dim_t, kernel_size=3, device=device).to(device=device)
TCN2 = B_TCN(hidden_dim_t, rank_t, kernel_size = 3, activation = 'linear', device=device).to(device=device)
TCN3 = B_TCN(rank_t, hidden_dim_t, kernel_size= 3, device=device).to(device=device)
# TNB是把模型输出修改为固定模式
TNB = NBNorm_ZeroInflated(hidden_dim_t,space_dim).to(device=device)
# SCN是spatial的编码
SCN1 = D_GCN(num_timesteps_input, hidden_dim_s, 3).to(device=device)
SCN2 = D_GCN(hidden_dim_s, rank_s, 2, activation = 'linear').to(device=device)
SCN3 = D_GCN(rank_s, hidden_dim_s, 2).to(device=device)
# SNB把模型输出修改为固定模式
SNB = NBNorm_ZeroInflated(hidden_dim_s,num_timesteps_output).to(device=device)
# 最后参数推导
STmodel = ST_NB_ZeroInflated(SCN1, SCN2, SCN3, TCN1, TCN2, TCN3, SNB,TNB).to(device=device)

# Load dataset

X = X.T
X = X.astype(np.float32)
X = X.reshape((X.shape[0],1,X.shape[1]))

# 划分train-test-valid集
split_line1 = int(X.shape[2] * 0.6)
split_line2 = int(X.shape[2] * 0.7)
print(X.shape,A.shape)

# normalization
max_value = np.max(X[:, :, :split_line1])           # 得到训练集的最大值

train_original_data = X[:, :, :split_line1]
val_original_data = X[:, :, split_line1:split_line2]
test_original_data = X[:, :, split_line2:]
# generate_dataset的意思是把模型固定为一个格式
training_input, training_target = generate_dataset(train_original_data,
                                                    num_timesteps_input=num_timesteps_input,
                                                    num_timesteps_output=num_timesteps_output)
val_input, val_target = generate_dataset(val_original_data,
                                            num_timesteps_input=num_timesteps_input,
                                            num_timesteps_output=num_timesteps_output)
test_input, test_target = generate_dataset(test_original_data,
                                            num_timesteps_input=num_timesteps_input,
                                            num_timesteps_output=num_timesteps_output)
print('input shape: ',training_input.shape,val_input.shape,test_input.shape)

# 归一化adj
A_wave = get_normalized_adj(A)
A_q = torch.from_numpy((calculate_random_walk_matrix(A_wave).T).astype('float32'))
A_h = torch.from_numpy((calculate_random_walk_matrix(A_wave.T).T).astype('float32'))
A_q = A_q.to(device=device)
A_h = A_h.to(device=device)
# Define the training process
# criterion = nn.MSELoss()
optimizer = optim.Adam(STmodel.parameters(), lr=1e-4)       # 创建优化器
training_nll   = []
validation_nll = []
validation_mae = []

for epoch in range(epochs):
    ## Step 1, training
    """
    # Begin training, similar training procedure from STGCN
    Trains one epoch with the given data.
    :param training_input: Training inputs of shape (num_samples, num_nodes,
    num_timesteps_train, num_features).
    :param training_target: Training targets of shape (num_samples, num_nodes,
    num_timesteps_predict).
    :param batch_size: Batch size to use during training.
    """
    permutation = torch.randperm(training_input.shape[0])           # 全排列
    epoch_training_losses = []
    for i in range(0, training_input.shape[0], batch_size):
        STmodel.train()                                             # 模型训练 TODO
        optimizer.zero_grad()

        indices = permutation[i:i + batch_size]
        X_batch, y_batch = training_input[indices], training_target[indices]
        X_batch = X_batch.to(device=device)
        y_batch = y_batch.to(device=device)

        n_train,p_train,pi_train = STmodel(X_batch.cuda(),A_q,A_h)  # 得到输出
        loss = nb_zeroinflated_nll_loss(y_batch,n_train,p_train,pi_train)   # 计算损失
        loss.backward()
        optimizer.step()
        epoch_training_losses.append(loss.detach().cpu().numpy())
        # print("epoch:{}, indices:{}, loss:{:.4f}".format(epoch, i, loss.item()))
    training_nll.append(sum(epoch_training_losses)/len(epoch_training_losses))
    torch.cuda.empty_cache()
    ## Step 2, validation
    with torch.no_grad():
        STmodel.eval()
        val_input = val_input.to(device=device)
        val_target = val_target.to(device=device)

        n_val,p_val,pi_val = STmodel(val_input,A_q,A_h)
        print('Pi_val,mean,min,max',torch.mean(pi_val),torch.min(pi_val),torch.max(pi_val))
        val_loss = nb_zeroinflated_nll_loss(val_target,n_val,p_val,pi_val).to(device="cpu")
        validation_nll.append(np.asscalar(val_loss.detach().numpy()))

        # Calculate the expectation value        
        val_pred = (1-pi_val.detach().cpu().numpy())*(n_val.detach().cpu().numpy()/p_val.detach().cpu().numpy()-n_val.detach().cpu().numpy()) # pipred
        # print(val_pred.mean(),pi_val.detach().cpu().numpy().min())
        mae = np.mean(np.abs(val_pred - val_target.detach().cpu().numpy()))

        print_errors(val_target.detach().cpu().numpy(), val_pred)

        # mae = np.median(np.abs(val_pred - val_target.detach().cpu().numpy()))
        print(mae)
        validation_mae.append(mae)

        n_val,p_val,pi_val = None,None,None
        val_input = val_input.to(device="cpu")
        val_target = val_target.to(device="cpu")

        hit_rate = 0.2
        val_target1 = val_target.reshape(-1)
        non_zero_mask = val_target1.detach().cpu().numpy()>0
        pred1 = val_pred
        pred1 = pred1.reshape(-1)
        # pred_index = heapq.nlargest(20, range(len(pred1)), pred1.__getitem__)
        pred_index = heapq.nlargest(int(hit_rate*len(pred1)), range(len(pred1)), pred1.__getitem__)

        hit_mask = np.zeros_like(non_zero_mask)
        hit_mask[pred_index] = 1

        ans = (hit_mask * non_zero_mask).sum() / hit_mask.sum()
        ans2 = (hit_mask * non_zero_mask).sum() / non_zero_mask.sum()
        UP =  ((hit_mask * non_zero_mask).sum()/non_zero_mask.sum()) / (non_zero_mask.sum()/len(non_zero_mask))
        print("HR:", ans, "Find Ratio:", ans2)
        # print("UP:", ans/(non_zero_mask.sum()/len(non_zero_mask)))


    print('Epoch: {}'.format(epoch))
    print("Training loss: {}".format(training_nll[-1]))
    print('Epoch %d: trainNLL %.5f; valNLL %.5f; mae %.4f'%(epoch,training_nll[-1],validation_nll[-1],validation_mae[-1]))
    if np.asscalar(training_nll[-1]) == min(training_nll):
        best_model = copy.deepcopy(STmodel.state_dict())
    checkpoint_path = "checkpoints/"
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)
    with open("checkpoints/losses.pk", "wb") as fd:
        pk.dump((training_nll, validation_nll, validation_mae), fd)
    if np.isnan(training_nll[-1]):
        break
STmodel.load_state_dict(best_model)
torch.save(STmodel,'pth/STZINB_ny_full_5min.pth')


from __future__ import division
import torch
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import torch.optim as optim
from torch import nn
import matplotlib.pyplot as plt
from utils import generate_dataset, get_normalized_adj, get_Laplace, calculate_random_walk_matrix,nb_zeroinflated_nll_loss,nb_zeroinflated_draw, nb_tweedie_nll_loss,nb_newtweedie_nll_loss,nb_zitweedie_nll_loss,nb_zitd_nll
from model import *
from utils import *
import random,os,copy
import math
import tqdm
from scipy.stats import nbinom
import pickle as pk
import os
# os.environ["CUDA_VISIBLE_DEVICES"]='2'
# Parameters
torch.manual_seed(0)
device = torch.device('cuda:3')
A = np.load('STZINB-main/ny_data_full_15min/adj_rand0.npy') # change the loading folder
X = np.load('STZINB-main/ny_data_full_15min/cta_samp_rand0.npy')

num_timesteps_output = 4
num_timesteps_input = 4

space_dim = X.shape[1]
batch_size = 10
hidden_dim_s = 20
hidden_dim_t = 20
rank_s = 20
rank_t = 4

epochs = 2000

# Initial networks
TCN1 = B_TCN(space_dim, hidden_dim_t, kernel_size=3, device=device).to(device=device)
TCN2 = B_TCN(hidden_dim_t, rank_t, kernel_size = 3, activation = 'linear', device=device).to(device=device)
TCN3 = B_TCN(rank_t, hidden_dim_t, kernel_size= 3, device=device).to(device=device)
TNB = NBNorm_ZeroInflated(hidden_dim_t,space_dim, four=True).to(device=device)
SCN1 = D_GCN(num_timesteps_input, hidden_dim_s, 3, att=False).to(device=device)
SCN2 = D_GCN(hidden_dim_s, rank_s, 3, activation = 'linear', att=True).to(device=device)
SCN3 = D_GCN(rank_s, hidden_dim_s, 2, att=True).to(device=device)
SNB = NBNorm_ZeroInflated(hidden_dim_s,num_timesteps_output, four=True).to(device=device)
STmodel = ST_new_TWEEDIE_ZeroInflated(SCN1, SCN2, SCN3, TCN1, TCN2, TCN3, SNB,TNB, four=True).to(device=device)

# Load data

X = X.T
X = X.astype(np.float32)
X = X.reshape((X.shape[0],1,X.shape[1]))

split_line1 = int(X.shape[2] * 0.60)
split_line2 = int(X.shape[2] * 0.70)
print(X.shape,A.shape)

# normalization
max_value = np.max(X[:, :, :split_line1])

train_original_data = X[:, :, :split_line1]
val_original_data = X[:, :, split_line1:split_line2]
test_original_data = X[:, :, split_line2:]
training_input, training_target = generate_dataset(train_original_data,
                                                    num_timesteps_input=num_timesteps_input,
                                                    num_timesteps_output=num_timesteps_output)
val_input, val_target = generate_dataset(val_original_data,
                                            num_timesteps_input=num_timesteps_input,
                                            num_timesteps_output=num_timesteps_output)
test_input, test_target = generate_dataset(test_original_data,
                                            num_timesteps_input=num_timesteps_input,
                                            num_timesteps_output=num_timesteps_output)
print('input shape: ',training_input.shape,val_input.shape,test_input.shape)


A_wave = get_normalized_adj(A)
A_q = torch.from_numpy((calculate_random_walk_matrix(A_wave).T).astype('float32'))
A_h = torch.from_numpy((calculate_random_walk_matrix(A_wave.T).T).astype('float32'))
A_q = A_q.to(device=device)
A_h = A_h.to(device=device)
# Define the training process
# criterion = nn.MSELoss()
optimizer = optim.Adam(STmodel.parameters(), lr=1e-4, weight_decay=1e-4)
training_nll   = []
validation_nll = []
validation_mae = []

for epoch in range(epochs):
    ## Step 1, training
    """
    # Begin training, similar training procedure from STGCN
    Trains one epoch with the given data.
    :param training_input: Training inputs of shape (num_samples, num_nodes,
    num_timesteps_train, num_features).
    :param training_target: Training targets of shape (num_samples, num_nodes,
    num_timesteps_predict).
    :param batch_size: Batch size to use during training.
    """
    permutation = torch.randperm(training_input.shape[0])
    epoch_training_losses = []
    training_input = training_input.cuda()
    for i in range(0, training_input.shape[0], batch_size):
        STmodel.train()
        optimizer.zero_grad()

        indices = permutation[i:i + batch_size]
        X_batch, y_batch = training_input[indices], training_target[indices]
        X_batch = X_batch.to(device=device)
        y_batch = y_batch.to(device=device)

        n_train,p_train,pi_train,zi_train = STmodel(X_batch,A_q,A_h)
        # loss = nb_zeroinflated_nll_loss(y_batch,n_train,p_train,pi_train)
        pi_train = torch.clip(pi_train, -15, 5)     # TODO 这个记得改回去！！fixme 一定
        loss = nb_zitd_nll(y_batch,n_train,p_train,pi_train,zi_train)
        # if torch.isnan(loss):
            # print(i)
        loss.backward()
        optimizer.step()
        epoch_training_losses.append(loss.detach().cpu().numpy())
        # print("epoch:{}, indices:{}, loss:{:.4f}".format(epoch, i, loss.item()))

    if epoch % 10 == 1:
        draw_3d_graph(y_batch.reshape(-1), phi=n_train.reshape(-1), rho=p_train.reshape(-1), mu=pi_train.reshape(-1))

    training_nll.append(sum(epoch_training_losses)/len(epoch_training_losses))
    training_input = training_input.cpu()
    torch.cuda.empty_cache()
    ## Step 2, validation
    with torch.no_grad():
        STmodel.eval()
        val_input = val_input.to(device=device)
        val_target = val_target.to(device=device)

        n_val,p_val,pi_val, zi_val = STmodel(val_input,A_q,A_h)
        val_loss = nb_zitd_nll(val_target,n_val,p_val,pi_val,zi_val).to(device="cpu")

        pi_val = torch.clip(pi_val, -10, 4)
        pi_val = torch.exp(pi_val)    

        print('Distribution_val,mean,min,max',torch.mean(pi_val),torch.min(pi_val),torch.max(pi_val))
        print('Pi_val,mean,min,max',torch.mean(pi_val),torch.min(pi_val),torch.max(pi_val))
        print('phi_val,mean,min,max',torch.mean(n_val),torch.min(n_val),torch.max(n_val))
        print('rou_val,mean,min,max',torch.mean(p_val),torch.min(p_val),torch.max(p_val))

        validation_nll.append(np.asscalar(val_loss.detach().numpy()))
        # Calculate the expectation value
        # val_pred = (1-pi_val.detach().cpu().numpy())*(n_val.detach().cpu().numpy()/p_val.detach().cpu().numpy()-n_val.detach().cpu().numpy()) # pipred
        # print(val_pred.mean(),pi_val.detach().cpu().numpy().min())
        # mae = np.mean(np.abs(val_pred - val_target.detach().cpu().numpy()))
        print_errors(val_target.detach().cpu().numpy(), (1-zi_val.detach().cpu().numpy()) * pi_val.detach().cpu().numpy())
        mae = np.mean(np.abs((1-zi_val.detach().cpu().numpy()) * pi_val.detach().cpu().numpy() - val_target.detach().cpu().numpy()))
        # mae = np.median(np.abs((1-zi_val.detach().cpu().numpy()) * pi_val.detach().cpu().numpy() - val_target.detach().cpu().numpy()))
        validation_mae.append(mae)
        # print(mae)
        n_val,p_val,pi_val = None,None,None
        val_input = val_input.to(device="cpu")
        val_target = val_target.to(device="cpu")
    print('Epoch: {}'.format(epoch))
    print("Training loss: {}".format(training_nll[-1]))
    print('Epoch %d: trainNLL %.5f; valNLL %.5f; mae %.4f'%(epoch,training_nll[-1],validation_nll[-1],validation_mae[-1]))
    if np.asscalar(training_nll[-1]) == min(training_nll):
        best_model = copy.deepcopy(STmodel.state_dict())
    checkpoint_path = "checkpoints/"
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)
    with open("checkpoints/losses.pk", "wb") as fd:
        pk.dump((training_nll, validation_nll, validation_mae), fd)
    if np.isnan(training_nll[-1]):
        break

    if epoch % 1000 == 400:
        ##### Test #####
        print("TEST")
        STmodel.eval()
        with torch.no_grad():
            test_input = test_input.to(device='cpu')  # .to(device=device)
            test_target = test_target.to(device='cpu')  # .to(device=device)
            print(test_input.is_cuda, A_q.is_cuda, A_h.is_cuda)

            test_loss_all = []
            test_pred_all = np.zeros_like(test_target)
            n_test_all = np.zeros_like(test_target)
            p_test_all = np.zeros_like(test_target)
            pi_test_all = np.zeros_like(test_target)
            print(test_input.shape, test_target.shape)
            for i in range(0, test_input.shape[0], batch_size):
                x_batch = test_input[i:i + batch_size]
                x_batch = x_batch.to(device)
                n_test, p_test, pi_test, zi_test = STmodel(x_batch, A_q, A_h)
                pi_test = torch.clip(pi_test, -10, 4)
                test_loss = nb_zitd_nll(test_target[i:i + batch_size].to(device), n_test, p_test,
                                                   pi_test, zi_test).to(
                    device="cpu")
                test_loss = np.asscalar(test_loss.detach().numpy())

                pi_test = torch.exp(pi_test)
                mean_pred = (1-zi_test.detach().cpu().numpy()) * pi_test.detach().cpu().numpy()

                test_pred_all[i:i + batch_size] = mean_pred  # test_pred_all是均值
                n_test_all[i:i + batch_size] = n_test.detach().cpu().numpy()
                p_test_all[i:i + batch_size] = p_test.detach().cpu().numpy()
                pi_test_all[i:i + batch_size] = pi_test.detach().cpu().numpy()  # fixme
                test_loss_all.append(test_loss)

            # The error of each horizon
            mae_list = []
            rmse_list = []
            mape_list = []
            for horizon in range(test_pred_all.shape[2]):
                mae = np.mean(
                    np.abs(test_pred_all[:, :, horizon] - test_target[:, :, horizon].detach().cpu().numpy()))
                rmse = np.sqrt(
                    np.mean(test_pred_all[:, :, horizon] - test_target[:, :, horizon].detach().cpu().numpy()))
                mape = np.mean(
                    np.abs((test_pred_all[:, :, horizon] - test_target[:, :, horizon].detach().cpu().numpy()) / (
                            test_target[:, :, horizon].detach().cpu().numpy() + 1e-5)))
                mae_list.append(mae)
                rmse_list.append(rmse)
                mape_list.append(mape)
                print('Horizon %d MAE:%.4f RMSE:%.4f MAPE:%.4f' % (horizon, mae, rmse, mape))
            print('Overall score: NLL %.5f; mae %.4f; rmse %.4f; mape %.4f' % (
                test_loss, np.mean(mae_list), np.mean(rmse_list), np.mean(mape_list)))

np.savez_compressed('output/ny_full_5min_ZISTNB', target=test_target.detach().cpu().numpy(), max_value=max_value,
                    mean_pred=test_pred_all, n=n_test_all, p=p_test_all, pi=pi_test_all)

STmodel.load_state_dict(best_model)
torch.save(STmodel,'pth/STZINB_ny_full_5min.pth')


import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable
import sys
import math
from scipy.stats import nbinom
from torch.nn.utils import weight_norm

# Define the NB class first, not mixture version
class NBNorm(nn.Module):
    def __init__(self, c_in, c_out):
        super(NBNorm,self).__init__()
        self.c_in = c_in
        self.c_out = c_out
        self.n_conv = nn.Conv2d(in_channels=c_in,
                                    out_channels=c_out,
                                    kernel_size=(1,1),
                                    bias=True)
        
        self.p_conv = nn.Conv2d(in_channels=c_in,
                                    out_channels=c_out,
                                    kernel_size=(1,1),
                                    bias=True)
        self.out_dim = c_out # output horizon

    def forward(self,x):
        x = x.permute(0,2,1,3)
        (B, _, N,_) = x.shape # B: batch_size; N: input nodes
        n = self.n_conv(x).squeeze_(-1)
        p = self.p_conv(x).squeeze_(-1)

        # Reshape
        n = n.view([B,self.out_dim,N])
        p = p.view([B,self.out_dim,N])

        # Ensure n is positive and p between 0 and 1
        n = F.softplus(n) # Some parameters can be tuned here
        p = F.sigmoid(p)
        return n.permute([0,2,1]), p.permute([0,2,1])

    def likelihood_loss(self,y,n,p,y_mask=None):
        """
        y: true values
        y_mask: whether missing mask is given
        """
        nll = torch.lgamma(n) + torch.lgamma(y+1) - torch.lgamma(n+y) - n*torch.log(p) - y*torch.log(1-p)
        if y_mask is not None:
            nll = nll*y_mask
        return torch.sum(nll)

    def mean(self,n,p):
        """
        :param cat: Input data of shape (batch_size, num_timesteps, in_nodes)
        :return: Output data of shape (batch_size, 1, num_timesteps, in_nodes)
        """ 
        pass

# Define the Gaussian 
class GaussNorm(nn.Module):
    def __init__(self, c_in, c_out):
        super(GaussNorm,self).__init__()
        self.c_in = c_in
        self.c_out = c_out
        self.n_conv = nn.Conv2d(in_channels=c_in,
                                    out_channels=c_out,
                                    kernel_size=(1,1),
                                    bias=True)
        
        self.p_conv = nn.Conv2d(in_channels=c_in,
                                    out_channels=c_out,
                                    kernel_size=(1,1),
                                    bias=True)
        self.out_dim = c_out # output horizon

    def forward(self,x):
        x = x.permute(0,2,1,3)
        (B, _, N,_) = x.shape # B: batch_size; N: input nodes
        loc    = self.n_conv(x).squeeze_(-1) # The location (loc) keyword specifies the mean. The scale (scale) keyword specifies the standard deviation.
        scale  = self.p_conv(x).squeeze_(-1)

        # Reshape
        loc   = loc.view([B,self.out_dim,N])
        scale = scale.view([B,self.out_dim,N])

        # Ensure n is positive and p between 0 and 1
        loc = F.softplus(loc) # Some parameters can be tuned here, count data are always positive
        # loc = F.sigmoid(loc) # Some parameters can be tuned here, count data are always positive
        scale = F.sigmoid(scale)
        return loc.permute([0,2,1]), scale.permute([0,2,1])

# Define the NB class first, not mixture version
class NBNorm_ZeroInflated(nn.Module):
    def __init__(self, c_in, c_out, four=False):
        super(NBNorm_ZeroInflated,self).__init__()
        self.c_in = c_in
        self.c_out = c_out
        self.n_conv = nn.Conv2d(in_channels=c_in,
                                    out_channels=c_out,
                                    kernel_size=(1,1),
                                    bias=True)
        
        self.p_conv = nn.Conv2d(in_channels=c_in,
                                    out_channels=c_out,
                                    kernel_size=(1,1),
                                    bias=True)

        self.pi_conv = nn.Conv2d(in_channels=c_in,
                                    out_channels=c_out,
                                    kernel_size=(1,1),
                                    bias=True)
        self.four = four

        if four:
            self.zero_conv = nn.Conv2d(in_channels=c_in,
                                        out_channels=c_out,
                                        kernel_size=(1,1),
                                        bias=True)

        self.out_dim = c_out # output horizon

    def forward(self,x):
        x = x.permute(0,2,1,3)
        (B, _, N,_) = x.shape # B: batch_size; N: input nodes
        n  = self.n_conv(x).squeeze_(-1)
        p  = self.p_conv(x).squeeze_(-1)
        pi = self.pi_conv(x).squeeze_(-1)

        # Reshape
        n = n.view([B,self.out_dim,N])
        p = p.view([B,self.out_dim,N])
        pi = pi.view([B,self.out_dim,N])

        if self.four:
            zi = self.zero_conv(x).squeeze_(-1)
            zi = zi.view([B,self.out_dim,N])
            zi = F.sigmoid(zi)

        # Ensure n is positive and p between 0 and 1
        if not self.four:
            n = F.softplus(n)  # Some parameters can be tuned here     # fixme
            p = F.sigmoid(p)
            pi = F.sigmoid(pi)      # todo
        if self.four:
            return n.permute([0,2,1]), p.permute([0,2,1]), pi.permute([0,2,1]), zi.permute([0,2,1])
        else:
            return n.permute([0,2,1]), p.permute([0,2,1]), pi.permute([0,2,1])

class D_GCN(nn.Module):
    """
    Neural network block that applies a diffusion graph convolution to sampled location
    """       
    def __init__(self, in_channels, out_channels, orders, activation = 'relu', att=False):
        """
        :param in_channels: Number of time step.
        :param out_channels: Desired number of output features at each node in
        each time step.
        :param order: The diffusion steps.
        """
        super(D_GCN, self).__init__()
        self.orders = orders
        self.activation = activation
        self.num_matrices = 2 * self.orders + 1
        self.Theta1 = nn.Parameter(torch.FloatTensor(in_channels * self.num_matrices,
                                             out_channels))
        self.bias = nn.Parameter(torch.FloatTensor(out_channels))
        self.att = att
        self.reset_parameters()
        
    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.Theta1.shape[1])
        self.Theta1.data.uniform_(-stdv, stdv)
        stdv1 = 1. / math.sqrt(self.bias.shape[0])
        self.bias.data.uniform_(-stdv1, stdv1)
        
    def _concat(self, x, x_):
        x_ = x_.unsqueeze(0)
        return torch.cat([x, x_], dim=0)
        
    def forward(self, X, A_q, A_h):
        """
        :param X: Input data of shape (batch_size, num_nodes, num_timesteps)
        :A_q: The forward random walk matrix (num_nodes, num_nodes)
        :A_h: The backward random walk matrix (num_nodes, num_nodes)
        :return: Output data of shape (batch_size, num_nodes, num_features)
        """
        batch_size = X.shape[0] # batch_size
        num_node = X.shape[1]
        input_size = X.size(2)  # time_length
        supports = []
        supports.append(A_q)
        supports.append(A_h)
        
        x0 = X.permute(1, 2, 0) #(num_nodes, num_times, batch_size)
        x0 = torch.reshape(x0, shape=[num_node, input_size * batch_size])
        x = torch.unsqueeze(x0, 0)
        for support in supports:
            x1 = torch.mm(support, x0)
            x = self._concat(x, x1)
            for k in range(2, self.orders + 1):
                x2 = 2 * torch.mm(support, x1) - x0
                x = self._concat(x, x2)
                x1, x0 = x2, x1
                
        x = torch.reshape(x, shape=[self.num_matrices, num_node, input_size, batch_size])
        x = x.permute(3, 1, 2, 0)  # (batch_size, num_nodes, input_size, order)
        x = torch.reshape(x, shape=[batch_size, num_node, input_size * self.num_matrices])         
        x = torch.matmul(x, self.Theta1)   # (batch_size * self._num_nodes, output_size)

        if self.att:
            att = torch.softmax(torch.tanh(x), dim=1)       # attention layer
            x = x * att

        x += self.bias
        if self.activation == 'relu':
            x = F.relu(x)
        elif self.activation == 'selu':
            x = F.selu(x)   
            
        return x

## Code of BTCN from Yuankai
class B_TCN(nn.Module):
    """
    Neural network block that applies a bidirectional temporal convolution to each node of
    a graph.
    """
    def __init__(self, in_channels, out_channels, kernel_size=3,activation = 'relu',device='cuda:0'):
        """
        :param in_channels: Number of nodes in the graph.
        :param out_channels: Desired number of output features.
        :param kernel_size: Size of the 1D temporal kernel.
        """
        super(B_TCN, self).__init__()
        # forward dirction temporal convolution
        self.kernel_size = kernel_size
        self.out_channels = out_channels
        self.activation = activation
        self.device = device
        self.conv1 = nn.Conv2d(in_channels, out_channels, (1, kernel_size))
        self.conv2 = nn.Conv2d(in_channels, out_channels, (1, kernel_size))
        self.conv3 = nn.Conv2d(in_channels, out_channels, (1, kernel_size))
        
        self.conv1b = nn.Conv2d(in_channels, out_channels, (1, kernel_size))
        self.conv2b = nn.Conv2d(in_channels, out_channels, (1, kernel_size))
        self.conv3b = nn.Conv2d(in_channels, out_channels, (1, kernel_size))
        
    def forward(self, X):
        """
        :param X: Input data of shape (batch_size, num_timesteps, num_nodes)
        :return: Output data of shape (batch_size, num_timesteps, num_features)
        """
        batch_size = X.shape[0]
        seq_len = X.shape[1]
        Xf = X.unsqueeze(1)  # (batch_size, 1, num_timesteps, num_nodes)
        
        inv_idx = torch.arange(Xf.size(2)-1, -1, -1).long().to(device=self.device)#.to(device=self.device).to(device=self.device)
        Xb = Xf.index_select(2, inv_idx) # inverse the direction of time
        
        Xf = Xf.permute(0, 3, 1, 2)
        Xb = Xb.permute(0, 3, 1, 2) #(batch_size, num_nodes, 1, num_timesteps)
        tempf = self.conv1(Xf) * torch.sigmoid(self.conv2(Xf)) #+
        outf = tempf + self.conv3(Xf) 
        outf = outf.reshape([batch_size, seq_len - self.kernel_size + 1, self.out_channels])        
        
        tempb = self.conv1b(Xb) * torch.sigmoid(self.conv2b(Xb)) #+
        outb = tempb + self.conv3b(Xb)
        outb = outb.reshape([batch_size, seq_len - self.kernel_size + 1, self.out_channels])
        
        rec = torch.zeros([batch_size, self.kernel_size - 1, self.out_channels]).to(device=self.device)#.to(device=self.device)
        outf = torch.cat((outf, rec), dim = 1)
        outb = torch.cat((outb, rec), dim = 1) #(batch_size, num_timesteps, out_features)
        
        inv_idx = torch.arange(outb.size(1)-1, -1, -1).long().to(device=self.device)#.to(device=self.device)
        outb = outb.index_select(1, inv_idx)
        out = outf + outb
        if self.activation == 'relu':
            out = F.relu(outf) + F.relu(outb)
        elif self.activation == 'sigmoid':
            out = F.sigmoid(outf) + F.sigmoid(outb)       
        return out


class ST_NB(nn.Module):
    """
  wx_t  + wx_s
    |       |
   TC4     SC4
    |       |
   TC3     SC3
    |       |
   z_t     z_s
    |       |
   TC2     SC2
    |       |  
   TC1     SC1
    |       |
   x_m     x_m
    """
    def __init__(self, SC1, SC2, SC3, TC1, TC2, TC3, SNB,TNB): 
        super(ST_NB, self).__init__()
        self.TC1 = TC1
        self.TC2 = TC2
        self.TC3 = TC3
        self.TNB = TNB

        self.SC1 = SC1
        self.SC2 = SC2
        self.SC3 = SC3
        self.SNB = SNB

    def forward(self, X, A_q, A_h):
        """
        :param X: Input data of shape (batch_size, num_timesteps, num_nodes)
        :A_hat: The Laplacian matrix (num_nodes, num_nodes)
        :return: Reconstructed X of shape (batch_size, num_timesteps, num_nodes)
        """
        print(111)
        print(111)
        X = X[:,:,:,0] # Dummy dimension deleted
        X_T = X.permute(0,2,1)
        X_t1 = self.TC1(X_T)
        X_t2 = self.TC2(X_t1) #num_time, rank
        self.temporal_factors = X_t2
        X_t3 = self.TC3(X_t2)
        _b,_h,_ht = X_t3.shape
        n_t_nb,p_t_nb = self.TNB(X_t3.view(_b,_h,_ht,1))

        X_s1 = self.SC1(X, A_q, A_h)
        X_s2 = self.SC2(X_s1, A_q, A_h) #num_nodes, rank
        self.space_factors = X_s2
        X_s3 = self.SC3(X_s2, A_q, A_h)
        _b,_n,_hs = X_s3.shape
        n_s_nb,p_s_nb = self.SNB(X_s3.view(_b,_n,_hs,1))
        n_res = n_t_nb.permute(0, 2, 1) * n_s_nb
        p_res = p_t_nb.permute(0, 2, 1) * p_s_nb
               
        return n_res,p_res

class ST_Gau(nn.Module):
    """
  wx_t  + wx_s
    |       |
   TC4     SC4
    |       |
   TC3     SC3
    |       |
   z_t     z_s
    |       |
   TC2     SC2
    |       |  
   TC1     SC1
    |       |
   x_m     x_m
    """
    def __init__(self, SC1, SC2, SC3, TC1, TC2, TC3, SGau,TGau): 
        super(ST_Gau, self).__init__()
        self.TC1 = TC1
        self.TC2 = TC2
        self.TC3 = TC3
        self.TGau = TGau

        self.SC1 = SC1
        self.SC2 = SC2
        self.SC3 = SC3
        self.SGau = SGau

    def forward(self, X, A_q, A_h):
        """
        :param X: Input data of shape (batch_size, num_timesteps, num_nodes)
        :A_hat: The Laplacian matrix (num_nodes, num_nodes)
        :return: Reconstructed X of shape (batch_size, num_timesteps, num_nodes)
        """  
        X = X[:,:,:,0] #.to(device='cuda') # Dummy dimension deleted
        X_T = X.permute(0,2,1)
        X_t1 = self.TC1(X_T)
        X_t2 = self.TC2(X_t1) #num_time, rank
        self.temporal_factors = X_t2
        X_t3 = self.TC3(X_t2)
        _b,_h,_ht = X_t3.shape
        loc_t,scale_t = self.TGau(X_t3.view(_b,_h,_ht,1))

        X_s1 = self.SC1(X, A_q, A_h)
        X_s2 = self.SC2(X_s1, A_q, A_h) #num_nodes, rank
        self.space_factors = X_s2
        X_s3 = self.SC3(X_s2, A_q, A_h)
        _b,_n,_hs = X_s3.shape
        loc_s,scale_s = self.SGau(X_s3.view(_b,_n,_hs,1))

        loc_res = loc_t.permute(0, 2, 1) * loc_s
        scale_res = scale_t.permute(0, 2, 1) * scale_s
               
        return loc_res,scale_res

class ST_NB_ZeroInflated(nn.Module):
    """
  wx_t  + wx_s
    |       |
   TC4     SC4
    |       |
   TC3     SC3
    |       |
   z_t     z_s
    |       |
   TC2     SC2
    |       |  
   TC1     SC1
    |       |
   x_m     x_m
    """
    def __init__(self, SC1, SC2, SC3, TC1, TC2, TC3, SNB,TNB): 
        super(ST_NB_ZeroInflated, self).__init__()
        self.TC1 = TC1
        self.TC2 = TC2
        self.TC3 = TC3
        self.TNB = TNB

        self.SC1 = SC1
        self.SC2 = SC2
        self.SC3 = SC3
        self.SNB = SNB

    def forward(self, X, A_q, A_h):
        """
        :param X: Input data of shape (batch_size, num_timesteps, num_nodes)
        :A_hat: The Laplacian matrix (num_nodes, num_nodes)
        :return: Reconstructed X of shape (batch_size, num_timesteps, num_nodes)
        """  
        X = X[:,:,:,0]#.to(device='cuda') # Dummy dimension deleted
        X_T = X.permute(0,2,1)
        X_t1 = self.TC1(X_T)
        X_t2 = self.TC2(X_t1) #num_time, rank
        self.temporal_factors = X_t2
        X_t3 = self.TC3(X_t2)
        _b,_h,_ht = X_t3.shape
        n_t_nb,p_t_nb,pi_t_nb = self.TNB(X_t3.view(_b,_h,_ht,1))

        X_s1 = self.SC1(X, A_q, A_h)
        X_s2 = self.SC2(X_s1, A_q, A_h) #num_nodes, rank
        self.space_factors = X_s2
        X_s3 = self.SC3(X_s2, A_q, A_h)
        _b,_n,_hs = X_s3.shape
        n_s_nb,p_s_nb,pi_s_nb = self.SNB(X_s3.view(_b,_n,_hs,1))
        n_res = n_t_nb.permute(0, 2, 1) * n_s_nb
        p_res = p_t_nb.permute(0, 2, 1) * p_s_nb
        pi_res = pi_t_nb.permute(0, 2, 1) * pi_s_nb

        return n_res,p_res,pi_res

class ST_TWEEDIE(nn.Module):
    """
  wx_t  + wx_s
    |       |
   TC4     SC4
    |       |
   TC3     SC3
    |       |
   z_t     z_s
    |       |
   TC2     SC2
    |       |
   TC1     SC1
    |       |
   x_m     x_m
    """
    def __init__(self, SC1, SC2, SC3, TC1, TC2, TC3, SNB,TNB, four=False):
        super(ST_TWEEDIE, self).__init__()
        self.TC1 = TC1
        self.TC2 = TC2
        self.TC3 = TC3
        self.TNB = TNB

        self.SC1 = SC1
        self.SC2 = SC2
        self.SC3 = SC3
        self.SNB = SNB
        self.four = four

    def forward(self, X, A_q, A_h):
        """
        :param X: Input data of shape (batch_size, num_timesteps, num_nodes)
        :A_hat: The Laplacian matrix (num_nodes, num_nodes)
        :return: Reconstructed X of shape (batch_size, num_timesteps, num_nodes)
        """
        X = X[:,:,:,0]#.to(device='cuda') # Dummy dimension deleted
        X_T = X.permute(0,2,1)
        X_t1 = self.TC1(X_T)
        X_t2 = self.TC2(X_t1) #num_time, rank
        self.temporal_factors = X_t2
        X_t3 = self.TC3(X_t2)
        _b,_h,_ht = X_t3.shape
        if self.four:
            n_t_nb, p_t_nb, pi_t_nb, zi_t_nb = self.TNB(X_t3.view(_b, _h, _ht, 1))
        else:
            n_t_nb,p_t_nb,pi_t_nb = self.TNB(X_t3.view(_b,_h,_ht,1))

        X_s1 = self.SC1(X, A_q, A_h)
        X_s2 = self.SC2(X_s1, A_q, A_h) #num_nodes, rank
        self.space_factors = X_s2
        X_s3 = self.SC3(X_s2, A_q, A_h)
        _b,_n,_hs = X_s3.shape
        if self.four:
            n_s_nb, p_s_nb, pi_s_nb, zi_s_nb = self.SNB(X_s3.view(_b, _n, _hs, 1))

            zi_res = zi_t_nb.permute(0, 2, 1) * zi_s_nb
            # zi_res = torch.sigmoid(zi_res)
        else:
            n_s_nb,p_s_nb,pi_s_nb = self.SNB(X_s3.view(_b,_n,_hs,1))

        phi_res = n_t_nb.permute(0, 2, 1) * n_s_nb
        rou_res = p_t_nb.permute(0, 2, 1) * p_s_nb
        mu_res = pi_t_nb.permute(0, 2, 1) * pi_s_nb

        rou_res = torch.sigmoid(rou_res) + 1
        phi_res = torch.relu(phi_res)


        # n, p, pi, zi => phi, rou, mu, zi
        if self.four:
            return phi_res, rou_res, mu_res, zi_res
        else:
            return phi_res, rou_res, mu_res


from __future__ import division
import os
import zipfile
import numpy as np
import scipy.sparse as sp
import pandas as pd
from math import radians, cos, sin, asin, sqrt
# from sklearn.externals import joblib
import joblib
import scipy.io
import torch
from torch import nn
from scipy.stats import nbinom,norm
rand = np.random.RandomState(0)
import tweedie
import torch.distributions
from sklearn.preprocessing import MinMaxScaler
import math

"""
Geographical information calculation
"""
def get_long_lat(sensor_index,loc = None):
    """
        Input the index out from 0-206 to access the longitude and latitude of the nodes
    """
    if loc is None:
        locations = pd.read_csv('data/metr/graph_sensor_locations.csv')
    else:
        locations = loc
    lng = locations['longitude'].loc[sensor_index]
    lat = locations['latitude'].loc[sensor_index]
    return lng.to_numpy(),lat.to_numpy()

def haversine(lon1, lat1, lon2, lat2): 
    """
    Calculate the great circle distance between two points 
    on the earth (specified in decimal degrees)
    """
    lon1, lat1, lon2, lat2 = map(radians, [lon1, lat1, lon2, lat2])
 
    # haversine
    dlon = lon2 - lon1 
    dlat = lat2 - lat1 
    a = sin(dlat/2)**2 + cos(lat1) * cos(lat2) * sin(dlon/2)**2
    c = 2 * asin(sqrt(a)) 
    r = 6371 
    return c * r * 1000


"""
Generate the training sample for forecasting task, same idea from STGCN
"""

def generate_dataset(X, num_timesteps_input, num_timesteps_output, origional_feature=True):
    """
    Takes node features for the graph and divides them into multiple samples
    along the time-axis by sliding a window of size (num_timesteps_input+
    num_timesteps_output) across it in steps of 1.
    :param X: Node features of shape (num_vertices, num_features,
    num_timesteps)
    :return:
        - Node features divided into multiple samples. Shape is
          (num_samples, num_vertices, num_features, num_timesteps_input).
        - Node targets for the samples. Shape is
          (num_samples, num_vertices, num_features, num_timesteps_output).
    """
    # Generate the beginning index and the ending index of a sample, which
    # contains (num_points_for_training + num_points_for_predicting) points
    indices = [(i, i + (num_timesteps_input + num_timesteps_output)) for i
               in range(X.shape[2] - (
                num_timesteps_input + num_timesteps_output) + 1)]

    # Save samples
    features, target = [], []
    for i, j in indices:
        features.append(
            X[:, :, i: i + num_timesteps_input].transpose(
                (0, 2, 1)))
        target.append(X[:, 0, i + num_timesteps_input: j])

    return torch.from_numpy(np.array(features)), \
           torch.from_numpy(np.array(target))


"""
Dynamically construct the adjacent matrix
"""

def get_Laplace(A):
    """
    Returns the laplacian adjacency matrix. This is for C_GCN
    """
    if A[0, 0] == 1:
        A = A - np.diag(np.ones(A.shape[0], dtype=np.float32)) # if the diag has been added by 1s
    D = np.array(np.sum(A, axis=1)).reshape((-1,))
    D[D <= 10e-5] = 10e-5
    diag = np.reciprocal(np.sqrt(D))
    A_wave = np.multiply(np.multiply(diag.reshape((-1, 1)), A),
                         diag.reshape((1, -1)))
    return A_wave

def get_normalized_adj(A):
    """
    Returns the degree normalized adjacency matrix. This is for K_GCN
    """
    if A[0, 0] == 0:
        A = A + np.diag(np.ones(A.shape[0], dtype=np.float32)) # if the diag has been added by 1s
    D = np.array(np.sum(A, axis=1)).reshape((-1,))
    D[D <= 10e-5] = 10e-5    # Prevent infs
    diag = np.reciprocal(np.sqrt(D))
    A_wave = np.multiply(np.multiply(diag.reshape((-1, 1)), A),
                         diag.reshape((1, -1)))
    return A_wave

def calculate_random_walk_matrix(adj_mx):
    """
    Returns the random walk adjacency matrix. This is for D_GCN
    """
    adj_mx = sp.coo_matrix(adj_mx)
    d = np.array(adj_mx.sum(1))
    d_inv = np.power(d, -1).flatten()
    d_inv[np.isinf(d_inv)] = 0.
    d_mat_inv = sp.diags(d_inv)
    random_walk_mx = d_mat_inv.dot(adj_mx).tocoo()
    return random_walk_mx.toarray()


def test_error_virtual(STmodel, unknow_set, test_data, A_s, E_maxvalue, Missing0):
    """
    :param STmodel: The graph neural networks
    :unknow_set: The unknow locations for spatial prediction
    :test_data: The true value test_data of shape (test_num_timesteps, num_nodes)
    :A_s: The full adjacent matrix
    :Missing0: True: 0 in original datasets means missing data
    :return: NAE, MAPE and RMSE
    """  
    unknow_set = set(unknow_set)
    time_dim = STmodel.time_dimension
    
    test_omask = np.ones(test_data.shape)
    if Missing0 == True:
        test_omask[test_data == 0] = 0
    test_inputs = (test_data * test_omask).astype('float32')
    test_inputs_s = test_inputs
   
    missing_index = np.ones(np.shape(test_data))
    missing_index[:, list(unknow_set)] = 0
    missing_index_s = missing_index
    
    o = np.zeros([test_data.shape[0]//time_dim*time_dim, test_inputs_s.shape[1]]) #Separate the test data into several h period
    
    for i in range(0, test_data.shape[0]//time_dim*time_dim, time_dim):
        inputs = test_inputs_s[i:i+time_dim, :]
        missing_inputs = missing_index_s[i:i+time_dim, :]
        T_inputs = inputs*missing_inputs
        T_inputs = T_inputs/E_maxvalue
        T_inputs = np.expand_dims(T_inputs, axis = 0)
        T_inputs = torch.from_numpy(T_inputs.astype('float32'))
        A_q = torch.from_numpy((calculate_random_walk_matrix(A_s).T).astype('float32'))
        A_h = torch.from_numpy((calculate_random_walk_matrix(A_s.T).T).astype('float32'))
        
        imputation = STmodel(T_inputs, A_q, A_h)
        imputation = imputation.data.numpy()
        o[i:i+time_dim, :] = imputation[0, :, :]
    
    o = o*E_maxvalue 
    truth = test_inputs_s[0:test_data.shape[0]//time_dim*time_dim]
    o[missing_index_s[0:test_data.shape[0]//time_dim*time_dim] == 1] = truth[missing_index_s[0:test_data.shape[0]//time_dim*time_dim] == 1]
    
    test_mask =  1 - missing_index_s[0:test_data.shape[0]//time_dim*time_dim]
    if Missing0 == True:
        test_mask[truth == 0] = 0
        o[truth == 0] = 0
    
    o_ = o[:,list(unknow_set)]
    truth_ = truth[:,list(unknow_set)]
    test_mask_ = test_mask[:,list(unknow_set)]

    MAE = np.sum(np.abs(o_ - truth_))/np.sum( test_mask_)
    RMSE = np.sqrt(np.sum((o_ - truth_)*(o_ - truth_))/np.sum( test_mask_) )
    # MAPE = np.sum(np.abs(o - truth)/(truth + 1e-5))/np.sum( test_mask)
    R2 = 1 - np.sum( (o_ - truth_)*(o_ - truth_) )/np.sum( (truth_ - truth_.mean())*(truth_-truth_.mean() ) )
    print(truth_.mean())
    return MAE, RMSE, R2, o

def test_error(STmodel, unknow_set, test_data, A_s, E_maxvalue, Missing0):
    """
    :param STmodel: The graph neural networks
    :unknow_set: The unknow locations for spatial prediction
    :test_data: The true value test_data of shape (test_num_timesteps, num_nodes)
    :A_s: The full adjacent matrix
    :Missing0: True: 0 in original datasets means missing data
    :return: NAE, MAPE and RMSE
    """  
    unknow_set = set(unknow_set)
    time_dim = STmodel.time_dimension
    
    test_omask = np.ones(test_data.shape)
    if Missing0 == True:
        test_omask[test_data == 0] = 0
    test_inputs = (test_data * test_omask).astype('float32')
    test_inputs_s = test_inputs
   
    missing_index = np.ones(np.shape(test_data))
    missing_index[:, list(unknow_set)] = 0
    missing_index_s = missing_index
    
    o = np.zeros([test_data.shape[0]//time_dim*time_dim, test_inputs_s.shape[1]]) #Separate the test data into several h period
    
    for i in range(0, test_data.shape[0]//time_dim*time_dim, time_dim):
        inputs = test_inputs_s[i:i+time_dim, :]
        missing_inputs = missing_index_s[i:i+time_dim, :]
        T_inputs = inputs*missing_inputs
        T_inputs = T_inputs/E_maxvalue
        T_inputs = np.expand_dims(T_inputs, axis = 0)
        T_inputs = torch.from_numpy(T_inputs.astype('float32'))
        A_q = torch.from_numpy((calculate_random_walk_matrix(A_s).T).astype('float32'))
        A_h = torch.from_numpy((calculate_random_walk_matrix(A_s.T).T).astype('float32'))
        
        imputation = STmodel(T_inputs, A_q, A_h)
        imputation = imputation.data.numpy()
        o[i:i+time_dim, :] = imputation[0, :, :]
    
    o = o*E_maxvalue 
    truth = test_inputs_s[0:test_data.shape[0]//time_dim*time_dim]
    o[missing_index_s[0:test_data.shape[0]//time_dim*time_dim] == 1] = truth[missing_index_s[0:test_data.shape[0]//time_dim*time_dim] == 1]
    
    test_mask =  1 - missing_index_s[0:test_data.shape[0]//time_dim*time_dim]
    if Missing0 == True:
        test_mask[truth == 0] = 0
        o[truth == 0] = 0
    
    o_ = o[:,list(unknow_set)]
    truth_ = truth[:,list(unknow_set)]
    test_mask_ = test_mask[:,list(unknow_set)]

    MAE = np.sum(np.abs(o_ - truth_))/np.sum( test_mask_)
    RMSE = np.sqrt(np.sum((o_ - truth_)*(o_ - truth_))/np.sum( test_mask_) )
    # MAPE = np.sum(np.abs(o - truth)/(truth + 1e-5))/np.sum( test_mask)
    R2 = 1 - np.sum( (o_ - truth_)*(o_ - truth_) )/np.sum( (truth_ - truth_.mean())*(truth_-truth_.mean() ) )
    print(truth_.mean())
    return MAE, RMSE, R2, o


def rolling_test_error(STmodel, unknow_set, test_data, A_s, E_maxvalue,Missing0):
    """
    :It only calculates the last time points' prediction error, and updates inputs each time point
    :param STmodel: The graph neural networks
    :unknow_set: The unknow locations for spatial prediction
    :test_data: The true value test_data of shape (test_num_timesteps, num_nodes)
    :A_s: The full adjacent matrix
    :Missing0: True: 0 in original datasets means missing data
    :return: NAE, MAPE and RMSE
    """  
    
    unknow_set = set(unknow_set)
    time_dim = STmodel.time_dimension
    
    test_omask = np.ones(test_data.shape)
    if Missing0 == True:
        test_omask[test_data == 0] = 0
    test_inputs = (test_data * test_omask).astype('float32')
    test_inputs_s = test_inputs
   
    missing_index = np.ones(np.shape(test_data))
    missing_index[:, list(unknow_set)] = 0
    missing_index_s = missing_index

    o = np.zeros([test_data.shape[0] - time_dim, test_inputs_s.shape[1]])

    for i in range(0, test_data.shape[0] - time_dim):
        inputs = test_inputs_s[i:i+time_dim, :]
        missing_inputs = missing_index_s[i:i+time_dim, :]
        MF_inputs = inputs * missing_inputs
        MF_inputs = np.expand_dims(MF_inputs, axis = 0)
        MF_inputs = torch.from_numpy(MF_inputs.astype('float32'))
        A_q = torch.from_numpy((calculate_random_walk_matrix(A_s).T).astype('float32'))
        A_h = torch.from_numpy((calculate_random_walk_matrix(A_s.T).T).astype('float32'))
        
        imputation = STmodel(MF_inputs, A_q, A_h)
        imputation = imputation.data.numpy()
        o[i, :] = imputation[0, time_dim-1, :]
    
 
    truth = test_inputs_s[time_dim:test_data.shape[0]]
    o[missing_index_s[time_dim:test_data.shape[0]] == 1] = truth[missing_index_s[time_dim:test_data.shape[0]] == 1]
    
    o = o*E_maxvalue
    truth = test_inputs_s[0:test_data.shape[0]//time_dim*time_dim]
    test_mask =  1 - missing_index_s[time_dim:test_data.shape[0]]
    if Missing0 == True:
        test_mask[truth == 0] = 0
        o[truth == 0] = 0
        
    MAE = np.sum(np.abs(o - truth))/np.sum( test_mask)
    RMSE = np.sqrt(np.sum((o - truth)*(o - truth))/np.sum( test_mask) )
    MAPE = np.sum(np.abs(o - truth)/(truth + 1e-5))/np.sum( test_mask)  #avoid x/0
        
    return MAE, RMSE, MAPE, o

def test_error_cap(STmodel, unknow_set, full_set, test_set, A,time_dim,capacities):
    unknow_set = set(unknow_set)
    
    test_omask = np.ones(test_set.shape)
    test_omask[test_set == 0] = 0
    test_inputs = (test_set * test_omask).astype('float32')
    test_inputs_s = test_inputs#[:, list(proc_set)]

    
    missing_index = np.ones(np.shape(test_inputs))
    missing_index[:, list(unknow_set)] = 0
    missing_index_s = missing_index#[:, list(proc_set)]
    
    A_s = A#[:, list(proc_set)][list(proc_set), :]
    o = np.zeros([test_set.shape[0]//time_dim*time_dim, test_inputs_s.shape[1]])
    
    for i in range(0, test_set.shape[0]//time_dim*time_dim, time_dim):
        inputs = test_inputs_s[i:i+time_dim, :]
        missing_inputs = missing_index_s[i:i+time_dim, :]
        MF_inputs = inputs*missing_inputs
        MF_inputs = MF_inputs
        MF_inputs = np.expand_dims(MF_inputs, axis = 0)
        MF_inputs = torch.from_numpy(MF_inputs.astype('float32'))
        A_q = torch.from_numpy((calculate_random_walk_matrix(A_s).T).astype('float32'))
        A_h = torch.from_numpy((calculate_random_walk_matrix(A_s.T).T).astype('float32'))
        
        imputation = STmodel(MF_inputs, A_q, A_h)
        imputation = imputation.data.numpy()
        o[i:i+time_dim, :] = imputation[0, :, :]
    
    o = o*capacities
    truth = test_inputs_s[0:test_set.shape[0]//time_dim*time_dim]
    truth = truth*capacities
    o[missing_index_s[0:test_set.shape[0]//time_dim*time_dim] == 1] = truth[missing_index_s[0:test_set.shape[0]//time_dim*time_dim] == 1]
    o[truth == 0] = 0
    
    test_mask =  1 - missing_index_s[0:test_set.shape[0]//time_dim*time_dim]
    test_mask[truth == 0] = 0
    
    o_ = o[:,list(unknow_set)]
    truth_ = truth[:,list(unknow_set)]
    test_mask_ = test_mask[:,list(unknow_set)]

    MAE = np.sum(np.abs(o_ - truth_))/np.sum( test_mask_)
    RMSE = np.sqrt(np.sum((o_ - truth_)*(o_ - truth_))/np.sum( test_mask_) )
    # MAPE = np.sum(np.abs(o - truth)/(truth + 1e-5))/np.sum( test_mask)
    R2 = 1 - np.sum( (o_ - truth_)*(o_ - truth_) )/np.sum( (truth_ - truth_.mean())*(truth_-truth_.mean() ) )
    print(truth_.mean())
    return MAE, RMSE, R2, o

def nb_nll_loss(y,n,p,y_mask=None):
    """
    y: true values
    y_mask: whether missing mask is given
    """
    nll = torch.lgamma(n) + torch.lgamma(y+1) - torch.lgamma(n+y) - n*torch.log(p) - y*torch.log(1-p)
    if y_mask is not None:
        nll = nll*y_mask
    return torch.sum(nll)

def nb_zeroinflated_nll_loss(y,n,p,pi,y_mask=None):
    """
    y: true values
    y_mask: whether missing mask is given
    https://stats.idre.ucla.edu/r/dae/zinb/
    """
    pi = torch.clip(pi, 1e-3, 1-1e-3)
    p = torch.clip(p, 1e-3, 1-1e-3)

    idx_yeq0 = y==0
    idx_yg0  = y>0
    
    n_yeq0 = n[idx_yeq0]
    p_yeq0 = p[idx_yeq0]
    pi_yeq0 = pi[idx_yeq0]
    yeq0 = y[idx_yeq0]

    n_yg0 = n[idx_yg0]
    p_yg0 = p[idx_yg0]
    pi_yg0 = pi[idx_yg0]
    yg0 = y[idx_yg0]

    #L_yeq0 = torch.log(pi_yeq0) + (1-pi_yeq0)*torch.pow(p_yeq0,n_yeq0)
    #L_yg0  = torch.log(pi_yg0) + torch.lgamma(n_yg0+yg0) - torch.lgamma(yg0+1) - torch.lgamma(n_yg0) + n_yg0*torch.log(p_yg0) + yg0*torch.log(1-p_yg0)
    L_yeq0 = torch.log(pi_yeq0+1e-4) + torch.log(1e-4+ (1-pi_yeq0)*torch.pow(p_yeq0,n_yeq0))
    L_yg0  = torch.log(1-pi_yg0+1e-4) + torch.lgamma(n_yg0+yg0) - torch.lgamma(yg0+1) - torch.lgamma(n_yg0+1e-4) + n_yg0*torch.log(p_yg0+1e-4) + yg0*torch.log(1-p_yg0+1e-4)
    #print('nll',torch.mean(L_yeq0),torch.mean(L_yg0),torch.mean(torch.log(pi_yeq0)),torch.mean(torch.log(pi_yg0)))
    return -torch.mean(L_yeq0)-torch.mean(L_yg0)

    # return torch.sum((((1-pi)*(n/p-n)).reshape(-1)-y.reshape(-1))*(((1-pi)*(n/p-n)).reshape(-1)-y.reshape(-1)))

def tweedie_nll_loss(y_true, phi, rou, mu):

    def lower_w_j(y_true, y_pred, j_max, rou, phi):
        rou = rou.reshape(-1)
        j_max = j_max.reshape(-1)
        alpha = (2-rou)/(1 - rou)

        log_w_j = -torch.log(y_true + 1e-3) + j_max * (alpha-1) - torch.log(j_max + 1e-3) - 1/2 * torch.log(-alpha + 1e-3)

        return -log_w_j

    rou = torch.clamp(rou, 1 + 1e-3, 2 - 1e-3)
    # 1. TD
    phi = torch.clamp(phi, 1, 10)       # ！！！！
    # 2. STP
    # phi = torch.ones_like(phi)
    # 3. STGM
    # phi = phi + 1
    # 4. STIG
    # phi = phi + 1
    mu = torch.exp(mu)
    ll = torch.ones_like(y_true)
    ll_1to_2_mask = (1 < rou) & (rou < 2)

    if torch.sum(ll_1to_2_mask) > 0:
        # 为0
        zeros = y_true == 0
        mask = zeros & ll_1to_2_mask
        ll[mask] = (mu[mask] ** (2 - rou[mask]) / (phi[mask] * (2 - rou[mask])))

        # 非0
        mask = ~zeros & ll_1to_2_mask
        ll[mask] = -(y_true[mask] * mu[mask] ** (1 - rou[mask]) / (phi[mask] * (1 - rou[mask]))) + (
                    mu[mask] ** (2 - rou[mask]) / (phi[mask] * (2 - rou[mask])))

        j_max = y_true ** (2-rou) / (2-rou) / phi
        ll[mask] -= lower_w_j(y_true[mask], mu[mask], j_max[mask], rou[mask], phi[mask])

    return ll.mean()


def nb_tcn_nll(y, phi, rou, mu, zi, y_mask=None):
    mu = torch.exp(mu)
    loss_y = ((mu - y)**2).mean()

    return loss_y.mean()


def nb_zeroinflated_draw(n,p,pi):
    """
    input: n, p, pi tensors
    output: drawn values
    """
    origin_shape = n.shape
    n = n.flatten()
    p = p.flatten()
    pi = pi.flatten()
    nb = nbinom(n,p)
    x_low = nb.ppf(0.01)
    x_up  = nb.ppf(0.99)
    pred = np.zeros_like(n)
   # print(n.shape,x_low.shape,pi.min())
    for i in range(len(x_low)):
        if x_up[i]<=1:
            x_up[i] = 1
        x = np.arange(x_low[i],x_up[i])
        #print(pi[0],pi[0].shape,x.shape,pi.shape)
        prob = (1-pi[i]) * nbinom.pmf(x,n[i],p[i])
#        print(len(prob),len(pi),len(n),len(x))
        prob[0] += pi[i] # zero-inflatted
        pred[i] = rand.choice(a=x,p=prob/np.sum(prob)) # random seed fixed, defined in the beginning

    return pred.reshape(origin_shape)


def gauss_draw(loc,scale):
    """
    input: n, p, pi tensors
    output: drawn values
    """
    origin_shape = loc.shape
    loc = loc.flatten()
    scale = scale.flatten()
    gauss = norm(loc,scale)
    x_low = gauss.ppf(0.01)
    x_up  = gauss.ppf(0.99)
    pred = np.zeros_like(loc)
    #print(n.shape,x_low.shape,pi.min())
    for i in range(len(x_low)):
        x = np.arange(x_low[i],x_up[i],100)
        prob = norm.pdf(x,loc[i],scale[i])
        pred[i] = rand.choice(a=x,p=prob/np.sum(prob)) # random seed fixed, defined in the beginning

    return pred.reshape(origin_shape)

def nb_draw(n,p):
    """
    input: n, p, pi tensors
    output: drawn values
    """
    origin_shape = n.shape
    n = n.flatten()
    p = p.flatten()
    nb = nbinom(n,p)
    x_low = nb.ppf(0.01)
    x_up  = nb.ppf(0.99)
    pred = np.zeros_like(n)
    for i in range(len(x_low)):
        if x_up[i]<=1:
            x_up[i] = 1
        if x_up[i] == x_low[i]:
            x_up[i] = x_low[i]+1
        #print(x_low[i],x_up[i])
        x = np.arange(x_low[i],x_up[i])
        prob = nbinom.pmf(x,n[i],p[i])
        pred[i] = rand.choice(a=x,p=prob/np.sum(prob)) # random seed fixed, defined in the beginning

    return pred.reshape(origin_shape)

def gauss_loss(y,loc,scale,y_mask=None):
    """
    The location (loc) keyword specifies the mean. The scale (scale) keyword specifies the standard deviation.
    http://jrmeyer.github.io/machinelearning/2017/08/18/mle.html
    """
    torch.pi = torch.acos(torch.zeros(1)).item() * 2 # ugly define pi value in torch format
    LL = -1/2 * torch.log(2*torch.pi*torch.pow(scale,2)+1e-2) - 1/2*( torch.pow(y-loc,2)/(torch.pow(scale,2)+1e-2) )
    LL = torch.clip(LL, -20, 10)
    return -torch.mean(LL)


def rmse(truth, pred):
    return np.sqrt(((truth - pred) ** 2).mean())


def mae(truth, pred):
    pred[pred<1]=0
    return np.abs(truth - pred).mean()


def wape(truth, pred):
    return np.abs(np.subtract(pred, truth)).sum() / np.sum(truth)


def mape(truth, pred):
    return np.mean(np.abs((np.subtract(pred, truth) + 1e-5) / (truth + 1e-5)))


def true_zeros(truth, pred):
    idx = truth == 0
    return np.sum(pred[idx]==0)/np.sum(idx)


def KL_DIV(truth, pred):
    return np.sum(pred * np.log((pred + 1e-5) / (truth + 1e-5)))

def KL_DIV_divide(truth, pred):
    return np.sum(pred * np.log((pred + 1e-1) / (truth + 1e-1))) / np.prod(truth.shape)


def F1_SCORE(truth,pred):
    true_zeros = truth == 0
    pred_zeros = pred == 0
    precision = np.sum(pred_zeros & true_zeros ) / np.sum(pred_zeros)
    recall = np.sum(pred_zeros)/np.sum(true_zeros)
    return 2*(precision*recall)/(precision+recall)


def print_errors(truth, pred,string=None):
    print(string,' RMSE %.4f MAE %.4f F1_SCORE %.4f KL-Div: %.4f, KL-Div-divide: %.4f, true_zeros_rate %.4f : '%(
        rmse(truth,pred),mae(truth,pred),F1_SCORE(truth,pred),KL_DIV(truth,pred),KL_DIV_divide(truth,pred),true_zeros(truth,pred)
    ))