## Read argument
args<-commandArgs(TRUE)
print("begin")
print(args)
print("end")
N_index <- eval( parse(text=args[1]) )

library(rstan)

load("GMM/gmm_data.Rdata")
library(rstan)
args <- c(1:20)
alpha_1 <- 0
rep_1 <- args[N_index]
print(args[N_index])

###### Config ####
iters <- 30000 ## Number of samples / number of events
warmup <- 20000  ## Number samples/events to burn
poly_order = 3

set.seed(N_index)
x_init <- rnorm(2, mean = 5, sd = 5)
theta_init <- c(1,1)


## ct HMC
stan_fit_eval <- stan("GMM/ct_hmc_eval.stan", data = dat, iter = 1, chains = 1)

## Setup ct-ZZ
target <- function(x){
  stan_ev <- grad_log_prob(stan_fit_eval, c(x,1))
  d_log_q <- as.numeric(stan_ev)[1:2]
  log_q <- attr(stan_ev, "log_prob")

  return(list(log_q = log_q, d_log_q = d_log_q))
}
temper <- function(x){
  stan_ev <- grad_log_prob(stan_fit_eval, c(x,0))
  d_log_q <- as.numeric(stan_ev)[1:2]
  log_q <- attr(stan_ev, "log_prob")

  return(list(log_q = log_q, d_log_q = d_log_q))
}

source("temper_zigzag_hess.R")
mean_diffs <- (apply(mu_mat, 1, max) - apply(mu_mat, 1, min))^2/4

C_1 <- 1/dat$sigma[1]
hess_q1 <- abs(C_1*(1 + C_1*mean_diffs))
hess_q0 <- 1/dat$sigma0
(hess <- abs(matrix(c(hess_q0,
                      hess_q1), 2,2, byrow = T)))
rownames(hess) <- c("q0", "q1")


## ct-zigzag
log_bf <- 0
if(alpha_1 == 1){
  ## Warm
  log_bf <- 0
  set.seed(N_index);zigzag_fit <- zigzag_temp(max_events = warmup, x0 = c(x_init,1),
                                              theta0 = c(theta_init,0),
                                              alphas = c(0, 1), tau_max = 1,
                                              poly_order = 3, echo = F,
                                              poly_coef = rep(0,2),
                                              nits_max = Inf)
  ## Run ctZZ
  x_init <- tail(t(zigzag_fit$positions),1)
  theta_init <- tail(t(zigzag_fit$thetas),1)
  log_bf <- 0
  set.seed(N_index);zigzag_fit_ada <- zigzag_temp(max_events = iters,
                                                  x0 = x_init,
                                                  theta0 = theta_init,
                                                  alphas = c(0, 1),
                                                  tau_max = 1,
                                                  poly_order = 3, echo = F,
                                                  poly_coef = rep(0,2),
                                                  nits_max = Inf)

  zigzag_samples <- gen_samples(zigzag_fit_ada$positions, zigzag_fit_ada$times,
                                nsample = iters, burn = 1)
  save(zigzag_fit_ada, zigzag_samples, file = paste0("GMM/gmm_alpha_",alpha_1,"_iter_",rep_1,".Rdata"))
} else {
  ## Warm
  log_bf <- 0
  set.seed(N_index);zigzag_fit <- zigzag_temp(max_events = warmup, x0 = c(x_init,.1),
                                              theta0 = c(theta_init,1),
                                              alphas = c(0, 0), tau_max = 1,
                                              poly_order = 3, echo = F,
                                              poly_coef = rep(0,3),
                                              nits_max = Inf)
  ## Run ctZZ
  x_init <- tail(t(zigzag_fit$positions),1)
  theta_init <- tail(t(zigzag_fit$thetas),1)
  log_bf <- 0
  set.seed(N_index);zigzag_fit_ada <- zigzag_temp(max_events = iters,
                                                  x0 = x_init,
                                                  theta0 = theta_init,
                                                  alphas = c(0, 0),
                                                  tau_max = 1,
                                                  poly_order = 3, echo = F,
                                                  poly_coef = rep(0,3),
                                                  nits_max = Inf)
  zigzag_samples <- gen_samples(zigzag_fit_ada$positions, zigzag_fit_ada$times,
                                nsample = warmup, burn = 1)
  triangle <- sapply(1:warmup, function(i){
    g_1 <- grad_log_prob(stan_fit_eval, c(zigzag_samples$samples[1:2,i],1))
    g_0 <- grad_log_prob(stan_fit_eval, c(zigzag_samples$samples[1:2,i],0))

    return(-attr(g_1,"log_prob") + attr(g_0,"log_prob"))
  })
  w_1_zz <- triangle/(exp(triangle) - 1);
  w_1_zz_norm <- w_1_zz/sum(w_1_zz)

  zigzag_samplesl <- gen_samples(zigzag_fit_ada$positions, zigzag_fit_ada$times,
                                 nsample = iters, burn = 1)
  triangle <- sapply(1:iters, function(i){
    g_1 <- grad_log_prob(stan_fit_eval, c(zigzag_samplesl$samples[1:2,i],1))
    g_0 <- grad_log_prob(stan_fit_eval, c(zigzag_samplesl$samples[1:2,i],0))

    return(-attr(g_1,"log_prob") + attr(g_0,"log_prob"))
  })
  w_1_zzl <- triangle/(exp(triangle) - 1);
  w_1_zz_norml <- w_1_zzl/sum(w_1_zzl)

  save(zigzag_fit_ada, zigzag_samples,w_1_zz,w_1_zz_norm,
       zigzag_samplesl,w_1_zzl,w_1_zz_norml,
       file = paste0("GMM/gmm_alpha_",0,"_iter_",rep_1,".Rdata"))
}
