#### Visualization the density result 
# We input X, Y, Z mat, and output the 

#### Usage: VisDens
#### Input: 
     #X_grid: X grid vector
     #Y_grid: Y grid vector 
     #Z     : density matrix/lqd matrix list

library(ggplot2)
library(tidyr)
library(tibble)
library(hrbrthemes)
library(dplyr)
library(scico)
library(ggridges)
VisDensIrreg<- function(X.mat.list, Y.mat.list, Z.mat.list, label=NULL, Scale = NULL, min = 0, max = 1){
  plot.df = data.frame()
  Ngrid = 50
  x_list = seq(min(X.mat.list[[1]]), max(X.mat.list[[1]]), length.out = Ngrid)
  y_list = seq(min(Y.mat.list[[1]]), max(Y.mat.list[[1]]), length.out = Ngrid)
  for(i in 1:length(Z.mat.list)){
    # interpolate 
    interp_xy = akima::interp(x = as.vector(X.mat.list[[i]]), 
                              y = as.vector(Y.mat.list[[i]]), 
                              z = as.vector(Z.mat.list[[i]]), 
                              xo = x_list,
                              yo = y_list,
                              linear = TRUE,
                              duplicate = "mean")
    x_list = interp_xy$x
    y_list = interp_xy$y
    z.mat = interp_xy$z
    
    if(is.null(label) == TRUE){
      plot.df = rbind(plot.df, data.frame(X = rep(x_list, times = Ngrid), 
                                          Y = rep(y_list, each = Ngrid),
                                          Z = as.vector(z.mat),
                                          label = i))
    } else if(is.null(Scale) == TRUE){
      plot.df = rbind(plot.df, data.frame(X = rep(x_list, times = Ngrid), 
                                          Y = rep(y_list, each = Ngrid),
                                          Z = as.vector(z.mat),
                                          label = paste(i,'Sample, mean = ', '(', round(label[[i]][1],2), ',', round(label[[i]][2],2), ')', sep = ''))
                                          )
    } else{
      plot.df = rbind(plot.df, data.frame(X = rep(x_list, times = Ngrid), 
                                          Y = rep(y_list, each = Ngrid),
                                          Z = as.vector(z.mat),
                                          label = paste(i,'Sample, mean = ', '(', round(label[[i]][1],2), ',', round(label[[i]][2],2), ')', ' scale = ', round(Scale[[i]], 2), sep = ''))
      )
    }
    
  }
  
  gg = ggplot(plot.df) + 
    geom_tile(aes(x = X, y = Y, fill = Z)) + 
    facet_wrap(~label, ncol = 3) +
    scale_fill_gradient2(limits = c(min, max), low= 'blue', mid = 'yellow', high = 'red', midpoint = mean(c(min,max)))
  gg
  
}
VisDens  <- function(X.mat.list, Y.mat.list, Z.mat.list, label=NULL, labelorder=NULL, min = 0, max = 1, ncol = 6, title = "Density", xlab = "maximum temperature", ylab = "temperature range"){
  n = length(Z.mat.list)
  if(length(X.mat.list) == 1){
    X.mat.list = lapply(1:n, function(o) X.mat.list)
  }
  if(length(Y.mat.list) == 1){
    Y.mat.list = lapply(1:n, function(o) Y.mat.list)
  }
  if(is.null(label)){
    label = 1:n
  }
  plot.df = data.frame()
  for(i in 1:n){
    plot.df = plot.df %>% rbind(data.frame(X = as.vector(X.mat.list[[i]]), 
                                Y = as.vector(Y.mat.list[[i]]), 
                                Density = as.vector(Z.mat.list[[i]]),
                                label = label[i]))
  }
  if(!is.null(labelorder)){
    plot.df$label = factor(plot.df$label, levels = labelorder)
  }
  gg = ggplot(plot.df) + 
    geom_tile(aes(x = X, y = Y, fill = Density)) + 
    facet_wrap(~label, ncol = ncol) +
    scale_fill_gradient2(limits = c(min, max), low = 'blue', mid = 'yellow', high = 'red2', midpoint = mean(c(min,max)))+
    #scale_fill_gradientn(colours = rev(rainbow(6)), limits = c(min, max))
    #scico::scale_fill_scico(limits = c(min, max), palette = "lajolla") 
    xlab(xlab) + 
    ylab(ylab) + 
    ggtitle(title) + 
    theme(plot.title = element_text(hjust = 0.5))
    
  gg
  
}
VisDens2 <- function(X.mat.list, Y.mat.list, Z.mat.list, label=NULL, level=NULL, min = 0, max = 1, ncol=5, title = "Density difference", xlab = "maximum temperature", ylab = "temperature range"){
  n = length(Z.mat.list)
  if(length(X.mat.list) == 1){
    X.mat.list = lapply(1:n, function(o) X.mat.list[[1]])
  }
  if(length(Y.mat.list) == 1){
    Y.mat.list = lapply(1:n, function(o) Y.mat.list[[1]])
  }
  if(is.null(label)){
    label = 1:n
  }
  plot.df = data.frame()
  for(i in 1:n){
    plot.df = plot.df %>% rbind(data.frame(X = as.vector(X.mat.list[[i]]), 
                                           Y = as.vector(Y.mat.list[[i]]), 
                                           density = as.vector(Z.mat.list[[i]]),
                                           label = label[i]))
  }
  if(!is.null(level)){
    plot.df$label = factor(plot.df$label, levels = level)
  }
  gg = ggplot(plot.df) + 
    geom_tile(aes(x = X, y = Y, fill = density)) + 
    facet_wrap(~label, ncol = ncol) +
    #scale_fill_gradient(limits = c(min, max), low = 'white', high = 'red')
    #scale_fill_gradientn(colours = rev(rainbow(6)), limits = c(min, max))
    #scale_fill_distiller(limits = c(min, max), palette = "RdPu" )
    scico::scale_fill_scico(limits = c(min, max), palette = "lajolla", direction = -1) + 
    xlab(xlab) + 
    ylab(ylab) + 
    ggtitle(title) + 
    theme(plot.title = element_text(hjust = 0.5))
  
  gg
  
}
VisDens3  <- function(X.mat.list, Y.mat.list, Z.mat.list, label=NULL, labelorder=NULL, min = 0, max = 1, ncol = 6, title = "Density", xlab = "maximum temperature", ylab = "temperature range"){
  n = length(X.mat.list)
  if(is.null(label)){
    label = 1:n
  }
  plot.df = data.frame()
  for(i in 1:n){
    plot.df = plot.df %>% rbind(data.frame(X = as.vector(X.mat.list[[i]]), 
                                           Y = as.vector(Y.mat.list[[i]]), 
                                           Density = as.vector(Z.mat.list[[i]]),
                                           label = label[i]))
  }
  if(!is.null(labelorder)){
    plot.df$label = factor(plot.df$label, levels = labelorder)
  }
  gg = ggplot(plot.df) + 
    geom_tile(aes(x = X, y = Y, fill = Density)) + 
    facet_wrap(~label, ncol = ncol) +
    #scale_fill_gradient2(limits = c(min, max), low = 'blue', mid = 'yellow', high = 'red', midpoint = mean(c(min,max))*0.6)+
    scale_fill_gradientn(colours = rev(rainbow(8)), limits = c(min, max)) +
    #scico::scale_fill_scico(limits = c(min, max), palette = "lajolla") 
    #scico::scale_fill_scico(palette = "bilbao") + 
    xlab(xlab) + 
    ylab(ylab) + 
    ggtitle(title) + 
    theme(plot.title = element_text(hjust = 0.5))
  
  gg
  
}
VisDens4 <- function(X.mat.list, Y.mat.list, Z.mat.list, label=NULL, min = 0, max = 1,ncol=5,title = "Density difference", xlab = "maximum temperature", ylab = "temperature range"){
  n = length(X.mat.list)
  if(is.null(label)){
    label = 1:n
  }
  plot.df = data.frame()
  for(i in 1:n){
    plot.df = plot.df %>% rbind(data.frame(X = as.vector(X.mat.list[[i]]), 
                                           Y = as.vector(Y.mat.list[[i]]), 
                                           Diff = as.vector(Z.mat.list[[i]]),
                                           label = label[i]))
  }
  gg = ggplot(plot.df) + 
    geom_tile(aes(x = X, y = Y, fill = Diff)) + 
    facet_wrap(~label, ncol = ncol) +
    scale_fill_gradient2(limits = c(min, max), mid = 'white', high = 'red', low='blue', midpoint = mean(c(min,max))) +
    #scale_fill_gradientn(colours = cm.colors(6), limits = c(min, max)) + 
    #scale_fill_distiller(limits = c(min, max), palette = "RdPu" )
    #scico::scale_fill_scico(limits = c(min, max), palette = "lajolla") + 
    xlab(xlab) + 
    ylab(ylab) + 
    ggtitle(title) + 
    theme(plot.title = element_text(hjust = 0.5))
  
  gg
  
}
VisDens6 <- function(X.mat.list, Y.mat.list, Z.mat.list, label=NULL, level=NULL, min = 0, max = 1, ncol=5, title = "Density difference", xlab = "maximum temperature", ylab = "temperature range"){
  n = length(Z.mat.list)
  if(length(X.mat.list) == 1){
    X.mat.list = lapply(1:n, function(o) X.mat.list[[1]])
  }
  if(length(Y.mat.list) == 1){
    Y.mat.list = lapply(1:n, function(o) Y.mat.list[[1]])
  }
  if(is.null(label)){
    label = 1:n
  }
  plot.df = data.frame()
  for(i in 1:n){
    plot.df = plot.df %>% rbind(data.frame(X = as.vector(X.mat.list[[i]]), 
                                           Y = as.vector(Y.mat.list[[i]]), 
                                           density = as.vector(Z.mat.list[[i]]),
                                           label = label[i]))
  }
  if(!is.null(level)){
    plot.df$label = factor(plot.df$label, levels = level)
  }
  gg = ggplot(plot.df) + 
    geom_raster(aes(x = X, y = Y, fill = density), interpolate = TRUE) + 
    facet_wrap(~label, ncol = ncol) +
    #scale_fill_gradient(limits = c(min, max), low = 'white', high = 'red')
    #scale_fill_gradientn(colours = rev(rainbow(6)), limits = c(min, max))
    #scale_fill_distiller(limits = c(min, max), palette = "RdPu" )
    scico::scale_fill_scico(limits = c(min, max), palette = "lajolla") + 
    xlab(xlab) + 
    ylab(ylab) + 
    ggtitle(title) + 
    theme(plot.title = element_text(hjust = 0.5))
  
  gg
  
}



