rm(list=ls())
library(data.table)
library(mvtnorm)
library(zoo)
library(lubridate)
library(tidyverse)
library(ggplot2)
library(ggsci)
library(ggjoy)
library(devtools)
library(gamlss.mx)
if(!require(deepregression)){
  install_github("neural-structured-additive-learning/deepregression")
  library(deepregression)
}
if(!require(mixdistreg)){
  install_github("neural-structured-additive-learning/mixdistreg")
  library(mixdistreg)
}
if(!require(deeptrafo)){
  install_github("neural-structured-additive-learning/mixdistreg")
  library(mixdistreg)
}
load_all("../../neat/R")

### Begin Sim 

rm(list = ls())
set.seed(1)

# grid evaluation is coarsed
closest_to_median <- function(x, y) y[which.min(abs(x-0.5))]

y_series <- x_series <- numeric(n <- 1003)
y_series[1] <- y_series[2] <- 0
shift <- 2
x_series <- sample(c(-shift,shift), n, replace = T)
table(x_series)
ar_coef <- 0.1
# ar_coefs <- seq(0.1,0.9, by = 0.1)
# shift_eff <- seq(0, 10, by = 1)
for (t_idx in 2:n) {
  y_series[t_idx] <- rnorm(1, y_series[t_idx-1] * ar_coef + x_series[t_idx])
}

hist(y_series, breaks=50)
data <- data.frame(value = y_series[1:1000], time = 1:1000)

# perturbate for validation
data <- data[sample(1:nrow(data)),]

### End Sim

# Split dataset
set.seed(32)

idx_train = sample(seq_len(nrow(data)), nrow(data)*0.8, replace = FALSE)
train_data = data[idx_train,]
test_data = data[-idx_train,]

X_train = train_data %>% dplyr::select(time) %>% as.matrix 
y_train = train_data %>% dplyr::select(value) %>% as.matrix 
X_test = test_data %>% dplyr::select(time) %>% as.matrix 
y_test = test_data %>% dplyr::select(value) %>% as.matrix

# mixture model
comps <- 2

mod <- sammer(y = y_train,
              list_of_formulas = list(~ 1 + s(time), 
                                      ~1 + s(time)), 
              formula_mixture = ~ 1,
              family = "normal", 
              data = data.frame(time = X_train),
              nr_comps = comps
)

if(file.exists("weights_mm.hdf5")){
  
  mod$load_weights(filepath="weights_mm.hdf5", by_name = FALSE)
  
}else{
  
  mod %>% fit(epochs = 10000, 
              validation_split = 0.1,
              view_metrics = FALSE, 
              verbose = T,
              early_stopping = TRUE,
              patience = 50
  )
 
  save_model_weights_hdf5(mod, filepath="weights_mm.hdf5")
  
}

(ls_mm <- -(mod %>% log_score(data = data.frame(time = X_test), 
                            this_y = y_test) %>% mean))

plotdf <- expand.grid(time = seq(min(X_test), max(X_test), l=20),
                      y = seq(quantile(y_test, 0.01), quantile(y_test, 0.99), l=100))

dist <- mod %>% get_distribution(data = plotdf)

pdfs <- as.matrix(tf$squeeze(dist$submodules[[1]]$prob(
  value = array(rep(plotdf$y, comps), dim = c(2000,1,comps))))
)
dfpdf <- cbind(plotdf, dens_y = c(pdfs), comp = rep(1:comps, each = 100*20))
dfpdf$comp <- factor(dfpdf$comp, labels = paste0("Mix. Comp. ", 1:comps))

gg1 <- ggplot() + 
  geom_joy(data = dfpdf %>% mutate(
    comp = factor(comp, levels=levels(comp), labels = 
                    paste0("Mixture ", 1:comps, ""))
  ), 
  aes(height = dens_y * 7, x = y + time/1000 * 7, y = time, 
      group = time, fill = time*0.1), 
  stat="identity", alpha = 0.7, colour = rgb(0,0,0,0.4)) +
  theme_bw() + facet_wrap(~ comp, ncol=3, strip.position = "bottom") + xlab("") + ylab("time") +
  theme(panel.grid.major = element_blank(),
        panel.grid.minor = element_blank(),
        # panel.spacing = unit(-1.5, "lines"),
        strip.background = element_blank(),
        strip.text = element_text(size=15.5),
        panel.border = element_blank(),
        text = element_text(size=17.5),
        legend.position = "none",
        axis.text.x=element_blank(),
        axis.ticks.x=element_blank(),
        rect = element_rect(fill = "transparent"))

gg1

# Train the model
feature_net <- function(x) x %>% 
  layer_dense(units = 64, activation = "relu", use_bias = TRUE) %>% 
  layer_dense(units = 64, activation = "relu", use_bias = TRUE) %>% 
  layer_dense(units = 1, use_bias = TRUE)

