library(frechet)
library(Matrix)

## --- Log-Cholesky distance function ---
log_cholesky_distance <- function(S1, S2) {
  S1 <- (S1 + t(S1)) / 2
  S2 <- (S2 + t(S2)) / 2
  reg <- 1e-8
  S1 <- S1 + reg * diag(nrow(S1))
  S2 <- S2 + reg * diag(nrow(S2))
  L1 <- chol(S1)
  L2 <- chol(S2)
  U1 <- t(L1)
  U2 <- t(L2)
  sUT1 <- U1
  sUT1[lower.tri(sUT1, diag = TRUE)] <- 0
  sUT2 <- U2
  sUT2[lower.tri(sUT2, diag = TRUE)] <- 0
  off_dist_sq <- sum((sUT1 - sUT2)^2)
  d1 <- diag(U1)
  d2 <- diag(U2)
  logD_dist_sq <- sum((log(d1) - log(d2))^2)
  sqrt(off_dist_sq + logD_dist_sq)
}

# Main Function : Single Index F-regression with covariance response with log-Cholesky metric
SIdxCovReg = function(xin, Min, bw=NULL, M=NULL, ker = ker_gauss, lower = -Inf, upper = Inf, iter =  1000,
                      verbose = T){
  ## xin: n by p matrix of input (n: number of inputs, p: dimension of predictors)
  ## Min: q by q by n array where \code{M[,,i]} contains the i-th covariance matrix of dimension q by q
  ## bw: bandwidth b
  ## M: size of binning
  ## ker: ker_gauss, ker_unif, ker_epan
  ## iter: generation of directions.  
  ## verbose: print the iteration counts?
  
  if (is.vector(xin)){
    stop("The number of observations is too small")
  }
  if (!is.matrix(xin)){
    stop("xin should be matrix.")
  }
  if(!is.array(Min)){
    stop("Min should be array.")
    
    if(length(dim(Min)) != 3){
      stop("Min should be 3-dimensional array.")
    }
    
  }
  
  p <- ncol(xin)
  
  ## Parameter (bandwidth, bin size) choice using cross-validation
  needParam <- (is.null(M) | is.null(bw))
  
  if (needParam) {
    param <- CovTuning(xin, Min, normalize(rep(1, p)))
  }
  
  # bw2 depends on bw
  if (is.null(bw)) {
    bw2 <- param[1]
  } else {
    bw2 <- bw
  }
  
  # M2 depends on M
  if (is.null(M)) {
    M2 <- param[2]
  } else {
    M2 <- M
  }
  
  coords_mat <- matrix(rnorm(iter * p), nrow = iter, ncol = p)
  
  # Normalize each row to have unit norm
  coords_mat <- coords_mat / sqrt(rowSums(coords_mat^2))
  coords_mat[,1] = abs(coords_mat[,1])
  
  ## Find the single index
  fdi_curr = Inf
  
  for(i in 1:iter){
    
    direc_new = coords_mat[i,]
    
    binned_dat <- CovBinned_data(xin, Min, direc_new, M2)
    proj_binned <- binned_dat$binned_xmean %*% direc_new
    
    err <- 0
    for (l in 1:M2) {
      
      res <- CovDirLocLin(
        xin, Min,
        direc_new, 
        proj_binned[l],
        bw2, ker = ker_gauss, lower = -Inf, upper = Inf)
      err <- err + log_cholesky_distance(res, binned_dat$binned_Mmean[,,l])^2
      
    }
    fdi_new <- err / M2
    
    if (fdi_new < fdi_curr) {
      
      fdi_curr = fdi_new
      direc_curr = direc_new
      bw_curr = bw2
      M_curr = M2
      
    }
    
    if(verbose){
      if(i %% 100 == 0){
        print(paste("Iteration number:", i,"/",iter))
      }
    }
  }
  
  return(list(est = normalize(direc_curr), bw = bw_curr, M = M_curr))
  
}