VisDens5 <- function(X.mat.list, Y.mat.list, Z.mat.list, Yname = 'predictor', label=NULL, level=NULL, min = 0, max = 1, ncol=5, title = "Density difference", xlab = "maximum temperature", ylab = "temperature range"){
  n = length(Z.mat.list)
  if(length(X.mat.list) == 1){
    X.mat.list = lapply(1:n, function(o) X.mat.list[[1]])
  }
  if(length(Y.mat.list) == 1){
    Y.mat.list = lapply(1:n, function(o) Y.mat.list[[1]])
  }
  if(is.null(label)){
    label = 1:n
  }
  plot.df = data.frame()
  for(i in 1:n){
    plot.df = plot.df %>% rbind(data.frame(X = as.vector(X.mat.list[[i]]), 
                                           Y = as.vector(Y.mat.list[[i]]), 
                                           density = as.vector(Z.mat.list[[i]]),
                                           label = label[i]))
  }
  if(!is.null(level)){
    plot.df$label = factor(plot.df$label, levels = level)
  }
  gg = ggplot(plot.df, 
              aes(x = X, y = Y, height = density, group = Y, fill = Y)) + 
    geom_ridgeline() + 
    facet_wrap(~label, ncol = ncol) +
    scico::scale_fill_scico(limits = c(min, max), palette = "lajolla", name = Yname) + 
    xlab(xlab) + 
    ylab(ylab) + 
    ggtitle(title) + 
    theme(plot.title = element_text(hjust = 0.5)) 

  gg
    
    # ggplot(plot.df) + 
    # geom_tile(aes(x = X, y = Y, fill = density)) + 
    # facet_wrap(~label, ncol = ncol) +
    # #scale_fill_gradient(limits = c(min, max), low = 'white', high = 'red')
    # #scale_fill_gradientn(colours = rev(rainbow(6)), limits = c(min, max))
    # #scale_fill_distiller(limits = c(min, max), palette = "RdPu" )
    # scico::scale_fill_scico(limits = c(min, max), palette = "lajolla") + 
    # xlab(xlab) + 
    # ylab(ylab) + 
    # ggtitle(title) + 
    # theme(plot.title = element_text(hjust = 0.5))

  
}

