library(ccpdmp)
## Base R files for processing Zig-Zag sampler

## Generate discrete samples from zigzag skeleton
gen_samples <- function(positions, times, theta,
                        nsample = 10^3, burn = 1){

  if(is.null(dim(positions))) positions <- matrix(positions, nrow = 1)

  positions <- positions[,burn:length(times), drop = F]
  theta <- theta[,burn:length(times), drop = F]
  times <- times[burn:length(times)] - times[burn]
  nsteps <- length(times)

  Tmax <- times[nsteps]
  dt <- Tmax/(nsample+2)
  t = dt
  t0 = times[1]
  x0 = positions[,1]
  samples <- matrix(0, nrow = length(x0), ncol = nsample)
  sample_times <- rep(0, nsample)
  n <- 0

  for(i in 2:nsteps){
    x1 = positions[,i]
    t1 = times[i]
    theta0 = theta[,i-1]
    while(t + dt < t1 && n < nsample){
      n <- n+1
      t <- t + dt
      x_s <- x0 + (t-t0)*theta0
      samples[,n] <- x_s
      sample_times[n] <- t
    }
    x0 = x1; t0 = t1
  }
  return(list(samples = samples, sample_times =sample_times))
}

## Calculate the marginal mean based on the pdmp trajectory
path_marginal_mean <- function(pdmp, burnin = 1){
  ## Calc Marginal
  d <- nrow(pdmp$positions)-1
  in_model <- (pdmp$positions[d+1, ] == 1 & pdmp$thetas[d+1, ] == 0)
  times <- pdmp$times
  positions <- pdmp$positions
  thetas <- pdmp$thetas

  maxIter <- length(times)
  marg_mean <- rep(0, d)

  for( mi in 1:d){
    beta_mean <- 0
    total_time <- 0
    for(i in (burnin):(maxIter-1)){
      if(in_model[i]){
        tauv <- (times[i+1] - times[i])
        total_time <- total_time + tauv
        beta_mean = beta_mean + (tauv*positions[mi,i] + thetas[mi,i]*tauv^2/2)
        marg_mean[mi] = beta_mean/total_time
      }

    }
  }
  return(marg_mean)
}
path_marginal_moment2 <- function(pdmp, burnin = 1){
  ## Calc Marginal 2 mom
  d <- nrow(pdmp$positions)-1
  in_model <- (pdmp$positions[d+1, ] == 1 & pdmp$thetas[d+1, ] == 0)
  times <- pdmp$times
  positions <- pdmp$positions
  thetas <- pdmp$thetas

  maxIter <- length(times)
  marg_mean <- rep(0, d)

  for( mi in 1:d){
    beta_mean <- 0
    total_time <- 0
    for(i in (burnin):(maxIter-1)){
      if(in_model[i]){
        tv <- (times[i+1] - times[i])
        total_time <- total_time + tv
        x_t <- positions[mi,i]; v_t <- thetas[mi,i]
        beta_mean = beta_mean +
          (tv*x_t^2 + tv^2*v_t*x_t + tv^3*v_t^2/3)
        marg_mean[mi] = beta_mean/total_time
      }

    }
  }
  return(marg_mean)
}

## object for evaluating path quadrature estimate of logz(b)
phi <- function(xt){
  x <- xt[-length(xt)]; t <- xt[length(xt)]

  q <- target(x); q_0 <- temper(x)
  return(q$log_q - q_0$log_q)
}
## object for evaluating path quadrature
path_quad <- function(t, u, t_lower=0, t_upper=1){
  keep <- (t > t_lower) & (t < t_upper)
  if (sum(keep) > 0){
    t_uniq <- sort(unique(t[keep]))
    N_uniq <- length(t_uniq)
    u_bar <- rep(NA, N_uniq)
    for (i in 1:N_uniq){
      ok <- t == t_uniq[i]
      u_bar[i] <- mean(u[ok])
    }
    width <- c(t_uniq, t_upper) - c(t_lower, t_uniq)
    u_bar_avg <- (c(u_bar, u_bar[N_uniq]) + c(u_bar[1], u_bar))/ 2
    log_z <- c(0, cumsum(width*u_bar_avg))
    return(list(t=c(0, t_lower, t_uniq, t_upper, 1), log_z=c(log_z[1], log_z, log_z[length(log_z)])))
  }
  else
    return(list(t=c(0, 1), log_z=c(0, 0)))
}

## Function returning the gradient for - log q(x, t)
d_nlogq <- function(x, t, d_poly_coef){
  q <- target(x)
  q_0 <- temper(x)

  x_grads <- - t*q$d_log_q - (1-t)*q_0$d_log_q
  t_grad <- pracma::polyval(d_poly_coef, t) - q$log_q + q_0$log_q

  grads <- c(x_grads, t_grad)
  return(grads)
}

