# Get function variance explained in two dimensional densities 

#### Usage: GetFVE2D()
#### Input:
# Estdmatrix: estimated matrix list 
# dmatrix: raw density matrix list 
source("trapz2DRcpp.R")
source("FPCAsurface.R")
source('MakeAlphaReg.R')

GetFVE2D <- function(Estdmatrix, dmatrix, s_list, t_list, metric = 'L2', fret_mean=NA, costm = NA){
  if(!(metric %in% c('L2', "W2"))){
    stop('Unrecognized value for metric input.')
  }
  
  
  if(metric == 'L2'){
    mu_dens = zeros(length(s_list), length(t_list))
    for(i in 1: length(dmatrix)){
      mu_dens = mu_dens + dmatrix[[i]]
    }
    mu_dens = mu_dens / length(dmatrix)
    mean_out = mu_dens
    
    vtot = mean(sapply(dmatrix, function(dens){
      trapz2DRcpp(X = s_list, Y = t_list, Z = (dens - mu_dens)^2)
    }))
    
    vk = mean(sapply(1:length(dmatrix), function(i){
      trapz2DRcpp(X = s_list, Y = t_list, Z = (dmatrix[[i]] - Estdmatrix[[i]])^2)
      
    }))
    
    FVE = (vtot - vk) / vtot
    
    
  }
  else{#wasserstein metric
    Ns = length(s_list)
    Nt = length(t_list)
    if(is.na(costm)){
      costm <- as.matrix(dist(expand.grid(s_list,t_list), diag=TRUE, upper=TRUE))
    }
    Cendmatrix = lapply(dmatrix, function(dens) dens/sum(dens))
    CenEstdmatrix = lapply(Estdmatrix, function(dens) dens/sum(dens))
    #frechet mean 
    if(is.na(fret_mean)){
      fret_mean = WaBarycenter(Cendmatrix, maxIter = 10, lambda = 10, costm = costm)
    }
    fret_mean = fret_mean / sum(fret_mean)
    mean_out = fret_mean
    #total variance
    a = matrix(fret_mean, nrow=Ns*Nt, ncol=1)
    b = matrix(sapply(Cendmatrix, function(dens) matrix(dens, nrow=Ns*Nt, ncol=1) ), nrow=Ns*Nt)
    temp = Sinkhorn(a,b,costm, lambda = 0.01)
    vtot = mean(temp$Distance)
    #variance explained
    a = matrix(sapply(CenEstdmatrix, function(dens) matrix(dens, nrow=Ns*Nt, ncol=1) ), nrow=Ns*Nt)
    b = matrix(sapply(Cendmatrix, function(dens) matrix(dens, nrow=Ns*Nt, ncol=1) ), nrow=Ns*Nt)
    temp = Sinkhorn(a,b,costm, lambda = 0.01)
    vk = mean(temp$Distance)
    #FVE
    FVE = (vtot - vk) / vtot
  }
  
  return(list(FVE = FVE, mean_out = mean_out))
  
}

GetFVE2D_sum <- function(fpcaObj1, fpcaObj2, SampleDens0.list, Xmat, Ymat, K1 = 1:3, K2 = 1:3, metric = "L2", alpha = 0, n.core = n.core, method = "BLQD"){
  if (!method %in% c("BLQD", "Vanilla")){
    stop("Unrecognized method")
  }
  #out = zeros(length(K1), length(K2))
  out = numeric(length(K1))
  for (i in 1:length(K1)){
      k1 = K1[i]
      k2 = K2[i]
      if (method == "BLQD"){
        Dens_fit = EstDens(fpcaObj1, fpcaObj2, Xmat, Ymat, K1 = k1, K2 = k2, alpha = alpha, n.core = n.core) #not deregularization here 
      }else{
        Dens_fit = EstDens_vanilla(fpcaObj1, fpcaObj2, Xmat, Ymat, K1 = k1, K2 = k2, alpha = alpha, n.core = n.core) #not deregularization here 
      }
      #check if it is regular design
      if(sum(apply(Xmat, 1, function(v){
        length(unique(v)) > 1
      })) > 0){
        print("irreg")
        out[i] = GetFVE2D_Irregular(Dens_fit$EstDens0.list,
                                      SampleDens0.list, 
                                      Xmat, Ymat, metric = 'L2')$FVE
      }else{
        out[i] = GetFVE2D(Dens_fit$EstDens0.list, 
                            SampleDens0.list, 
                            Xmat[,1], Ymat[1,], metric = 'L2')$FVE
      }
    }
  return(out)
}