VisDensSlice <- function(model, xout, sliceIndex = c(27, 39, 1, 14), dens_min = 0, dens_max = 0.002, Yname='winter temperature', xlab="rho", ylab='xout', title='Regression for Selected Slices', outOption='heat'){
  # Visualization for SWW for selected slice
  # model: model output from radon_regression
  # xout: x out value corresponding to model$dens
  # sliceIndex: index of model$theta_grid, chosen slice to be visualized 
  # dens_min, dens_max, xlab, ylab - arguments used for plot output
  sliceLabel = model$theta_grid[sliceIndex]
  x = lapply(model$df_frechet, function(obj) obj$dens)
  sliceDens = simplify2array(lapply(model$df_frechet, function(obj) obj$dens))
  Z.grid = list()
  for (i in 1:length(sliceLabel)){
    slice_index_i = sliceIndex[i]
    if (sliceLabel[i]>180){
      sliceLabel[i] = sliceLabel[i]-180
      slice_out = sliceDens[,model$N:1,slice_index_i]
    }else{
      slice_out = sliceDens[,,slice_index_i]
    }
    Z.grid[[i]] = slice_out
  }

  if (outOption=='heat'){
    gg = VisDens2(list(kron(ones(model$n,1), t(model$rho_grid))),
             list(kron(xout, ones(1,model$N))), 
             Z.grid, 
             label = round(sliceLabel/45) * 45, 
             min = dens_min, 
             max = dens_max, 
             ncol = length(sliceIndex), 
             title = title, 
             xlab = xlab, 
             ylab = ylab
    )
  }else{
    gg= VisDens5(list(kron(ones(model$n,1), t(model$rho_grid))),
             list(kron(xout, ones(1,model$N))), 
             Z.grid, 
             Yname = Yname,
             label = round(sliceLabel/45) * 45, 
             min = dens_min, 
             max = dens_max, 
             ncol = length(sliceIndex), 
             title = title, 
             xlab = xlab, 
             ylab = ylab
    )
  }
  gg
}