## Return the rates for each dimension evaluated at times tau_grid ahead
return_rates_hess <- function(x, theta, tau_grid, d_poly_coef){
  l_tau <- length(tau_grid)
  nx <- length(x)
  rates_eval <- matrix(0, nx, l_tau)

  ## Find the a, b and c terms for the rates on x_j
  q <- target(x[-nx]);  q_0 <- temper(x[-nx])

  a_q <- -q$d_log_q*theta[-nx]; a_q_0 <- -q_0$d_log_q*theta[-nx]
  b_q <- hess[2,]*abs(theta[-nx]); b_q_0 <- hess[1,]*abs(theta[-nx])

  t <- x[nx]
  a <- t*a_q + (1-t)*a_q_0;
  b <- theta[nx]*(a_q - a_q_0) + b_q*t + (1-t)*b_q_0
  c <- theta[nx]*(b_q - b_q_0)
  poly_coef_x <- rbind(c,b,a)

  ## Find the a, b and c terms for the rates on temp
  a_q <- -theta[nx]*q$log_q; a_q_0 <- theta[nx]*q_0$log_q
  b_q <- -theta[nx]*sum(theta[-nx]*q$d_log_q); b_q_0 <- theta[nx]*sum(theta[-nx]*q_0$d_log_q)
  c_q <- 0.5*sum(hess[2,]*(theta[-nx])^2)*abs(theta[nx]); c_q_0 <- 0.5*sum(hess[1,]*(theta[-nx])^2)*abs(theta[nx])
  a <- a_q + a_q_0; b <- b_q + b_q_0; c <- c_q + c_q_0
  poly_coef_t <- c(c,b,a)

  ## Evaluate the upper-bounding rate at the time points
  for( i in 1:(nx-1) ){
    rates_eval[i,] <- pracma::polyval(poly_coef_x[,i], tau_grid)
  }
  rates_eval[nx,] <- pracma::polyval(poly_coef_t, tau_grid) +
    theta[nx]*pracma::polyval(d_poly_coef, t + tau_grid*theta[nx])
  return(rates_eval)
}

## calculate time until inverse temp = 0 or = 1
time_to_bndry <- function(x, theta){
  if(theta == 0) {
    return(Inf)
  } else if( theta > 0){
    # if traveling to 1
    return( (1-x)/theta )
  } else {
    # if traveling to 0
    return(  -x/theta )
  }
}