feature_net_pos <- function(x) x %>% 
  layer_dense(units = 64, activation = "relu", use_bias = TRUE) %>% 
  layer_dense(units = 64, activation = "relu", use_bias = TRUE) %>% 
  layer_dense(units = 1, activation = "relu", use_bias = TRUE)

# Train the model
mod_neat_ls <- deeptrafo(formula = y | snamIA(time) ~ 1 + snam(time), 
                         data = data.frame(y = y_train[,1], 
                                           time = X_train[,1]),
                         list_of_deep_models = list(snam = feature_net,
                                                    snamIA = feature_net_pos),
                         order = 30)

if(file.exists("weights_neat.hdf5")){
  
  mod_neat_ls$load_weights(filepath="weights_neat.hdf5", by_name = FALSE)
  
}else{
  
  # set.seed(42)
  # ranind <- sample(1:nrow(X_train))
  
  mod_neat_ls %>% fit(batch_size = 32L, epochs = 1000L,
                      validation_split = 0.1,
                      early_stopping = TRUE,
                      patience = 50,
                      view_metrics = FALSE, 
                      verbose = T
  )
  
  save_model_weights_hdf5(mod_neat_ls, filepath="weights_neat.hdf5")
  
}

# Make predictions on test set
(ls_neat <- - (mod_neat_ls %>% logLik(newdata = data.frame(time = X_test[,1],
                                                          y = y_test[,1])) / nrow(X_test)))

# plot per time (discrete)
# time_level = sort(as.vector(unique(X_train)))
pred_neat_ls <- mod_neat_ls %>% predict(newdata = data.frame(time = X_test[,1],
                                                             y = y_test[,1]))
df = data.frame(y_test = y_test, y_pred = pred_neat_ls, time = X_test)
# df$time_lev <- NA
# for(i in seq_along(time_level)) {
#   idx <- which(X_test[,1] == time_level[i])
#   df[idx, 'time_lev'] <- time_level[i]
# }
gg2 <- ggplot(data = df, aes(x = y_test, y = y_pred)) + 
  geom_point(aes(color = time), alpha = 0.9, size=2) + 
  theme_bw() + theme(text = element_text(size = 15.5)) + 
  xlab("True value") +
  ylab("Inverse flow values") +
  guides(alpha="none") + labs(color='Time') 

gg2

# ggsave(file = "yeast_res.pdf", width = 5, height = 4)
# ggsave(file = "yeast_res_2.pdf", width = 8, height = 7)

## pdf

plotdf <- expand.grid(time = seq(min(X_test), max(X_test), l=1000),
                      y = seq(quantile(y_test, 0.001), quantile(y_test, 0.999), l=1000))

pdf <- mod_neat_ls %>% predict(newdata = plotdf, type = "pdf")
dfpdf_neat <- data.frame(time = plotdf$time, 
                         dens_y = pdf[,1], 
                         y = plotdf$y)

gg3 <- ggplot() + 
  geom_joy(data = dfpdf_neat %>% filter(time %in% quantile(dfpdf_neat$time, seq(0,1,l=20))) %>% 
             mutate(
               dummy = "DRIFT              "
             ) %>% filter(time > 1), 
           aes(height = dens_y * 7, x = y + time/1000 * 7, y = time, 
               group = time, fill = time), 
           stat="identity", alpha = 0.7, colour = rgb(0,0,0,0.4)) +
  theme_bw() + facet_grid(~ dummy, switch="both") + ylab("Time") + xlab("") + 
  theme(panel.grid.major = element_blank(),
        panel.grid.minor = element_blank(),
        # panel.spacing = unit(-1.5, "lines"),
        strip.background = element_blank(),
        panel.border = element_blank(),
        text = element_text(size=17.5),
        strip.text = element_text(size=15.5),
        legend.position = "none",
        axis.text.x=element_blank(),
        axis.ticks.x=element_blank(),
        rect = element_rect(fill = "transparent"),
        axis.title.y=element_blank(),
        axis.text.y=element_blank(),
        axis.ticks.y=element_blank()) +
  scale_y_continuous(breaks = c(0, 250, 500, 750, 1000)) + 
  scale_x_continuous(expand = c(0, 0))

gg3

library(grid)
library(gridExtra)

g <- grid.arrange(gg1 + theme(plot.margin = unit(c(0,1,-1.5,0.2), "lines")), 
                  gg3 + theme(plot.margin = unit(c(0,1,-1.5,0.2), "lines")), widths = c(1.5,1),
                  top = textGrob("estimated conditional density", 
                                 vjust = 1, gp = gpar(ex = 1.2))) 
ggsave(file = "atm_matrix_fig.pdf", g, width = 6, height = 5)