# library(plotly)
# 
# set.seed(123)
# 
# df_cohort_master <- data.frame(
#   rho = rep(seq(-3,3, length.out = 50), times = 30),
#   mu = rep(0.1 * runif(30,0,30), each = 50)
# )
# df_cohort_master$dens = dnorm(df_cohort_master$rho, mean=df_cohort_master$mu)
# 
# plot_ly(df_cohort_master,
#         x = ~mu, 
#         y = ~rho, 
#         z = ~dens, 
#         type = 'scatter3d', 
#         mode = 'lines',
#         fill = 'tozeroy',
#         split = ~mu) %>%
#   layout(
#     title = "3D Scatter plot", 
#     scene = list(
#       camera = list(
#         eye = list(x = 1, y = 2, z = 2)
#       )
#     )
#   )
# 
# 
# library(ggplot2)
# library(ggridges)
# 
# ggplot(diamonds, aes(x = price, y = cut, fill = cut)) +
#   geom_density_ridges(scale = 2, alpha = 0.7) +
#   scale_fill_brewer(guide = guide_legend(reverse = TRUE)) +
#   scale_y_discrete(expand = c(0.01, 0)) +
#   theme_ridges(center = TRUE)
# 
# ggplot(df_cohort_master, 
#        aes(x = rho, y = mu, height = dens, group = mu, fill = mu)) + 
#   geom_ridgeline()
# 
# +
#   scale_fill_brewer(guide = guide_legend(reverse = TRUE)) +
#   scale_y_discrete(expand = c(0.01, 0)) +
#   theme_ridges(center = TRUE)
#   
# 
# 
VisDens2_sm <- function(X.mat.list, Y.mat.list, Z.mat.list, bandw = 0, label=NULL, level=NULL, min = 0, max = 1, ncol=5, title = "Density difference", xlab = "maximum temperature", ylab = "temperature range"){
  n = length(Z.mat.list)
  if(length(X.mat.list) == 1){
    X.mat.list = lapply(1:n, function(o) X.mat.list[[1]])
  }
  if(length(Y.mat.list) == 1){
    Y.mat.list = lapply(1:n, function(o) Y.mat.list[[1]])
  }
  if(is.null(label)){
    label = 1:n
  }
  plot.df = data.frame()
  for(i in 1:n){
    xin = cbind(as.vector(X.mat.list[[i]]), as.vector(Y.mat.list[[i]]))
    yin = as.vector(Z.mat.list[[i]])
    yout = Lwls2D(bandw, kern = "epan", xin, yin, xout = xin, crosscov=TRUE)
    yout[yout<0] = 0
    plot.df = plot.df %>% rbind(data.frame(X = xin[,1], 
                                           Y = xin[,2], 
                                           density = yout,
                                           label = label[i]))
  }
  if(!is.null(level)){
    plot.df$label = factor(plot.df$label, levels = level)
  }
  gg = ggplot(plot.df) + 
    geom_tile(aes(x = X, y = Y, fill = density)) + 
    facet_wrap(~label, ncol = ncol) +
    #scale_fill_gradient(limits = c(min, max), low = 'white', high = 'red')
    #scale_fill_gradientn(colours = rev(rainbow(6)), limits = c(min, max))
    #scale_fill_distiller(limits = c(min, max), palette = "RdPu" )
    scico::scale_fill_scico(limits = c(min, max), palette = "lajolla", direction = -1) + 
    xlab(xlab) + 
    ylab(ylab) + 
    ggtitle(title) + 
    theme(plot.title = element_text(hjust = 0.5))
  
  gg
  
}