## Tempered Zig-Zag algorithm
zigzag_temp <- function(max_events, x0, theta0 = c(rep(1, length(x0)-1), 0.1),
                        alpha = 0.2, tau_max = 1,
                        return_rates = return_rates_hess,
                        poly_order = 2, echo = FALSE,
                        poly_coef = c(0,0), nits_max = Inf, max_stochT = Inf){

  kappa_m <- function(t){
    return( exp( - pracma::polyval(poly_coef, t) ) )
  }

  poly_order_temp <- length(poly_coef)

  # derivative terms i.e. sum(i*a_it^{i-1})
  d_poly_coef <- c(poly_order_temp:1 - 1)*poly_coef
  d_poly_coef <- d_poly_coef[-poly_order_temp]

  ## Init
  t = 0; eps = 1e-10; nits <- 0
  x = x0; theta = theta0; nvel <- length(x)

  thetas <- positions <- matrix(0, nrow = nvel, ncol = max_events);
  times = rep(0,max_events); thetas[,1] <- theta; positions[,1] <- x;

  num_evts = 1
  event = FALSE

  ## Rates to jump from dirac
  if(abs(theta0[nvel]) < 1e-10){
    rate_jump_from_spike <- (1-alpha)/(2*alpha)
    temp_scale <- 1
  } else {
    rate_jump_from_spike <- (1-alpha)/(2*alpha)*abs(theta0[nvel])
    temp_scale <- abs(theta0[nvel])
  }

  # Simulate times
  taus = rep(Inf, nvel);  u_s = rexp(nvel);  f_s = rep(Inf, nvel)
  temp_to_bndry <- time_to_bndry(x[nvel], theta[nvel])
  tau_grid <- seq(0, to = min(tau_max,temp_to_bndry),  length.out = poly_order + 1)

  ## Evaluates the (upper-bounded) rate
  rates <- return_rates(x, theta, tau_grid, d_poly_coef)

  ## Simulate from the upper-bounding rate function using ccpdmp package
  for( i in 1:nvel){
    tus = sim_rate_poly(eval_times = tau_grid, eval_rates = rates[i,], poly_order)
    taus[i] = tus$t
    u_s[i] = tus$u
    f_s[i] = tus$f_evall
  }

  # Simulate time to reintroduce inverse temp
  sampling_spike <- if(abs(theta[nvel]) < eps) TRUE else FALSE ## If b=1

  if(sampling_spike){
    taus[nvel] <- rexp(1)/rate_jump_from_spike
  }

  while(num_evts < max_events & nits < nits_max){

    mini_x <- which.min(taus)
    tau <- taus[mini_x]

    ## If temperature hits boundary it is an event
    update_temp <- if(abs(tau - taus[nvel]) < eps) TRUE else FALSE

    x = x + tau*theta
    t = t + tau

    if(update_temp){
      b_hit_1 <- abs(x[nvel]-1) < eps
      b_hit_0 <- abs(x[nvel]) < eps

      ## If hit 0 flip
      if(b_hit_0){
        theta[nvel] = abs(theta0[nvel])
        event = TRUE
      }
      ## If hit spike release or stick
      if( b_hit_1 ){
        if(sampling_spike){
          theta[nvel] = -abs(theta0[nvel])
          sampling_spike <- FALSE
        } else {
          theta[nvel] = 0
          sampling_spike <- TRUE
        }
        event = TRUE
      }
      ## If regular event
      if(!b_hit_1 & !b_hit_0){
        if(u_s[nvel] < 1e-10){

          ## Evaluate the rate
          grad <- d_nlogq(x[-nvel], x[nvel], d_poly_coef)
          rate <- grad[nvel]*theta[nvel]

          ## Thinning
          acc_prb <- rate/f_s[nvel]
          if(acc_prb > 1.0001){
            print(paste("Invalid thinning on inverse temp, thinning prob:",acc_prb))
          }
          if(runif(1) <= acc_prb){
            theta[nvel] = -theta[nvel]
            reintro_theta = theta[nvel]
            event = TRUE
          }
        }
      }

    } else {
      ## Proceed with regular event

      if(u_s[mini_x] < 1e-10){

        ## Calculate rate
        grad <- d_nlogq(x[-nvel], x[nvel], d_poly_coef)
        rate <- grad[mini_x]*theta[mini_x]

        acc_prb <- rate/f_s[mini_x]
        if(acc_prb > 1.0001){
          print(paste("Invalid thinning on x, thinning prob:",acc_prb))
        }
        if(runif(1) <= acc_prb){
          theta[mini_x] = -theta[mini_x]
          event = TRUE
        }
      } else{
        nits = nits - 1
      }
    }

    if(event){
      # Store event info
      num_evts = num_evts + 1
      thetas[,num_evts] <- theta; positions[,num_evts] <- x;
      times[num_evts] = t

      event = FALSE
      if(echo & (num_evts %% 100 == 0)){
        print(num_evts)
      }

      # Simulate times
      temp_to_bndry <- time_to_bndry(x[nvel], theta[nvel])
      tau_grid <- seq(0, to = min(tau_max,temp_to_bndry),  length.out = poly_order + 1)

      rates <- return_rates(x, theta, tau_grid, d_poly_coef)
      for( i in 1:nvel){
        tus = sim_rate_poly(eval_times = tau_grid, eval_rates = rates[i,], poly_order)
        taus[i] = tus$t
        u_s[i] = tus$u
        f_s[i] = tus$f_evall
      }

      # Simulate time for temp
      if(sampling_spike){
        taus[nvel] <- rexp(1)/rate_jump_from_spike
      }

    } else {
      # If there was no event

      # Re-simulate times for all taus less than zero:
      temp_to_bndry <- time_to_bndry(x[nvel], theta[nvel])
      tau_grid <- seq(0, to = min(tau_max,temp_to_bndry),  length.out = poly_order + 1)

      # Adjust simulated times
      taus <- taus-tau
      update_rates <- which(taus <= 0)

      rates <- return_rates(x, theta, tau_grid, d_poly_coef)
      for( j in update_rates){
        tus = sim_rate_poly(eval_times = tau_grid, eval_rates = rates[j,], poly_order)
        taus[j] = tus$t
        u_s[j] = tus$u
        f_s[j] = tus$f_evall
      }

      # Simulate time for temp
      if(sampling_spike){
        taus[nvel] <- rexp(1)/rate_jump_from_spike
      }
    }
    if(times[num_evts] > max_stochT){
      break
    }
    nits = nits +1

  }
  if(num_evts < max_events){
    return (list(positions=positions[,1:num_evts],thetas=thetas[,1:num_evts],times=times[1:num_evts],
                 nits = nits, poly_coef = poly_coef, alpha = alpha))
  } else{
    return (list(positions=positions,thetas=thetas,times=times, nits = nits,
                 poly_coef = poly_coef, alpha = alpha))
  }
}