#### Directional local F-regression given covariance response with Frobenius metric, 
#### Direction along which to compute projection, and bandw choice
CovDirLocLin <- function(xin, Min, direc, xout, bw, ker = ker_gauss, 
                         lower = -Inf, upper = Inf) {
  
  ## xin: n by p matrix of input (n: number of inputs, p: dimension of predictors)
  ## Min: q by q by n array where \code{M[,,i]} contains the i-th covariance matrix of dimension q by q
  ## direc: directional vector, length p
  ## xout: A k by p matrix with output measurements of the predictors. Default is \code{xin}.
  ## bw: bandwidth b
  ## M: size of binning
  ## ker: ker_gauss, ker_unif, ker_epan
  
  if (is.vector(xin)){
    stop("The number of observations is too small")
  }
  
  if(!is.array(Min)){
    stop("Min should be array.")
    
    if(length(dim(Min)) != 3){
      stop("Min should be 3-dimensional array.")
    }
    
  }
  
  
  if (!is.matrix(xin)) stop("xin should be a matrix.")
  if (!is.array(Min) || length(dim(Min)) != 3) stop("Min should be a 3-dimensional array.")
  
  n <- nrow(xin)
  
  if (n < 3) {
    stop("The number of observations is too small")
  }
  
  projec <- xin %*% direc
  aux = ker((projec - xout) / bw)
  mu0 <- mean(aux)
  mu1 <- mean(ker((projec - xout) / bw) * (projec - xout))
  mu2 <- mean(ker((projec - xout) / bw) * (projec - xout)^2)
  
  s = array(0, length(aux))
  for(i in 1:length(aux)){
    s[i] =aux[i]*(1-t(mu1)%*%solve(mu2)%*%(projec[i]-xout))
  }
  
  
 # s <- ker((projec - xout) / bw) * (mu2 - mu1 * (projec - xout)) / (mu0 * mu2 - mu1^2)
  sL <- sum(s)
  
  # Use log-Cholesky distance for covariance matrix estimation
  n <- nrow(xin)
  q <- dim(Min)[1]
  
  # Compute Cholesky decomposition for each matrix
  LL <- lapply(1:n, function(i) {
    M_i <- Min[,,i]
    # Ensure symmetric and positive definite
    M_i <- (M_i + t(M_i)) / 2
    M_i <- M_i + 1e-8 * diag(q)
    chol(M_i)
  })
  
  # Separate lower triangular part (excluding diagonal) and diagonal
  L_list <- lapply(LL, function(X) X - diag(diag(X)))
  D_list <- lapply(LL, function(X) diag(X))
  
  # Compute weighted sums for log-Cholesky mean
  U <- matrix(0, nrow = q, ncol = q)
  E <- numeric(q)
  
  for (i in 1:n) {
    U <- U + s[i] * L_list[[i]]
    E <- E + s[i] * log(D_list[[i]])
  }
  
  # Reconstruct the covariance matrix
  SS <- U / sL + diag(exp(E / sL))
  M_res <- t(SS) %*% SS
  
  # Ensure positive definiteness
  M_res <- (M_res + t(M_res)) / 2
  M_res <- M_res + 1e-8 * diag(q)
  M_res <- as.matrix(Matrix::nearPD(M_res)$mat)

  return(M_res)
}

#### Implements the selcetion of bandw for the local F-reg and, 
#### for the optimal choice of bandw, selects the optimal bin size
CovTuning <- function(xin, Min, direc, ker = ker_gauss){
  ## xin: n by p matrix of input (n: number of inputs, p: dimension of predictors)
  ## Min: q by q by n array where \code{M[,,i]} contains the i-th covariance matrix of dimension q by q
  ## direc: directional vector, length p
  ## ker: ker_gauss, ker_unif, ker_epan
  
  ## CV function to select the bandwidth
  bwCV <- function(bw, xin, Min, direc, ker = ker_gauss, lower = -Inf, upper = Inf) {
    
    if (is.vector(xin)){
      stop("The number of observations is too small")
    } 
    
    n <- nrow(xin)
    p <- ncol(xin)
    q <- dim(Min)[1]
    
    projec <- xin %*% direc
    ind_cv <- split(1:n, rep(1:5, length.out = n))
    cv_err <- 0
    
    for (i in 1:5) {
      xin_eff <- xin[-ind_cv[[i]], ]
      Min_eff <- Min[,,-ind_cv[[i]]]
      
      for (k in 1:length(ind_cv[[i]])) {
        res <- CovDirLocLin(xin_eff, Min_eff , direc, projec[ind_cv[[i]][k]], bw, ker, lower = -Inf, upper = Inf)
        cv_err <- cv_err + log_cholesky_distance(res, Min[,,ind_cv[[i]][k]])^2 
        
      }
      cv_err <- cv_err / length(ind_cv[[i]])
    }
    return(cv_err / 5)
  }
  
  ## CV function to select M
  bwCV_M <- function(xin, Min, direc, M, bw, ker = ker_gauss, lower = -Inf, upper = Inf) {
    binned_dat <- CovBinned_data(xin, Min, direc, M)
    xin_binned <- binned_dat$binned_xmean
    Min_binned <- binned_dat$binned_Mmean
    proj_binned <- xin_binned %*% direc
    
    cv_err <- 0
    
    for (i in 1:M) {
      xin_eff <- xin_binned[-i, ]
      Min_eff <- Min_binned[,,-i]
      res <- CovDirLocLin(xin_eff, Min_eff, direc, proj_binned[i], bw, ker = ker_gauss, lower = -Inf, upper = Inf)
      cv_err <- cv_err + log_cholesky_distance(res, Min_binned[,,i])^2
    }
    
    if(!is.nan(cv_err)){
      return(cv_err / M)
    } else{
      return(Inf)
    }
    
  }
  
  n <- nrow(xin)
  projec = xin %*% direc
  
  xinSt = unique(sort(projec))
  bw_min = max(c(diff(xinSt)))*1.1
  bw_max = (max(projec) - min(projec))/3
  if (bw_max < bw_min){
    if (bw_min > bw_max * 3/2){
      warning("Data is too sparse.")
      bw_max = bw_min * 1.01
    } else{
      bw_max = bw_max * 3/2
    }
  }
  
  ## bandwidth choice using bwCV
  bw = optim(par = runif(1, min = bw_min, max = bw_max), 
             fn = bwCV, xin = xin, Min = Min, direc = direc,
             method = "Brent", 
             lower = bw_min, upper = bw_max)$par
  
  
  M_range = ceiling(n/c(2:30))
  M_range = unique(M_range[60 > M_range & M_range > 15])
  
  if (length(M_range) >0){
    
    cv_err_curr = Inf
    for(M in M_range){
      
      cv_err_new = bwCV_M(xin,Min, direc, M, bw)

      if (cv_err_new < cv_err_curr){
        
        cv_err_curr <- cv_err_new
        M_curr <- M
        
      }
      
    }
    
  } else{
    M = 15
  }
  
  #end
  return(c(bw, M))
  
}


