###############################################################
# Code for paper 
# Dynamic COVID risk assessment accounting for community virus exposure from a spatial-temporal transmission model

###############################################################


import numpy as np
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim 
from torch.utils.data import Dataset, DataLoader, ConcatDataset

import os
import pandas as pd
import math
import sys



# define the class to run Spatial-temporal model
class STmodel(torch.nn.Module):
    

    def __init__(self, n_time, n_area, t0, C, n_beta, loglik_out, initial_a, initial_beta0, initial_beta, initial_rho, initial_tau0): 
    # n_time: from the first N(t)=1 in one area, how many time points in total
    # days will align for all areas
    # t0: index for the first day of infected case for each area
    # C: looking back how many days (maximum incubation period)
    # n_beta: number of feature variables in each area
    
        super().__init__()
    
        self.n_time = n_time
        self.t0 = t0
        self.C = C
        self.n_area = n_area
        self.n_beta = n_beta
        self.loglik_out = loglik_out

        
        ## parameters
        self.beta0 = torch.nn.Parameter(torch.tensor(initial_beta0))
        self.beta = torch.nn.Parameter(torch.tensor(initial_beta))
        self.rho = torch.nn.Parameter(torch.tensor(initial_rho))
        self.tau0 = torch.nn.Parameter(torch.tensor(initial_tau0))
        self.loga = torch.nn.Parameter(torch.tensor(initial_a).log())
           
                
    def _cal_Mt_Yt(self, N, t, t0, C, Surv, surv_mean):
        
    # calculate Mt and Yt for one area at time t
    # N: length n_time
    # t: the current time index 
    # t0: the first time index for infected case
        
        ndays = t - t0
          
        if ndays < C:
            S = Surv(torch.arange(ndays, -1, step=-1).float(), surv_mean)
            S_1 = Surv(torch.arange(ndays+1, 0, step=-1).float(), surv_mean)
            Mt = (N[t0: (t+1)] * S).sum()  # including day t   
            Yt = (N[t0: (t+1)] * (S - S_1)).sum()
        else:
            S = Surv(torch.arange(C, -1, step=-1).float(), surv_mean)
            S_1 = Surv(torch.arange(C+1, 0, step=-1).float(), surv_mean)
            Mt = (N[(t-C) : (t+1)] * S).sum()
            Yt = (N[(t-C) : (t+1)] * (S - S_1)).sum()
        
        return {'Mt':Mt, 'Yt':Yt}
    
    
    def _udpate_all_surv(self, surv_mean):
        
    # update M, Y, N for all areas at all times
    # surv_mean: mean infectious days in the exponential distribution 

        for i in range(self.n_area):
            for t in range(self.t0[i], self.n_time):  # differ for each area
                Mt_Yt = self._cal_Mt_Yt(self.N[:,i], t, self.t0[i], self.C, self.Surv, surv_mean)
                self.M[t, i] = Mt_Yt['Mt']
                self.Y[t, i] = Mt_Yt['Yt']
                if t < self.n_time - 1:
                    self.E[t+1, i] = (self.M[t, i] - self.Y[t, i]) * (self.loga[t,:].exp().mean())
                    self.N[t+1, i] = self.loga[t,i].exp() * (self.M[t,i] - self.Y[t,i])  # here used exp(Z)
                    self.Z[t+1, i] = torch.log(self.loga[t+1,i].exp()/torch.mean(self.loga[t,:].exp()))


            
    # survival function for the infectious days        
    def Surv(self, t, mean=5.2):
        return ((torch.exp(-t / mean) - torch.exp(torch.tensor([-self.C / mean]))) / (1 - torch.exp(torch.tensor([-self.C / mean])))) * (t <= self.C)
    
    
    # calculate the H matrix at time t
    def _cal_Ht(self, t, neighbor):
        Nt_mat = self.E[t,].expand(self.n_area, -1)
        Nt_mat_dev = (Nt_mat) / (Nt_mat).transpose(0, 1)
        return torch.pow(Nt_mat_dev, 0.5) * neighbor 
    
    
    # calculate the conditional mean of Zit, for all area i and all time t
    # X: 3d, n_time * n_area * n_feature
    def _cal_theta(self, X, neighbor):
        
        # if n_beta > 1
        #miu = torch.zeros((self.n_time, self.n_area)) 
        #for i in range(self.n_area):
            #miu[:,i] = torch.sum(X[:,i,:] * self.beta, 1) + self.beta0   # for all time of this area ######################
        miu = self.beta.reshape((self.n_time,1)) * X[:,:,0].reshape((self.n_time, self.n_area)) + self.beta0.reshape((self.n_time,1))


        theta = torch.zeros((self.n_time, self.n_area))
        for t in range(self.n_time):
            Ht = self._cal_Ht(t, neighbor)
            theta[t,] = miu[t,] + ((self.Z[t,].expand(self.n_area, -1).transpose(0,1) - miu[t,]) * Ht * neighbor * self.rho).sum(1)            
        return theta


    def cal_loss(self, X, neighbor, Y_obs, penalty=None, lambd=None, lambd_a=None):
        # Y_obs: n_time * n_area
        # X: 3d, n_time * n_area * n_feature
        # neighbor: n_area * n_area
        
        theta = self._cal_theta(X, neighbor)
        first_term = (torch.sqrt(Y_obs) - torch.sqrt(self.Y)).pow(2).mean()
        second_term = ((self.Z - theta)[self.loglik_out:,].pow(2) / (self.tau0[self.loglik_out:].pow(2).expand(self.n_area, -1).transpose(0, 1)/self.E[self.loglik_out:])/ 2).mean()
        second_term2 = self.tau0.expand(self.n_area, -1).transpose(0, 1)[self.loglik_out:,].abs().log().mean() - self.E[self.loglik_out:].log().mean()/2 
          
        if penalty=="L1":
            loss_l1 = nn.L1Loss()
            penalty_term2 = torch.tensor([0.])
            penalty_term = torch.tensor([0.])
            for t in range(0, self.loglik_out):
                penalty_term = penalty_term + loss_l1(self.loga[t+1,].exp(), self.loga[t,].exp())
            
            for t in range(self.loglik_out, self.n_time-1):
                penalty_term = penalty_term + loss_l1(self.loga[t+1,].exp(), self.loga[t,].exp())
                penalty_term2 = penalty_term2 + 0.25*loss_l1(self.beta[t+1], self.beta[t]) + 0.25*loss_l1(self.beta0[t+1], self.beta0[t])  + loss_l1(self.tau0[t+1], self.tau0[t])

            loss_all = first_term + second_term + second_term2 + penalty_term2*lambd + penalty_term*lambd_a
    
        else:
            loss_all = first_term + second_term + second_term2 
        #import pdb; pdb.set_trace()
        return loss_all      
        
        
    
    def forward(self, surv_mean=5.2):
        
        ## N, Y, M: 2d tensor, time*area
        self.N = torch.zeros((self.n_time, self.n_area))
        self.Y = torch.zeros((self.n_time, self.n_area))
        self.M = torch.zeros((self.n_time, self.n_area))
        self.E = torch.ones((self.n_time, self.n_area))
        self.Z = torch.zeros((self.n_time, self.n_area))+ 0.01
        
        # assign the first latent case
        for i in range(self.n_area):
            self.N[self.t0[i],i] = 1
            self.M[self.t0[i],i] = 1
            self.E[self.t0[i],i] = 1 
            self.Z[self.t0[i],i] = 0.1
            
        self._udpate_all_surv(surv_mean)

# make input data to tensor, inputs as numpy arrays
class PrepareData(Dataset):

    def __init__(self, Y_obs, X, neighbor):
        
        # Y_obs: n_time * n_area
        # X: n_time * n_area * n_feature
        # neighbor: n_area * n_area

        self.Y_obs = torch.from_numpy(Y_obs).float()
        self.X = torch.from_numpy(X).float()
        self.neighbor = torch.from_numpy(neighbor).float()
            
    def __len__(self):
        return self.neighbor.size()[0]
    
    # get data from certain areas -- not used
    def __getitem__(self, idx):
        
        return (self.Y_obs[:,idx], self.X[:,idx,:], self.neighbor[idx,:]) # added 0613

def train_STmodel(n_epochs, model, optimizer, scheduler, dataset, penalty="L1", lambd=1, lambd_a=0.1, surv_mean=5.2):
    
    i = 0
    for epoch in range(n_epochs):    
        
        model(surv_mean)
        loss = model.cal_loss(dataset.X, dataset.neighbor, dataset.Y_obs, penalty=penalty, lambd=lambd, lambd_a=lambd_a)
        loss.backward()
        
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step(loss)
        
        print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()), file=open("loss_record.txt", "a"))