get_sample_t <- function(zz, time){
  ## find segment where it occurs
  ind <- max(which(zz$times<time))
  t <- time - zz$times[ind]
  x <- zz$positions[,ind] + t*zz$thetas[,ind]
  return(list(x=x,theta=zz$theta[,ind]))
}
get_neighbours <- function(ind, temps){
  if(ind == 1){
    neighbours <- ind+1
  }
  if(ind == length(temps)){
    neighbours <- ind-1
  }
  if((ind >1) & (ind < length(temps))) {
    neighbours <- c(ind-1, ind+1)
  }
  return(neighbours)
}
l_targ_temp <- function(temp, logTT){
  return(temp*logTT[1] + (1-temp)*logTT[2])
}

pt <- function(xinit = c(0,0),
               temp,
               stoch_time = 1,
               Nit = 10, even_odd_kernel = F) {

  print(Sys.time())
  t1 = Sys.time()
  x.dim <- length(xinit)
  n.temp <- length(temp)
  out <- matrix(NA, nrow = Nit, ncol = x.dim)
  n_eval <- rep(0,Nit)

  all_chains <- t(sapply(1:n.temp, function(s) xinit))
  all_theta <- t(sapply(1:n.temp, function(s) sample(c(1,-1), size = x.dim, replace = T)))

  log_targ_temp <- matrix(0, nrow = n.temp, ncol = 2)
  n_grad_eval <- 0

  even_ind <- 2*c(1:floor(n.temp/2))
  odd_ind <- c(1:n.temp)[-even_ind]
  non_rev_seq <- c(odd_ind,even_ind)
  non_rev_seq <- non_rev_seq[-which(non_rev_seq == n.temp)]

  prop_swap = NULL
  for (i in 1 : Nit) {

    ## Move (**)
    for(  j in 1:n.temp ){
      ## Run the ZigZag for stoch_time units.
      x <- all_chains[j, ]
      theta <- all_theta[j,]
      runZZ <- zigzag_temp(max_events = 1e4,
                           x0 = c(x,temp[j]),
                           theta0 = c(theta,0),
                           alpha = 1, tau_max = 1,
                           poly_order = 2, echo = F,
                           poly_coef = rep(0,2),
                           max_stochT = stoch_time)

      samp <- get_sample_t(runZZ, stoch_time)
      all_chains[j, ] <- samp$x[1:length(x)]
      all_theta[j,] <- samp$theta[1:length(x)]
      n_grad_eval <- n_grad_eval + runZZ$nits

      ## Eval gradient at new position
      log_targ_temp[j,1] <- target(samp$x[1:length(x)])$log_q
      log_targ_temp[j,2] <- temper(samp$x[1:length(x)])$log_q

    }
    ## Move (*)
    for(  j in 1:(n.temp-1) ){
      ## propose swaps
      if(even_odd_kernel){
        ind <- non_rev_seq[j]
        swap_n <- get_neighbours(ind, temp)
        swap_ind <- max(swap_n)
      } else {
        ind <- sample(1:n.temp, size = 1)
        swap_ind <- sample(1:n.temp, size = 1)
      }

      x <- all_chains[ind, ]; x_swap <- all_chains[swap_ind, ];
      theta <- all_theta[ind, ]; theta_swap <- all_theta[swap_ind, ];
      ltt_x <- log_targ_temp[ind,]; ltt_swap <- log_targ_temp[swap_ind,]

      l_prop <- l_targ_temp(temp[swap_ind], ltt_x) +
        l_targ_temp(temp[ind], ltt_swap)

      l_curr <- l_targ_temp(temp[ind], ltt_x) +
        l_targ_temp(temp[swap_ind], ltt_swap)
      acc = 0
      if(runif(1) < exp(l_prop - l_curr)){
        all_chains[ind, ] <- x_swap
        all_chains[swap_ind, ] <- x

        all_theta[ind, ] <- theta_swap
        all_theta[swap_ind, ] <- theta

        log_targ_temp[ind,] <- ltt_swap
        log_targ_temp[swap_ind,] <- ltt_x
        acc = 1
      }
      prop_swap = rbind(prop_swap, c(ind, swap_ind, acc))

    }
    out[i,] <- all_chains[n.temp,]
    n_eval[i] <- n_grad_eval
  }
  t2 = Sys.time()
  print(t2)
  return(list(out = out, n_eval=n_eval, n_grad_eval=n_grad_eval, prop_swap=prop_swap, t1=t1,t2=t2))
}
