###############################################################
# Code for paper 
# Dynamic COVID risk assessment accounting for community virus exposure from a spatial-temporal transmission model

# This file contains code to run the experiment from simulated data. 
# Data are generated in the R file "simulateData.R"
###############################################################


import numpy as np
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim 
from torch.utils.data import Dataset, DataLoader, ConcatDataset

import os
import pandas as pd
import math
import sys

from functions import STmodel, PrepareData




n_time = 160
n_area = 176
C = 21
n_beta = 1
t0=torch.tensor([0]*n_area) # python ix start from 0

# data_X: 3d array (n_time * n_area * n_beta), store covariate information for each area at each time point
# each 2d array represents a time point, 
# each row in 2d array represents covariate info for an area, 
# each col in 2d array represents a covariate for all the areas at current time point.
data_X_temp = np.ones((n_time, n_area, n_beta)) 
X_temp = np.loadtxt("Xmatrix.txt")
X_temp = np.mat(X_temp).T
for ix in range(1,(n_time+1)):
    # path = "coMatrix/cov_" + str(ix) + ".csv" # if X is time changing
    data_X_temp[(ix-1),:,:] = X_temp
data_X = data_X_temp
data_X.shape


# data_Y: 2d array (n_time * n_area), each col contain observed case number for n_time days.
data_Y = np.loadtxt('dataY/' + str(subid) + '.txt')
print(data_Y.shape)
# data_nei: neighborhood matrix, whether two areas are neighborhood.
data_nei = np.loadtxt('simul_nei.txt')
print(data_nei.shape)

initial_a = np.loadtxt('initial_a/' + str(subid) + '.txt') # can use value calculated from survival convolution model
print(initial_a.shape)


initial_rho = 0.1
para_mat = np.loadtxt('initial_para.txt') 
para_mat = para_mat + 0.05 * np.random.rand(160, 4) - 0.025
initial_beta0 = para_mat[:,0] 
initial_beta = para_mat[:,1] 
initial_tau0 = para_mat[:,2] 




# create dataset
dataset = PrepareData(data_Y, data_X, data_nei)
# initialize the model
stmodel = STmodel(n_time=n_time, n_area=n_area, t0=t0, C=C, n_beta=1, loglik_out = 14, initial_a=initial_a, initial_beta0=initial_beta0, initial_beta=initial_beta, initial_rho=initial_rho, initial_tau0=initial_tau0)

# Training the model with initial value and schduler
# Training the model 
opt = optim.Adam(stmodel.parameters(), lr=0.05) 
sch = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', factor=0.5, patience=1, threshold=0.0001, threshold_mode='rel', cooldown=5, min_lr=0, eps=1e-03, verbose=True)
num_epochs = 1000
model_train = train_STmodel(num_epochs, stmodel, opt, sch, dataset, penalty="L1", lambd=0.8, lambd_a = 0) #lambd penalty for parameter, lambd_a penalty for a(t)