GetFVE2D_Irregular <- function(Estdmatrix, dmatrix, X.mat, Y.mat, metric = 'L2', fret_mean=NA, costm = NA){
  # interpolate 
  result1 = lapply(1:n, function(i){
    interp_xy = akima::interp(x = as.vector(X.mat), 
                              y = as.vector(Y.mat), 
                              z = as.vector(Estdmatrix[[i]]), 
                              nx = 50,
                              ny = 50,
                              linear = TRUE,
                              #extrap = TRUE,
                              duplicate = "mean")
   return(interp_xy)
  })
  result2 = lapply(1:n, function(i){
    interp_xy = akima::interp(x = as.vector(X.mat), 
                              y = as.vector(Y.mat), 
                              z = as.vector(dmatrix[[i]]), 
                              nx = 50,
                              ny = 50,
                              linear = TRUE,
                              #extrap = TRUE,
                              duplicate = "mean")
    return(interp_xy)
  })
  
  Estdmatrix = lapply(result1, function(v){
    return(v$z)
  })
  dmatrix = lapply(result2, function(v){
    return(v$z)
  })
  s_list = result1[[1]]$x
  t_list = result1[[1]]$y

  
  require(Barycenter)
  if(!(metric %in% c('L2', "W2"))){
    stop('Unrecognized value for metric input.')
  }
  
  
  if(metric == 'L2'){
    mu_dens = zeros(length(s_list), length(t_list))
    for(i in 1: length(dmatrix)){
      mu_dens = mu_dens + dmatrix[[i]]
    }
    mu_dens = mu_dens / length(dmatrix)
    mean_out = mu_dens
    
    vtot = mean(sapply(dmatrix, function(dens){
      trapz2DRcpp(X = s_list, Y = t_list, Z = (dens - mu_dens)^2)
    }))
    
    vk = mean(sapply(1:length(dmatrix), function(i){
      trapz2DRcpp(X = s_list, Y = t_list, Z = (dmatrix[[i]] - Estdmatrix[[i]])^2)
      
    }))
    
    FVE = (vtot - vk) / vtot
    
    
  }
  else{#wasserstein metric
    Ns = length(s_list)
    Nt = length(t_list)
    if(is.na(costm)){
      costm <- as.matrix(dist(expand.grid(s_list,t_list), diag=TRUE, upper=TRUE))
    }
    Cendmatrix = lapply(dmatrix, function(dens) dens/sum(dens))
    CenEstdmatrix = lapply(Estdmatrix, function(dens) dens/sum(dens))
    #frechet mean 
    if(is.na(fret_mean)){
      fret_mean = WaBarycenter(Cendmatrix, maxIter = 10, lambda = 10, costm = costm)
    }
    fret_mean = fret_mean / sum(fret_mean)
    mean_out = fret_mean
    #total variance
    a = matrix(fret_mean, nrow=Ns*Nt, ncol=1)
    b = matrix(sapply(Cendmatrix, function(dens) matrix(dens, nrow=Ns*Nt, ncol=1) ), nrow=Ns*Nt)
    temp = Sinkhorn(a,b,costm, lambda = 0.01)
    vtot = mean(temp$Distance)
    #variance explained
    a = matrix(sapply(CenEstdmatrix, function(dens) matrix(dens, nrow=Ns*Nt, ncol=1) ), nrow=Ns*Nt)
    b = matrix(sapply(Cendmatrix, function(dens) matrix(dens, nrow=Ns*Nt, ncol=1) ), nrow=Ns*Nt)
    temp = Sinkhorn(a,b,costm, lambda = 0.01)
    vk = mean(temp$Distance)
    #FVE
    FVE = (vtot - vk) / vtot
  
  }
  return(list(FVE = FVE, mean_out = mean_out))
  }

Trans_Irregular <- function(estdmatrix, X.mat, Y.mat){
  interp_xy = akima::interp(x = as.vector(X.mat), 
                              y = as.vector(Y.mat), 
                              z = as.vector(estdmatrix), 
                              nx = 51,
                              ny = 51,
                              linear = TRUE,
                              #extrap = TRUE,
                              duplicate = "mean")

  estdmatrix = interp_xy$z
  return(estdmatrix)
}

GetFVESW <- function(radon_dens, radon_rec_lqd, eps, consInte, rho_grid){
  # radon_dens: A list of sample size n, each is related to a N_slice X N density matrix
  # radon_rec_lqd: A list of sample size n, each is related to a N_slice X N lqd matrix
  radon_rec_quantile = lapply(radon_rec_lqd, function(lqd_mat){
    densmat = InverseTransformLqd(eps, lqd_mat, consInte, rho_grid)
    quantilemat = t(apply(densmat, 1, function(v){
      v = v / trapzRcpp(rho_grid, v)
      fdadensity::dens2quantile(v, dSup=rho_grid)
    }))
    return(quantilemat)
  })
  
  radon_quantile = lapply(radon_dens, function(dens){
    quantilemat = t(apply(dens, 1, function(v){
      v = v / trapzRcpp(rho_grid, v)
      fdadensity::dens2quantile(v, dSup=rho_grid)
    }))
    return(quantilemat)
  })
  
  radon_quantile_avg = Reduce("+", radon_quantile) / length(radon_quantile)
  
  vTot = mean(sapply(1: length(radon_rec_quantile), function(i){
    sum((radon_quantile[[i]] - radon_quantile_avg)^2)
  }))
  vEst = mean(sapply(1: length(radon_rec_quantile), function(i){
    sum((radon_quantile[[i]] - radon_rec_quantile[[i]])^2)
  }))
  
  return(1 - vEst / vTot)
  
  
}

