load("gmm/gmm_data.Rdata")
library(rstan)
## ct HMC
stan_fit_eval <- stan("gmm/model_eval.stan", data = dat, iter = 0)

## Setup ct-ZZ
target <- function(x){
  stan_ev <- grad_log_prob(stan_fit_eval, c(x,1))
  d_log_q <- as.numeric(stan_ev)[1:2]
  log_q <- attr(stan_ev, "log_prob")

  return(list(log_q = log_q, d_log_q = d_log_q))
}
temper <- function(x){
  stan_ev <- grad_log_prob(stan_fit_eval, c(x,0))
  d_log_q <- as.numeric(stan_ev)[1:2]
  log_q <- attr(stan_ev, "log_prob")

  return(list(log_q = log_q, d_log_q = d_log_q))
}

## only take first 20,000 iters
np <- 3e2
x_v <- seq(from = -1, to = 12, length.out = np)
y_v <- seq(from = -1, to = 12, length.out = np)
xy <- expand.grid(x=x_v, y=y_v)
f <- function(x){exp(target(x)$log_q)}
z <- matrix(apply(as.matrix(xy), 1, f), length(x_v), length(y_v))
par(mfrow = c(1,2), mar = c(2,2,1,2))
image(x_v, y_v, z, las=1, xlab = "x1", ylab = "x2")
load("gmm/Cluser Code/ZZ/gmm_alpha_1_iter_1.Rdata")
lines(t(zigzag_fit_ada$positions)[1:30000,1:2],
      col = rgb(red = 0, green = 0, blue = 1, alpha = 0.5))
#legend('topleft', legend = c('Zig Zag'), col = rgb(red = 0, green = 0, blue = 1, alpha = 0.5), lwd = 1)

colnames(z) = y_v
rownames(z) = x_v

library(tidyverse)
dat = as.data.frame(z) %>%
  rownames_to_column(var="x_v") %>%
  gather(y_v, value, -x_v) %>%
  mutate(y_v=as.numeric(y_v),
         x_v=as.numeric(x_v),
         value_range = cut(value, 8))

plot_1 <-
  ggplot(dat, aes(x_v, y_v, fill=value_range)) +
  geom_raster(show.legend = FALSE) +
  scale_fill_manual(values=colorRampPalette(c("white","black"))(10)) +
  theme_bw()+ theme(axis.title.x=element_blank(), axis.title.y=element_blank())

np <- 3e2
x_v <- seq(from = -1, to = 12, length.out = np)
y_v <- seq(from = -1, to = 12, length.out = np)
xy <- expand.grid(x=x_v, y=y_v)
f <- function(x){exp(temper(x)$log_q)}
z <- matrix(apply(as.matrix(xy), 1, f), length(x_v), length(y_v))
par(mfrow = c(1,2), mar = c(2,2,1,2))
dat = as.data.frame(z) %>%
  rownames_to_column(var="x_v") %>%
  gather(y_v, value, -x_v) %>%
  mutate(y_v=as.numeric(y_v),
         x_v=as.numeric(x_v),
         value_range = cut(value, 8))

df_z <- data.frame(x = t(zigzag_fit_ada$positions)[1:30000,1],
                   y = t(zigzag_fit_ada$positions)[1:30000,2],
                   value_range = rep(dat$value_range[1],30000) )

plot_1z <- plot_1 + geom_path(data=df_z,aes(x=x,y=y, fill=value_range, colour = 1), show.legend = F)+
  scale_colour_gradient(low = "darkred", high = "darkred", na.value = NA)


load("gmm/Cluser Code/ZZ/gmm_alpha_0.7_iter_1.Rdata")
df_z <- data.frame(x = t(zigzag_fit_ada$positions)[1:30000,1],
                   y = t(zigzag_fit_ada$positions)[1:30000,2],
                   value_range = rep(dat$value_range[1],30000),
                   col = zigzag_fit_ada$positions[3,])

plot_2z <- plot_1 + geom_path(data=df_z,aes(x=x,y=y, fill=value_range, colour=col)) +
  scale_colour_gradient(low = "pink", high = "darkred", na.value = NA)
plot_2z

library(patchwork)
library(latex2exp)
(plot_1z | plot_2z)+ guides(colour=guide_legend(title=TeX(r"($\beta$)"), keyheight = 2))

# df <- data.frame(x=x_v, y=y_v, inv_temp = z)
# ggplot(df, aes(x=x,y=y))
# library(ggmap)
# ggimage(z)