#### Binning step: given data and direction bins the support of the projection
#### Returns a representative point for the data (xin and Min)
CovBinned_data <- function(xin, Min, direc, M) {
  
  if (M < 4){
    stop("The number of binned data should be greater than 3.")
  }
  
  n <- nrow(xin)
  p <- ncol(xin)
  q <- dim(Min)[1]
  
  if(n < M){
    stop("The number of binned data cannot exceed the number of observations.")
  }
  
  projec <- xin %*% direc
  range_of_projec <- seq(min(projec), max(projec), length.out = M)
  
  binned_xmean <- matrix(NA, M, p)
  binned_xmean[1, ] <- xin[which.min(projec), ]
  
  binned_Mmean <- array(NA, dim=c(q,q,M))
  binned_Mmean[,, 1] <- Min[,, which.min(projec)]
  
  for (l in 2:(M - 1)) {
    idx = (n*l)%/%M
    idx_set = which(projec == sort(projec)[idx])
    binned_xmean[l, ] <- xin[idx_set[1], ]
    
    binned_Mmean[,, l] <- Min[,,idx_set[1]]
  }
  
  binned_xmean[M, ] <- xin[which.max(projec), ]
  binned_Mmean[,,M] <- Min[,,which.max(projec)]
  
  return(list(projec = projec, binned_xmean = binned_xmean, binned_Mmean = binned_Mmean))
}

#### Additional functions ####
ker_gauss <- function(x) {
  return(exp(-x^2 / 2) / sqrt(2 * pi))
}

normalize <- function(x){
  x / sqrt(sum(x^2))
}

CovGen_data_setting = function(n, true_beta, link){
  
  d = length(true_beta)
  rho = 1/4
  xin = 2*pnorm(MASS::mvrnorm(n, mu = rep(0,d), Sigma = (1-rho)*diag(d) + matrix(rho,d,d) ))-1
  
  q = 3
  Min = array(0, c(q,q,n))
  for(i in 1:n){
    
    proj = sum(true_beta * xin[i, ])
    eig = c(exp(link(proj)),exp(link(proj)/2), exp(-link(proj)))
    Min[,,i] = diag(eig) 
    
  }
  
  return(list(xin = xin, Min = Min))
  
}

#### Test ####
#for(rep in 1:10){
#  set.seed(rep+999)
#  dat <- CovGen_data_setting(100, b0, function(x) x)
#  res_cov <- SIdxCovReg(dat$xin, dat$Min, iter = 500, M = 10, bw = 0.25, verbose = F)
#  print(res_cov$est)
#}
#save(res_cov, file = "res_cov.RData")

# b <- c(3, -1.3, -3, 1.7)
# b0 <- normalize(b)
# b0 #0.6313342 -0.2735781 -0.6313342  0.3577560
# 
# set.seed(999)
# dat <- CovGen_data_setting(500, b0, function(x) x)
# res_cov <- SIdxCovReg(dat$xin, dat$Min)


