## This script creates the example dataset "splotdata" of the CAST package.
## It downloads splotopen data points and associated worldclim predictors for South America.
## A lower resolution predictor stack (terra rast) is also created for Chile.
## For more information, please check out the Book Chapter and Repository CAST4Ecology

library(geodata)
library(rnaturalearth)
library(terra)
library(sf)
library(tidyverse)
library(geodata)


##### Download Predictors --------------------------------
## Warning: This downloads ~ 1 GB of data
dir.create("data-raw/raw/")

wcf = geodata::worldclim_global(var = "bio", path = "data-raw/raw/", res = 0.5)
wc = geodata::worldclim_global(var = "bio", path = "data-raw/raw/", res = 5)
elevf = geodata::elevation_global(res = 0.5, path = "data-raw/raw/")
elev = geodata::elevation_global(res = 5, path = "data-raw/raw/")

wcf = c(wcf, elevf)
wc = c(wc, elev)

##### Download sPlotOpen -------------------------------------
if(!file.exists("data-raw/raw/splotopen")){
  download.file("https://idata.idiv.de/ddm/Data/DownloadZip/3474?version=5779", destfile = "data-raw/raw/splotopen.zip")
  unzip(zipfile = "data-raw/raw/splotopen.zip", exdir = "data-raw/raw/splotopen")
  unzip(zipfile = "data-raw/raw/splotopen/sPlotOpen.RData(2).zip", exdir = "data-raw/raw/splotopen")
}



##### Clean up and save necessary files ----------------------------------
# define region: all of south america
region = rnaturalearth::ne_countries(continent = "South America", returnclass = "sf", scale = 110)


# Predictor clean up
wc = crop(wc, region)
names(wc) = names(wc) |> str_remove(pattern = "wc2.1_5m_")
p = c("bio_1", "bio_4", "bio_5", "bio_6", "bio_8", "bio_9", "bio_12", "bio_13", "bio_14", "bio_15", "elev")
wc = wc[[p]]

# worldclim in full resolution for extracting the training data
wcf = crop(wcf, region)
names(wcf) = names(wcf) |> str_remove(pattern = "wc2.1_30s_")
wcf = wcf[[p]]
wcf$lat = terra::init(wcf, "y")
wcf$lon = terra::init(wcf, "x")


# Gather Response Variable: sPlotOpen Species Richness for South America
## see Appendix 1 of https://doi.org/10.1111/geb.13346
load("data-raw/raw/splotopen/sPlotOpen.RData")

splot = header.oa |>
  #filter(Resample_1 == TRUE) |>
  filter(Continent == "South America") |>
  st_as_sf(coords = c("Longitude", "Latitude"), crs = 4326) |>
  left_join(CWM_CWV.oa |> select(c("PlotObservationID", "Species_richness"))) |>
  select(c("PlotObservationID", "GIVD_ID", "Country", "Biome",
           "Species_richness")) |>
  na.omit()

# extract predictor values and attach to response
splot = terra::extract(wcf, splot, ID = FALSE, bind = TRUE) |>
  st_as_sf() |>
  na.omit()


# only keep unique locations
## some reference sample locations are in the same predictor stack pixel
## this can lead to erroneous models and misleading validations
splotdata = splot[!duplicated(c(splot$lat, splot$lon)),]
splotdata = splotdata |> na.omit()
splotdata$lat = NULL
splotdata$lon = NULL


# save splotdata
splotdata$Biome = droplevels(splotdata$Biome)
save(splotdata, file = "data/splotdata.rda", compress = "xz")

## save predictors for chile
chile = rnaturalearth::ne_countries(country = "Chile", returnclass = "sf")
wc = crop(wc, chile)
writeRaster(wc, "inst/extdata/predictors_chile.tif", datatype = "INT2S", overwrite = TRUE)


## Remove downloaded data
unlink("data-raw/raw", recursive = TRUE)


#' Area of Applicability
#' @description
#' This function estimates the Dissimilarity Index (DI) and the derived
#' Area of Applicability (AOA) of spatial prediction models by
#' considering the distance of new data (i.e. a SpatRaster of spatial predictors
#' used in the models) in the predictor variable space to the data used for model
#' training. Predictors can be weighted based on the internal
#' variable importance of the machine learning algorithm used for model training.
#' The AOA is derived by applying a threshold on the DI which is the (outlier-removed)
#' maximum DI of the cross-validated training data.
#' Optionally, the local point density is calculated which indicates the number of similar training data points up to the DI threshold.
#' @param newdata A SpatRaster, stars object or data.frame containing the data
#' the model was meant to make predictions for.
#' @param model A train object created with caret used to extract weights from (based on variable importance) as well as cross-validation folds.
#' See examples for the case that no model is available or for models trained via e.g. mlr3.
#' @param trainDI A trainDI object. Optional if \code{\link{trainDI}} was calculated beforehand.
#' @param train A data.frame containing the data used for model training. Optional. Only required when no model is given
#' @param weight A data.frame containing weights for each variable. Optional. Only required if no model is given.
#' @param variables character vector of predictor variables. if "all" then all variables
#' of the model are used or if no model is given then of the train dataset.
#' @param CVtest list or vector. Either a list where each element contains the data points used for testing during the cross validation iteration (i.e. held back data).
#' Or a vector that contains the ID of the fold for each training point.
#' Only required if no model is given.
#' @param CVtrain list. Each element contains the data points used for training during the cross validation iteration (i.e. held back data).
#' Only required if no model is given and only required if CVtrain is not the opposite of CVtest (i.e. if a data point is not used for testing, it is used for training).
#' Relevant if some data points are excluded, e.g. when using \code{\link{nndm}}.
#' @param method Character. Method used for distance calculation. Currently euclidean distance (L2) and Mahalanobis distance (MD) are implemented but only L2 is tested. Note that MD takes considerably longer.
#' @param useWeight Logical. Only if a model is given. Weight variables according to importance in the model?
#' @param useCV Logical. Only if a model is given. Use the CV folds to calculate the DI threshold?
#' @param LPD Logical. Indicates whether the local point density should be calculated or not.
#' @param maxLPD numeric or integer. Only if \code{LPD = TRUE}. Number of nearest neighbors to be considered for the calculation of the LPD. Either define a number between 0 and 1 to use a percentage of the number of training samples for the LPD calculation or a whole number larger than 1 and smaller than the number of training samples. CAUTION! If not all training samples are considered, a fitted relationship between LPD and error metric will not make sense (@seealso \code{\link{DItoErrormetric}})
#' @param indices logical. Calculate indices of the training data points that are responsible for the LPD of a new prediction location? Output is a matrix with the dimensions num(raster_cells) x maxLPD. Each row holds the indices of the training data points that are relevant for the specific LPD value at that location. Can be used in combination with exploreAOA(aoa) function from the \href{https://github.com/fab-scm/CASTvis}{CASTvis package} for a better visual interpretation of the results. Note that the matrix can be quite big for examples with a high resolution and a larger number of training samples, which can cause memory issues.
#' @param parallel Logical. Parallelization the process. Only possible if LPD = TRUE. Can reduce computation time significantly.
#' @param cores Integer or Character. Number of cores to use for the the parallelization. You can use "auto" to set your cores to \code{detectCores()/2} (see \code{\link[parallel]{detectCores}}).
#' @param verbose Logical. Print progress or not?
#' @param algorithm see \code{\link[FNN]{knnx.dist}} and \code{\link[FNN]{knnx.index}}
#' @details The Dissimilarity Index (DI), the Local Data Point Density (LPD) and the corresponding Area of Applicability (AOA) are calculated.
#' If variables are factors, dummy variables are created prior to weighting and distance calculation.
#'
#' Interpretation of results: If a location is very similar to the properties
#' of the training data it will have a low distance in the predictor variable space
#' (DI towards 0) while locations that are very different in their properties
#' will have a high DI. For easier interpretation see \code{\link{normalize_DI}}
#' See Meyer and Pebesma (2021) for the full documentation of the methodology.
#' @note If classification models are used, currently the variable importance can only
#' be automatically retrieved if models were trained via train(predictors,response) and not via the formula-interface.
#' Will be fixed.
#' @return An object of class \code{aoa} containing:
#'  \item{parameters}{object of class trainDI. see \code{\link{trainDI}}}
#'  \item{DI}{SpatRaster, stars object or data frame. Dissimilarity index of newdata}
#'  \item{LPD}{SpatRaster, stars object or data frame. Local Point Density of newdata.}
#'  \item{AOA}{SpatRaster, stars object or data frame. Area of Applicability of newdata. AOA has values 0 (outside AOA) and 1 (inside AOA)}
#'
#' @importFrom parallel detectCores makeForkCluster clusterExport parLapply stopCluster
#'
#' @author
#' Hanna Meyer, Fabian Schumacher
#' @references Meyer, H., Pebesma, E. (2021): Predicting into unknown space?
#' Estimating the area of applicability of spatial prediction models.
#' Methods in Ecology and Evolution 12: 1620-1633. \doi{10.1111/2041-210X.13650}
#'
#' Schumacher, F., Knoth, C., Ludwig, M., Meyer, H. (2024):
#' Estimation of local training data point densities to support the assessment
#' of spatial prediction uncertainty. EGUsphere. \doi{10.5194/egusphere-2024-2730}.
#'
#' @seealso \code{\link{trainDI}}, \code{\link{normalize_DI}}, \code{\link{errorProfiles}}
#' @examples
#' \dontrun{
#' library(sf)
#' library(terra)
#' library(caret)
#' library(viridis)
#'
#' # prepare sample data:
#' data(cookfarm)
#' dat <- aggregate(cookfarm[,c("VW","Easting","Northing")],
#'    by=list(as.character(cookfarm$SOURCEID)),mean)
#' pts <- st_as_sf(dat,coords=c("Easting","Northing"),crs=26911)
#' pts$ID <- 1:nrow(pts)
#' set.seed(100)
#' pts <- pts[1:30,]
#' studyArea <- rast(system.file("extdata","predictors_2012-03-25.tif",package="CAST"))[[1:8]]
#' trainDat <- extract(studyArea,pts,na.rm=FALSE)
#' trainDat <- merge(trainDat,pts,by.x="ID",by.y="ID")
#'
#' # visualize data spatially:
#' plot(studyArea)
#' plot(studyArea$DEM)
#' plot(pts[,1],add=TRUE,col="black")
#'
#' # train a model:
#' set.seed(100)
#' variables <- c("DEM","NDRE.Sd","TWI")
#' model <- train(trainDat[,which(names(trainDat)%in%variables)],
#' trainDat$VW, method="rf", importance=TRUE, tuneLength=1,
#' trControl=trainControl(method="cv",number=5,savePredictions=T))
#' print(model) #note that this is a quite poor prediction model
#' prediction <- predict(studyArea,model,na.rm=TRUE)
#' plot(varImp(model,scale=FALSE))
#'
#' #...then calculate the AOA of the trained model for the study area:
#' AOA <- aoa(studyArea, model)
#' plot(AOA)
#' plot(AOA$AOA)
#' #... or if preferred calculate the aoa and the LPD of the study area:
#' AOA <- aoa(studyArea, model, LPD = TRUE)
#' plot(AOA$LPD)
#'
#' #note that it is not required to use Random Forests. The method is model agnostic.
#' # Let's chnage to SVM:
#' model <- train(trainDat[,which(names(trainDat)%in%variables)],
#' trainDat$VW, method="svmRadial", importance=TRUE, tuneLength=1,
#' trControl=trainControl(method="cv",number=5,savePredictions=T))
#' AOA <- aoa(studyArea, model, LPD = TRUE)
#' plot(AOA$LPD)
#'
#' ####
#' #The AOA can also be calculated without a trained model.
#' #All variables are weighted equally in this case:
#' ####
#'
#' AOA <- aoa(studyArea,train=trainDat,variables=variables)
#'
#' ####
#' # The AOA can also be used for models trained via mlr3 (parameters have to be assigned manually):
#' ####
#'
#' library(mlr3)
#' library(mlr3learners)
#' library(mlr3spatial)
#' library(mlr3spatiotempcv)
#' library(mlr3extralearners)
#'
#' # initiate and train model:
#' train_df <- trainDat[, c("DEM","NDRE.Sd","TWI", "VW")]
#' backend <- as_data_backend(train_df)
#' task <- as_task_regr(backend, target = "VW")
#' lrn <- lrn("regr.randomForest", importance = "mse")
#' lrn$train(task)
#'
#' # cross-validation folds
#' rsmp_cv <- rsmp("cv", folds = 5L)$instantiate(task)
#'
#' ## predict:
#' prediction <- predict(studyArea,lrn$model,na.rm=TRUE)
#'
#' ### Estimate AOA
#' AOA <- aoa(studyArea,
#'            train = as.data.frame(task$data()),
#'            variables = task$feature_names,
#'            weight = data.frame(t(lrn$importance())),
#'            CVtest = rsmp_cv$instance[order(row_id)]$fold)
#'
#' }
#' @export aoa
#' @aliases aoa


aoa <- function(newdata,
                model=NA,
                trainDI = NA,
                train=NULL,
                weight=NA,
                variables="all",
                CVtest=NULL,
                CVtrain=NULL,
                method="L2",
                useWeight=TRUE,
                useCV=TRUE,
                LPD = FALSE,
                maxLPD = 1,
                indices = FALSE,
                parallel = FALSE,
                cores = 4,
                verbose = TRUE,
                algorithm = "brute") {

  # handling of different raster formats
  as_stars <- FALSE
  leading_digit <- any(grepl("^{1}[0-9]",names(newdata)))

  if (inherits(newdata, "stars")) {
    if (!requireNamespace("stars", quietly = TRUE))
      stop("package stars required: install that first")
    newdata <- methods::as(newdata, "SpatRaster")
    as_stars <- TRUE
  }
  if (inherits(newdata, "Raster")) {
   # if (!requireNamespace("raster", quietly = TRUE))
  #    stop("package raster required: install that first")
    message("Raster will soon not longer be supported. Use terra or stars instead")
    newdata <- methods::as(newdata, "SpatRaster")
  }

  calc_LPD <- LPD
  # validate maxLPD input
  if (LPD == TRUE) {
    if (is.numeric(maxLPD)) {
      if (maxLPD <= 0) {
        stop("maxLPD can not be negative or equal to 0. Either define a number between 0 and 1 to use a percentage of the number of training samples for the LPD calculation or a whole number larger than 1 and smaller than the number of training samples.")
      }
      if (maxLPD <= 1) {
        if (inherits(model, "train")) {
          maxLPD <- round(maxLPD * as.integer(length(model$trainingData[[1]])))
        } else if (!is.null(train)) {
          maxLPD <- round(maxLPD * as.integer(length(train[[1]])))
        }
        if (maxLPD <= 1) {
          stop("The percentage you provided for maxLPD is too small.")
        }
      }
      if (maxLPD > 1) {
        if (maxLPD %% 1 == 0) {
          maxLPD <- as.integer(maxLPD)
        } else if (maxLPD %% 1 != 0) {
          stop("If maxLPD is bigger than 0, it should be a whole number. Either define a number between 0 and 1 to use a percentage of the number of training samples for the LPD calculation or a whole number larger than 1 and smaller than the number of training samples.")
        }
      }
      if ((maxLPD > length(if (inherits(model, "train")) { model$trainingData[[1]] } else if (!is.null(train)) { train[[1]] })) || maxLPD %% 1 != 0) {
        stop("maxLPD can not be bigger than the number of training samples. Either define a number between 0 and 1 to use a percentage of the number of training samples for the LPD calculation or a whole number larger than 1 and smaller than the number of training samples.")
      }
    } else {
      stop("maxLPD must be a number. Either define a number between 0 and 1 to use a percentage of the number of training samples for the LPD calculation or a whole number larger than 1 and smaller than the number of training samples.")
    }
  }

  if (parallel & Sys.info()["sysname"] != "Linux") {
    stop("Paralellization only works for UNIX-alike systems. Please use single core computation.")
  }


  # if not provided, compute train DI
  if(!inherits(trainDI, "trainDI")) {
    if (verbose) {
      message("No trainDI provided.")
    }
    trainDI <- trainDI(model, train, variables, weight, CVtest, CVtrain, method, useWeight, useCV, LPD, verbose, algorithm=algorithm)
  }

  if (calc_LPD == TRUE) {
    # maxLPD <- trainDI$avrgLPD
    trainDI$maxLPD <- maxLPD
  }


  # check if variables are in newdata
  if(any(trainDI$variables %in% names(newdata)==FALSE)){
    if(leading_digit){
      stop("names of newdata start with leading digits, automatically added 'X' results in mismatching names of train data in the model")
    }
    stop("names of newdata don't match names of train data in the model")
  }


  # Prepare output as either as RasterLayer or vector:
  out <- NA
  if (inherits(newdata, "SpatRaster")){
    out <- newdata[[1]]
    names(out) <- "DI"
  }



  #### order data:
  if (inherits(newdata, "SpatRaster")){
    if (any(is.factor(newdata))){
      newdata[[which(is.factor(newdata))]] <- as.numeric(newdata[[which(is.factor(newdata))]])
    }
    newdata <- terra::as.data.frame(newdata,na.rm=FALSE)
  }
  newdata <- newdata[,na.omit(match(trainDI$variables, names(newdata))),drop = FALSE]


  ## Handling of categorical predictors:
  catvars <- trainDI$catvars
  if (!inherits(catvars,"error")&length(catvars)>0){
    for (catvar in catvars){
      # mask all unknown levels in newdata as NA (even technically no predictions can be made)
      trainDI$train[,catvar]<-droplevels(trainDI$train[,catvar])
      newdata[,catvar] <- factor(newdata[,catvar])
      newdata[!newdata[,catvar]%in%unique(trainDI$train[,catvar]),catvar] <- NA
      newdata[,catvar] <- droplevels(newdata[,catvar])
      # then create dummy variables for the remaining levels in train:
      dvi_train <- predict(caret::dummyVars(paste0("~",catvar), data = trainDI$train),trainDI$train)
      dvi_newdata <- predict(caret::dummyVars(paste0("~",catvar), data=trainDI$train),newdata)
      dvi_newdata[is.na(newdata[,catvar]),] <- 0
      trainDI$train <- data.frame(trainDI$train,dvi_train)
      newdata <- data.frame(newdata,dvi_newdata)

    }
    newdata <- newdata[,-which(names(newdata)%in%catvars)]
    trainDI$train <- trainDI$train[,-which(names(trainDI$train)%in%catvars)]
  }

  # scale and weight new data
  newdata <- scale(newdata,center=trainDI$scaleparam$`scaled:center`,
                   scale=trainDI$scaleparam$`scaled:scale`)

  if(!inherits(trainDI$weight, "error")){
    tmpnames <- names(newdata)#!!!!!
    newdata <- sapply(1:ncol(newdata),function(x){
      newdata[,x]*unlist(trainDI$weight[x])
    })
    names(newdata)<-tmpnames#!!!!
  }


  # rescale and reweight train data
  train_scaled <- scale(trainDI$train,
                        center = trainDI$scaleparam$`scaled:center`,
                        scale = trainDI$scaleparam$`scaled:scale`)

  train_scaled <- sapply(1:ncol(train_scaled),function(x){train_scaled[,x]*unlist(trainDI$weight[x])})


  # Distance Calculation ---------
  okrows <- which(apply(newdata, 1, function(x)
    all(!is.na(x))))
  newdataCC <- newdata[okrows, ,drop=F]

  if (method == "MD") {
    if (dim(train_scaled)[2] == 1) {
      S <- matrix(stats::var(train_scaled), 1, 1)
      newdataCC <- as.matrix(newdataCC, ncol = 1)
    } else {
      S <- stats::cov(train_scaled)
    }
    S_inv <- MASS::ginv(S)
  } else {
    S_inv <- NULL # S_inv dummy variable to not crash on parallization
  }

  if (calc_LPD == FALSE) {
    if (verbose) {
      message("Computing DI of new data...")
    }
    mindist <- rep(NA, nrow(newdata))
    mindist[okrows] <-
      .mindistfun(newdataCC, train_scaled, method, S_inv,algorithm=algorithm)
    DI_out <- mindist / trainDI$trainDist_avrgmean
  }

  if (calc_LPD == TRUE) {
    if (verbose) {
      message("Computing DI and LPD of new data...")
    }

    DI_out <- rep(NA, nrow(newdata))
    LPD_out <- rep(NA, nrow(newdata))
    if (indices) {
        Indices_out <- matrix(NA, nrow = nrow(newdataCC), ncol = maxLPD)
    }

    if (!parallel) {

      if (verbose) {
        pb <- txtProgressBar(min = 0,
                             max = nrow(newdataCC),
                             style = 3)
      }

      for (i in seq(nrow(newdataCC))) {
        knnDist  <- .knndistfun(t(matrix(newdataCC[i,])), train_scaled, method, S_inv, maxLPD = maxLPD, algorithm=algorithm)
        knnDI <- knnDist / trainDI$trainDist_avrgmean
        knnDI <- c(knnDI)

        DI_out[okrows[i]] <- knnDI[1]
        LPD_out[okrows[i]] <- sum(knnDI < trainDI$threshold)

        if (indices) {
          if (LPD_out[okrows[i]] > 0) {
            knnIndex  <- .knnindexfun(t(matrix(newdataCC[i,])), train_scaled, method, S_inv, maxLPD = LPD_out[okrows[i]],algorithm=algorithm)
            Indices_out[i,1:LPD_out[okrows[i]]] <- as.numeric(knnIndex)
          }
        }

        if (verbose) {
          setTxtProgressBar(pb, i)
        }
      }

      # end progress bar
      if (verbose) {
        close(pb)
      }
    }

    # parallelized computatio using parLapply
    if (parallel) {
      message("Progress cannot be visualized for parallel computation.")

      trainDIdat <- trainDI # store trainDI in different variable to avoid environment conflict with function trainDI()

      if (cores == "auto") {
        cores <- floor(detectCores()/2)
      }



      # Create a cluster
      cl <- makeForkCluster(cores, useXDR = FALSE, methods = FALSE)

      # Export the necessary data and functions to the cluster
      clusterExport(cl, c("train_scaled",
                          "method",
                          "S_inv",
                          "trainDIdat",
                          "indices",
                          "maxLPD",
                          "algorithm",
                          ".process_row",
                          ".knndistfun",
                          ".knnindexfun"), envir = environment())

      # # Split newdataCC into chunks for each core (important for large datasets)
      size_chunks <- ceiling(nrow(newdataCC) / cores)
      indices_chunks <- split(seq(nrow(newdataCC)), rep(1:cores, each = size_chunks, length.out = nrow(newdataCC)))
      chunks <- lapply(indices_chunks, function(indices) newdataCC[indices, ] )

      # Apply parLapply over chunks
      results_chunks <- parLapply(cl, chunks, function(chunk) {
        apply(chunk, MARGIN = 1, .process_row)
      })

      # Combine the results from the computation of the data chunks
      results <- unlist(results_chunks, recursive = FALSE)

      # Stop the cluster
      stopCluster(cl)

      # Process the results and put them in the original output variables
      for (i in seq(length(results))) {
        DI_out[okrows[i]] <- results[[i]]$DI_out_i
        LPD_out[okrows[i]] <- results[[i]]$LPD_out_i
        if (indices & results[[i]]$LPD_out_i > 0) {
          Indices_out[i,1:LPD_out[okrows[i]]] <- as.numeric(results[[i]]$Indices_out_i)
        }
      }
    }

    # set maxLPD to max of LPD_out if
    realMaxLPD <- max(LPD_out, na.rm = T)
    if (maxLPD > realMaxLPD) {
      if (inherits(maxLPD, c("numeric", "integer")) && verbose) {
        message("Your specified maxLPD is bigger than the real maxLPD of you predictor data.")
      }
      if (verbose) {
        message(paste("maxLPD is set to", realMaxLPD))
      }
      trainDI$maxLPD <- realMaxLPD
    }

    if (indices) {
      Indices_out <- Indices_out[,1:trainDI$maxLPD]
      rownames(Indices_out) <- okrows
    }
  }

  if (verbose) {
    message("Computing AOA...")
  }

  #### Create Mask for AOA and return statistics
  if (inherits(out, "SpatRaster")) {
    terra::values(out) <- DI_out

    AOA <- out
    terra::values(AOA) <- 1
    AOA[out > trainDI$thres] <- 0
    AOA <- terra::mask(AOA, out)
    names(AOA) = "AOA"

    if (calc_LPD == TRUE) {
      LPD <- out
      terra::values(LPD) <- LPD_out
      names(LPD) = "LPD"
    }


    # handling of different raster formats.
    if (as_stars) {
      out <- stars::st_as_stars(out)
      AOA <- stars::st_as_stars(AOA)

      if (calc_LPD == TRUE) {
        LPD <- stars::st_as_stars(LPD)
      }
    }

  } else{
    out <- DI_out
    AOA <- rep(1, length(out))
    AOA[out > trainDI$thres] <- 0

    if (calc_LPD == TRUE) {
      LPD <- LPD_out
    }
  }


  #  # used in old versions of the AOA. eventually remove the attributes
  #  attributes(AOA)$aoa_stats <- list("Mean_train" = trainDI$trainDist_avrgmean,
  #                                    "threshold" = trainDI$thres)
  #  attributes(AOA)$TrainDI <- trainDI$trainDI

  result <- list(
    parameters = trainDI,
    DI = out,
    AOA = AOA
  )

  if (calc_LPD == TRUE) {
    result$LPD <- LPD
    if (indices) {
      result$indices <- Indices_out
    }
  }

  if (verbose) {
    message("Finished!")
  }

  class(result) <- "aoa"
  return(result)
}


.knndistfun <-
  function (point,
            reference,
            method,
            S_inv = NULL,
            maxLPD = maxLPD,
            algorithm) {
    if (method == "L2") {
      # Euclidean Distance
      return(FNN::knnx.dist(reference, point, k = maxLPD, algorithm = algorithm))
    } else if (method == "MD") {
      return(t(sapply(1:dim(point)[1],
                      function(y)
                        sort(sapply(1:dim(reference)[1],
                                    function(x)
                                      sqrt(t(point[y, ] - reference[x, ]) %*% S_inv %*% (point[y, ] - reference[x,]) )))[1:maxLPD])))
    }
  }

.knnindexfun <-
  function (point,
            reference,
            method,
            S_inv = NULL,
            maxLPD = maxLPD,
            algorithm) {
    if (method == "L2") {
      # Euclidean Distance
      return(FNN::knnx.index(reference, point, k = maxLPD, algorithm = algorithm))
    } else if (method == "MD") {
      stop("MD currently not implemented for LPD")
    }
  }

.process_row <- function(row) {
  knnDist <- .knndistfun(t(matrix(row)), train_scaled, method, S_inv, maxLPD = maxLPD, algorithm=algorithm)
  knnDI <- knnDist / trainDIdat$trainDist_avrgmean
  knnDI <- c(knnDI)

  DI_out_i <- knnDI[1]
  LPD_out_i <- sum(knnDI < trainDIdat$threshold)

  if (indices) {
    knnIndex <- .knnindexfun(t(matrix(row)), train_scaled, method, S_inv, maxLPD = LPD_out_i, algorithm=algorithm)
    Indices_out_i <- if (LPD_out_i > 0) { knnIndex } else { NA }

    # return here if indices to be calculated
    return(list(DI_out_i = DI_out_i,
                LPD_out_i = LPD_out_i,
                Indices_out_i = Indices_out_i
    ))
  }

  # return if indices not to be calculated
  return(list(DI_out_i = DI_out_i,
              LPD_out_i = LPD_out_i
  ))
}

# Tell R CMD check these variables are fine
utils::globalVariables(
  c(
    "train_scaled",
    "method",
    "S_inv",
    "trainDIdat",
    "maxLPD",
    "algorithm",
    "indices"
    )
  )




#' Best subset feature selection
#' @description Evaluate all combinations of predictors during model training
#' @param predictors see \code{\link[caret]{train}}
#' @param response see \code{\link[caret]{train}}
#' @param method see \code{\link[caret]{train}}
#' @param metric see \code{\link[caret]{train}}
#' @param maximize see \code{\link[caret]{train}}
#' @param globalval Logical. Should models be evaluated based on 'global' performance? See \code{\link{global_validation}}
#' @param trControl see \code{\link[caret]{train}}
#' @param tuneLength see \code{\link[caret]{train}}
#' @param tuneGrid see \code{\link[caret]{train}}
#' @param seed A random number
#' @param verbose Logical. Should information about the progress be printed?
#' @param ... arguments passed to the classification or regression routine
#' (such as randomForest).
#' @return A list of class train. Beside of the usual train content
#' the object contains the vector "selectedvars" and "selectedvars_perf"
#' that give the best variables selected as well as their corresponding
#' performance. It also contains "perf_all" that gives the performance of all model runs.
#' @details bss is an alternative to \code{\link{ffs}} and ideal if the training
#' set is small. Models are iteratively fitted using all different combinations
#' of predictor variables. Hence, 2^X models are calculated. Don't try running bss
#' on very large datasets because the computation time is much higher compared to
#' \code{\link{ffs}}.
#'
#' The internal cross validation can be run in parallel. See information
#' on parallel processing of carets train functions for details.
#'
#'
#' @note This variable selection is particularly suitable for spatial
#' cross validations where variable selection
#' MUST be based on the performance of the model for predicting new spatial units.
#' Note that bss is very slow since all combinations of variables are tested.
#' A more time efficient alternative is the forward feature selection (\code{\link{ffs}}).
#' @author Hanna Meyer
#' @seealso \code{\link[caret]{train}},\code{\link{ffs}},
#' \code{\link[caret]{trainControl}},\code{\link{CreateSpacetimeFolds}},
#' \code{\link{nndm}}
#' @examples
#' \dontrun{
#' data(iris)
#' bssmodel <- bss(iris[,1:4],iris$Species)
#' bssmodel$perf_all
#' plot(bssmodel)
#' }
#' @export bss
#' @aliases bss

bss <- function (predictors,
                 response,
                 method = "rf",
                 metric = ifelse(is.factor(response), "Accuracy", "RMSE"),
                 maximize = ifelse(metric == "RMSE", FALSE, TRUE),
                 globalval=FALSE,
                 trControl = caret::trainControl(),
                 tuneLength = 3,
                 tuneGrid = NULL,
                 seed = 100,
                 verbose=TRUE,
                 ...){
  trControl$returnResamp <- "final"
  trControl$savePredictions <- "final"
  if(inherits(response,"character")){
    response <- factor(response)
    if(metric=="RMSE"){
      metric <- "Accuracy"
      maximize <- TRUE
    }
  }
  se <- function(x){sd(x, na.rm = TRUE)/sqrt(length(na.exclude(x)))}
  n <- length(names(predictors))
  if(maximize) evalfunc <- function(x){max(x,na.rm=T)}
  if(!maximize) evalfunc <- function(x){min(x,na.rm=T)}
  isBetter <- function (actmodelperf,bestmodelperf,maximization=maximize){
    result <- ifelse (!maximization, actmodelperf < bestmodelperf,
                      actmodelperf > bestmodelperf)
    return(result)
  }
  testgrid <- expand.grid(lapply(seq_along(names(predictors)), c, 0))
  testgrid <- testgrid[-which(rowSums(testgrid==0)>=(length(names(predictors))-1)),]
  acc <- 0
  perf_all <- data.frame(matrix(ncol=length(predictors)+3,nrow=nrow(testgrid)))
  names(perf_all) <- c(paste0("var",1:length(predictors)),metric,"SE","nvar")
  for (i in 1:nrow(testgrid)){
    set.seed(seed)
    model <- caret::train(predictors[,unlist(testgrid[i,])],
                   response,method=method,trControl=trControl,
                   tuneLength=tuneLength,
                   tuneGrid=tuneGrid,...)


    if (globalval){
      perf_stats <- global_validation(model)[names(global_validation(model))==metric]
    }else{
      perf_stats <- model$results[,names(model$results)==metric]
    }
    actmodelperf <- evalfunc(perf_stats)

    actmodelperfSE <- se(
      sapply(unique(model$resample$Resample),
             FUN=function(x){mean(model$resample[model$resample$Resample==x,
                                                 metric],na.rm=TRUE)}))
    bestmodelperfSE <- actmodelperfSE
    if (i == 1){
      bestmodelperf <- actmodelperf
      if(globalval){
        bestmodelperfSE <- NA
      }
      bestmodel <- model
    } else{
      if (isBetter(actmodelperf,bestmodelperf,maximization=maximize)){
        bestmodelperf <- actmodelperf
        bestmodelperfSE <- actmodelperfSE
        bestmodel <- model
      }
    }
    acc <- acc+1
    perf_all[acc,1:length(model$finalModel$xNames)] <- model$finalModel$xNames
    perf_all[acc,(length(predictors)+1):ncol(perf_all)] <- c(actmodelperf,actmodelperfSE,length(model$finalModel$xNames))
    if (verbose){
    print(paste0("models that still need to be trained: ",
                 2^n-(n+1) - acc))
    }
  }



  if (globalval){
    selectedvars_perf <- global_validation(bestmodel)[names(global_validation(bestmodel))==metric]
  }else{
    if (maximize){
      selectedvars_perf <-max(bestmodel$results[,metric])
    }else{
      selectedvars_perf <- min(bestmodel$results[,metric])
    }
  }


  bestmodel$selectedvars <- bestmodel$finalModel$xNames
  bestmodel$selectedvars_perf <- selectedvars_perf
  bestmodel$perf_all <- perf_all
  bestmodel$perf_all <- bestmodel$perf_all[!apply(is.na(bestmodel$perf_all), 1, all),]
  bestmodel$perf_all <- bestmodel$perf_all[order(bestmodel$perf_all$nvar),]
  bestmodel$type <- "bss"
  class(bestmodel) <- c("ffs", "bss", "train")
  return(bestmodel)
}


#' 'caret' Applications for Spatio-Temporal models
#' @description Supporting functionality to run 'caret' with spatial or spatial-temporal data.
#' 'caret' is a frequently used package for model training and prediction using machine learning.
#' CAST includes functions to improve spatial-temporal modelling tasks using 'caret'.
#' It includes the newly suggested 'Nearest neighbor distance matching' cross-validation to estimate the performance
#' of spatial prediction models and allows for spatial variable selection to selects suitable predictor variables
#' in view to their contribution to the spatial model performance.
#' CAST further includes functionality to estimate the (spatial) area of applicability of prediction models
#' by analysing the similarity between new data and training data.
#' Methods are described in Meyer et al. (2018); Meyer et al. (2019); Meyer and Pebesma (2021); Milà et al. (2022); Meyer and Pebesma (2022); Linnenbrink et al. (2023).
#' The package is described in detail in Meyer et al. (2024).
#' @name CAST
#' @title 'caret' Applications for Spatial-Temporal Models
#' @author Hanna Meyer, Carles Milà, Marvin Ludwig, Jan Linnenbrink, Fabian Schumacher
#' @references
#' \itemize{
#' \item Meyer, H., Ludwig, L., Milà, C., Linnenbrink, J., Schumacher, F. (2024): The CAST package for training and assessment of spatial prediction models in R. arXiv, https://doi.org/10.48550/arXiv.2404.06978.
#' \item Linnenbrink, J., Milà, C., Ludwig, M., and Meyer, H.: kNNDM: k-fold Nearest Neighbour Distance Matching Cross-Validation for map accuracy estimation, EGUsphere [preprint], https://doi.org/10.5194/egusphere-2023-1308, 2023.
#' \item Milà, C., Mateu, J., Pebesma, E., Meyer, H. (2022): Nearest Neighbour Distance Matching Leave-One-Out Cross-Validation for map validation. Methods in Ecology and Evolution 00, 1– 13.
#' \item Meyer, H., Pebesma, E. (2022): Machine learning-based global maps of ecological variables and the challenge of assessing them. Nature Communications. 13.
#' \item Meyer, H., Pebesma, E. (2021): Predicting into unknown space? Estimating the area of applicability of spatial prediction models. Methods in Ecology and Evolution. 12, 1620– 1633.
#' \item Meyer, H., Reudenbach, C., Wöllauer, S., Nauss, T. (2019): Importance of spatial predictor variable selection in machine learning applications - Moving from data reproduction to spatial prediction. Ecological Modelling. 411, 108815.
#' \item Meyer, H., Reudenbach, C., Hengl, T., Katurji, M., Nauß, T. (2018): Improving performance of spatio-temporal machine learning models using forward feature selection and target-oriented validation. Environmental Modelling & Software 101: 1-9.
#' }
#'
#' @import caret
#' @importFrom stats sd dist na.omit lm predict quantile na.exclude complete.cases median
#' @importFrom utils combn txtProgressBar setTxtProgressBar
#' @importFrom grDevices rainbow
#' @importFrom graphics axis plot segments
#' @keywords package
#' @aliases CAST-package
#'
"_PACKAGE"


#' Clustered samples simulation
#'
#' @description A simple procedure to simulate clustered points based on a two-step sampling.
#' @param sarea polygon. Area where samples should be simulated.
#' @param nsamples integer. Number of samples to be simulated.
#' @param nparents integer. Number of parents.
#' @param radius integer. Radius of the buffer around each parent for offspring simulation.
#'
#' @return sf object with the simulated points and the parent to which each point belongs to.
#' @details A simple procedure to simulate clustered points based on a two-step sampling.
#' First, a pre-specified number of parents are simulated using random sampling.
#' For each parent, `(nsamples-nparents)/nparents` are simulated within a radius of the parent point using random sampling.
#'
#' @examples
#' # Simulate 100 points in a 100x100 square with 5 parents and a radius of 10.
#' library(sf)
#' library(ggplot2)
#'
#' set.seed(1234)
#' simarea <- list(matrix(c(0,0,0,100,100,100,100,0,0,0), ncol=2, byrow=TRUE))
#' simarea <- sf::st_polygon(simarea)
#' simpoints <- clustered_sample(simarea, 100, 5, 10)
#' simpoints$parent <- as.factor(simpoints$parent)
#' ggplot() +
#'     geom_sf(data = simarea, alpha = 0) +
#'     geom_sf(data = simpoints, aes(col = parent))
#'
#' @author Carles Milà
#' @export
clustered_sample <- function(sarea, nsamples, nparents, radius){

  # Number of offspring per parent
  nchildren <- round((nsamples-nparents)/nparents, 0)

  # Simulate parents
  parents <- sf::st_sf(geometry=sf::st_sample(sarea, nparents, type="random"))
  res <- parents
  res$parent <- 1:nrow(parents)

  # Simulate offspring
  for(i in 1:nrow(parents)){

    # Generate buffer and cut parts outside of the area of study
    buf <- sf::st_buffer(parents[i,], dist=radius)
    buf <- sf::st_intersection(buf, sarea)

    # Simulate children
    children <- sf::st_sf(geometry=sf::st_sample(buf, nchildren, type="random"))
    children$parent <- i
    res <- rbind(res, children)
  }

  return(res)
}


#' Cookfarm soil logger data
#'
#' spatio-temporal data of soil properties and associated predictors for the Cookfarm in Washington, USA.
#' The data are a subset of the cookfarm dataset provided with the \href{https://CRAN.R-project.org/package=GSIF}{GSIF package}.
#' @format
#' A sf data.frame with 128545 rows and 17 columns:
#' \describe{
#'   \item{SOURCEID}{ID of the logger}
#'   \item{VW}{Response Variable - Soil Moisture}
#'   \item{altitude}{Measurement depth of VW}
#'   \item{Date, cdata}{Measurement Date, Cumulative Date}
#'   \item{Easting, Northing}{Location Coordinates (EPSG:26911)}
#'   \item{DEM, TWI, NDRE.M, NDRE.Sd, Precip_wrcc, MaxT_wrcc, MinT_wrcc, Precip_cum}{Predictor Variables}
#' }
#'
#' @references \itemize{
#' \item{Gash et al. 2015 - Spatio-temporal interpolation of soil water, temperature, and electrical conductivity in 3D + T: The Cook Agronomy Farm data set \doi{https://doi.org/10.1016/j.spasta.2015.04.001}}
#' \item{Meyer et al. 2018 - Improving performance of spatio-temporal machine learning models using forward feature selection and target-oriented validation \doi{https://doi.org/10.1016/j.envsoft.2017.12.001}}
#' }
#' @usage data(cookfarm)
#'
"cookfarm"



#' Create Space-time Folds
#' @description Create spatial, temporal or spatio-temporal Folds for cross validation based on pre-defined groups
#' @param x data.frame containing spatio-temporal data
#' @param spacevar Character indicating which column of x identifies the
#' spatial units (e.g. ID of weather stations)
#' @param timevar Character indicating which column of x identifies the
#' temporal units (e.g. the day of the year)
#' @param k numeric. Number of folds. If spacevar or timevar is NA and a
#' leave one location out or leave one time step out cv should be performed,
#' set k to the number of unique spatial or temporal units.
#' @param class Character indicating which column of x identifies a class unit (e.g. land cover)
#' @param seed numeric. See ?seed
#' @return A list that contains a list for model training and a list for
#' model validation that can directly be used as "index" and "indexOut" in
#' caret's trainControl function. "cluster" gives us the information to which validation fold a sample belongs.
#' @details The function creates train and test sets by taking (spatial and/or temporal) groups into account.
#' In contrast to \code{\link{nndm}}, it requires that the groups are already defined (e.g. spatial clusters or blocks or temporal units).
#' Using "class" is helpful in the case that data are clustered in space
#' and are categorical. E.g This is the case for land cover classifications when
#' training data come as training polygons. In this case the data should be split in a way
#' that entire polygons are held back (spacevar="polygonID") but at the same time the distribution of classes
#' should be similar in each fold (class="LUC").
#' @note Standard k-fold cross-validation can lead to considerable misinterpretation in spatial-temporal modelling tasks.
#' This function can be used to prepare a Leave-Location-Out, Leave-Time-Out or Leave-Location-and-Time-Out cross-validation
#' as target-oriented validation strategies for spatial-temporal prediction tasks.
#' See Meyer et al. (2018) for further information. CreateSpaceTimeFolds is just a very simple approach and the suitability depends on the choice of the groups.
#' You may check the suitability with \code{\link{geodist}}. Consider \code{\link{nndm}} or \code{\link{knndm}} as alternatives or other approaches such as Spatial Blocks.
#' For spatial visualization of fold affiliation see examples.
#' @author Hanna Meyer
#' @seealso \code{\link[caret]{trainControl}},\code{\link{ffs}}, \code{\link{nndm}}, \code{\link{geodist}}
#' @references
#' Meyer, H., Reudenbach, C., Hengl, T., Katurji, M., Nauß, T. (2018): Improving performance of spatio-temporal machine learning models using forward feature selection and target-oriented validation. Environmental Modelling & Software 101: 1-9.
#' @examples
#' \dontrun{
#' data(cookfarm)
#' ### Prepare for 10-fold Leave-Location-and-Time-Out cross validation
#' indices <- CreateSpacetimeFolds(cookfarm,"SOURCEID","Date")
#' str(indices)
#' ### Prepare for 10-fold Leave-Location-Out cross validation
#' indices <- CreateSpacetimeFolds(cookfarm,spacevar="SOURCEID")
#' str(indices)
#' ### Prepare for leave-One-Location-Out cross validation
#' indices <- CreateSpacetimeFolds(cookfarm,spacevar="SOURCEID",
#'     k=length(unique(cookfarm$SOURCEID)))
#' str(indices)
#'
#' ### example from splotopen and visualization
#' data(splotdata)
#' indices <- CreateSpacetimeFolds(splotdata,spacevar="Country")
#' ggplot() +
#' geom_sf(data = splotdata, aes(col = factor(indices$cluster)))
#' ## is this representative?
#' data(splotdata)
#' studyArea <- rnaturalearth::ne_countries(continent = "South America", returnclass = "sf")
#' dist <- geodist(splotdata, studyArea,cvfolds=indices$cluster)
#' plot(dist)+ scale_x_log10(labels=round)
#'
#' }
#' @export CreateSpacetimeFolds
#' @aliases CreateSpacetimeFolds

CreateSpacetimeFolds <- function(x,spacevar=NA,timevar=NA,
                                 k=10,class=NA,seed=sample(1:1000, 1)){
  x <- data.frame(x)
  ### if classification is used, make sure that classes are equally distributed across folds
  if(!is.na(class)){
    if(is.numeric(x[,class])){
      stop("argument class only works for categorical data")
      }
    unit <- unique(x[,c(spacevar,class)])
    unit$CAST_fold <- createFolds(unit[,which(names(unit)==class)],k = k,list=FALSE)
    #x <- merge(x,unit,by.x=c(spacevar,class),by.y=c(spacevar,class),all.x=TRUE,sort=FALSE)
    x <- plyr::join(x,unit,by=c(spacevar,class),match="first")
    spacevar <- "CAST_fold"  }

  if(!is.na(spacevar)){
    if(k>length(unique(x[,spacevar]))){
      k <- length(unique(x[,spacevar]))
      print(paste0("warning: k is higher than number of unique locations. k is set to ",k))
    }
  }
  if(!is.na(timevar)){
    if(k>length(unique(x[,timevar]))){
      k <- length(unique(x[,timevar]))
      print(paste0("warning: k is higher than number of unique points in time. k is set to ",k))
    }
  }
  #split space into k folds
  if(!is.na(spacevar)){
    set.seed(seed)
    spacefolds <- lapply(caret::createFolds(1:length(unique(x[,spacevar])),k),function(y){
      unique(x[,spacevar])[y]})
  }
  #split time into k folds
  if(!is.na(timevar)){
    set.seed(seed)
    timefolds <- lapply(caret::createFolds(1:length(unique(x[,timevar])),k),function(y){
      unique(x[,timevar])[y]})
  }
  # combine space and time folds
  cvindices_train <- list()
  cvindices_test <- list()
  for (i in 1:k){
    if(!is.na(timevar)&!is.na(spacevar)){
      cvindices_test[[i]]<- which(x[,spacevar]%in%spacefolds[[i]]&
                                    x[,timevar]%in%timefolds[[i]])
      cvindices_train[[i]]<- which(!x[,spacevar]%in%spacefolds[[i]]&
                                     !x[,timevar]%in%timefolds[[i]])
    }
    if(is.na(timevar)&!is.na(spacevar)){
      cvindices_test[[i]]<- which(x[,spacevar]%in%spacefolds[[i]])
      cvindices_train[[i]]<- which(!x[,spacevar]%in%spacefolds[[i]])
    }
    if(!is.na(timevar)&is.na(spacevar)){
      cvindices_test[[i]]<- which(x[,timevar]%in%timefolds[[i]])
      cvindices_train[[i]]<- which(!x[,timevar]%in%timefolds[[i]])
    }
  }

  ## summarize folds:
  result <- list("index"=cvindices_train,"indexOut"=cvindices_test)
  cluster <- do.call(rbind, lapply(seq_along(result$indexOut), function(i) {
    data.frame(Number = result$indexOut[[i]], List = i)
  }))
  x$Number <- seq_len(nrow(x))
  df <- merge(x, cluster, by = "Number", all.x = TRUE)
  result$cluster <- df$List
  return(result)
}


#' Model and inspect the relationship between the prediction error and measures of dissimilarities and distances
#' @description Performance metrics are calculated for moving windows of dissimilarity values based on cross-validated training data
#' @param model the model used to get the AOA
#' @param trainDI the result of \code{\link{trainDI}} or aoa object \code{\link{aoa}}
#' @param locations Optional. sf object for the training data used in model. Only used if variable=="geodist". Note that they must be in the same order as model$trainingData.
#' @param variable Character. Which dissimilarity or distance measure to use for the error metric. Current options are "DI" or "LPD"
#' @param multiCV Logical. Re-run model fitting and validation with different CV strategies. See details.
#' @param window.size Numeric. Size of the moving window. See \code{\link[zoo]{rollapply}}.
#' @param calib Character. Function to model the DI/LPD~performance relationship. Currently lm and scam are supported
#' @param length.out Numeric. Only used if multiCV=TRUE. Number of cross-validation folds. See details.
#' @param method Character. Method used for distance calculation. Currently euclidean distance (L2) and Mahalanobis distance (MD) are implemented but only L2 is tested. Note that MD takes considerably longer. See ?aoa for further explanation
#' @param useWeight Logical. Only if a model is given. Weight variables according to importance in the model?
#' @param k Numeric. See mgcv::s
#' @param m Numeric. See mgcv::s
#' @details If multiCV=TRUE the model is re-fitted and validated by length.out new cross-validations where the cross-validation folds are defined by clusters in the predictor space,
#' ranging from three clusters to LOOCV. Hence, a large range of dissimilarity values is created during cross-validation.
#' If the AOA threshold based on the calibration data from multiple CV is larger than the original AOA threshold (which is likely if extrapolation situations are created during CV),
#' the AOA threshold changes accordingly. See Meyer and Pebesma (2021) for the full documentation of the methodology.
#' @return A scam, linear model or exponential model
#' @author
#' Hanna Meyer, Marvin Ludwig, Fabian Schumacher
#' @references Meyer, H., Pebesma, E. (2021): Predicting into unknown space?
#' Estimating the area of applicability of spatial prediction models.
#' \doi{10.1111/2041-210X.13650}
#' @seealso \code{\link{aoa}}
#' @examples
#' \dontrun{
#' library(CAST)
#' library(sf)
#' library(terra)
#' library(caret)
#'
#' data(splotdata)
#' predictors <- terra::rast(system.file("extdata","predictors_chile.tif", package="CAST"))
#'
#' model <- caret::train(st_drop_geometry(splotdata)[,6:16], splotdata$Species_richness,
#'    ntree = 10, trControl = trainControl(method = "cv", savePredictions = TRUE))
#'
#' AOA <- aoa(predictors, model, LPD = TRUE, maxLPD = 1)
#'
#' ### DI ~ error
#' errormodel_DI <- errorProfiles(model, AOA, variable = "DI")
#' plot(errormodel_DI)
#' summary(errormodel_DI)
#'
#' expected_error_DI = terra::predict(AOA$DI, errormodel_DI)
#' plot(expected_error_DI)
#'
#' ### LPD ~ error
#' errormodel_LPD <- errorProfiles(model, AOA, variable = "LPD")
#' plot(errormodel_LPD)
#' summary(errormodel_DI)
#'
#' expected_error_LPD = terra::predict(AOA$LPD, errormodel_LPD)
#' plot(expected_error_LPD)
#'
#' ### geodist ~ error
#' errormodel_geodist = errorProfiles(model, locations=splotdata, variable = "geodist")
#' plot(errormodel_geodist)
#' summary(errormodel_DI)
#'
#' dist <- terra::distance(predictors[[1]],vect(splotdata))
#' names(dist) <- "geodist"
#' expected_error_DI <- terra::predict(dist, errormodel_geodist)
#' plot(expected_error_DI)
#'
#'
#' ### with multiCV = TRUE (for DI ~ error)
#' errormodel_DI = errorProfiles(model, AOA, multiCV = TRUE, length.out = 3, variable = "DI")
#' plot(errormodel_DI)
#'
#' expected_error_DI = terra::predict(AOA$DI, errormodel_DI)
#' plot(expected_error_DI)
#'
#' # mask AOA based on new threshold from multiCV
#' mask_aoa = terra::mask(expected_error_DI, AOA$DI > attr(errormodel_DI, 'AOA_threshold'),
#'   maskvalues = 1)
#' plot(mask_aoa)
#' }
#'
#'
#' @export errorProfiles
#' @aliases errorProfiles DItoErrormetric



errorProfiles <- function(model,
                          trainDI=NULL,
                          locations=NULL,
                          variable = "DI",
                          multiCV=FALSE,
                          length.out = 10,
                          window.size = 5,
                          calib = "scam",
                          method= "L2",
                          useWeight=TRUE,
                          k = 6,
                          m = 2){


  if(inherits(trainDI,"aoa")){
    trainDI = trainDI$parameters
  }

  if(!is.null(locations)&variable=="geodist"){
    message("warning: Please ensure that the order of the locations matches to model$trainingData")
  }


  # get DIs and Errormetrics OR calculate new ones from multiCV
  if(!multiCV){
    preds_all <- get_preds_all(model, trainDI, locations, variable)
  }
  if(multiCV){
    preds_all <- multiCV(model, locations, length.out, method, useWeight, variable)
  }

  # train model between DI and Errormetric
  error_model = errorModel(preds_all, model, window.size, calib,  k, m, variable)

  # save AOA threshold and raw data
  attr(error_model, "AOA_threshold") <- attr(preds_all, "AOA_threshold")
  attr(error_model, "variable") <- attr(preds_all, "variable")
  attr(error_model, "metric") <- attr(preds_all, "metric")
  class(error_model) <- c("errorModel", class(error_model))
  return(error_model)
}






# Model expected error between Metric and DI/LPD
errorModel <- function(preds_all, model, window.size, calib, k, m, variable){

  ## use performance metric from the model:
  rmse <- function(pred,obs){sqrt( mean((pred - obs)^2, na.rm = TRUE) )}
  rsquared <-  function(pred,obs){summary(lm(pred~obs))$r.squared}
  mae <- function(pred,obs){MAE(pred,obs)}
  kappa <- function(pred,obs){
    pred <- factor(pred)
    obs <- factor(obs)
    lev <- unique(c(levels(pred), levels(obs)))
    pred <- factor(pred, levels = lev)
    obs <- factor(obs, levels = lev)
    result <- tryCatch( confusionMatrix(pred, obs)$overall["Kappa"], error = function(e)e)
    if(inherits(result, "error")){result <- 0} # 0 not right value!!! adjust!!!
    return(unname(result))
  }

  accuracy <- function(pred,obs){
    pred <- factor(pred)
    obs <- factor(obs)
    lev <- unique(c(levels(pred), levels(obs)))
    pred <- factor(pred, levels = lev)
    obs <- factor(obs, levels = lev)
    result <- tryCatch(confusionMatrix(pred, obs)$overall["Accuracy"], error = function(e)e)
    if(inherits(result, "error")){result <- 0}
    return(unname(result))
  }
  if(!tolower(model$metric)%in%c("rmse","rsquared","mae","kappa","accuracy")){
    message("Model metric not yet included in this function")
    stop()
  }

  evalfunc <- function(pred,obs){
    eval(parse(text=paste0(tolower(model$metric),"(pred,obs)")))
  }


  # order data according to DI/LPD:
  performance <- preds_all[order(preds_all[,variable]),]
  # calculate performance for moving window:
  performance$metric <- zoo::rollapply(performance[,1:2], window.size,
                                       FUN=function(x){evalfunc(x[,1],x[,2])},
                                       by.column=F,align = "center",fill=NA)
  performance$ll <- data.table::shift(performance[,variable],window.size/2)
  performance$ul <- data.table::shift(performance[,variable],-round(window.size/2),0)
  performance <- performance[!is.na(performance$metric),]

  performance <-  performance[,c(variable,"metric")]
  ### Estimate Error:
  if(calib=="lm"){
    errormodel <- lm(metric ~ ., data = performance)
  }

  if(calib=="scam"){
    if (!requireNamespace("scam", quietly = TRUE)) {
      stop("Package \"scam\" needed for this function to work. Please install it.",
           call. = FALSE)
    }
    if (variable %in% c("DI","geodist")) {
      if (model$maximize){ # e.g. accuracy, kappa, r2
        bs="mpd"
      }else{
        bs="mpi" #e.g. RMSE
      }
      if(variable=="DI"){
        errormodel <- scam::scam(metric~s(DI, k=k, bs=bs, m=m),
                                 data=performance,
                                 family=stats::gaussian(link="identity"))
      }else if (variable=="geodist"){
        errormodel <- scam::scam(metric~s(geodist, k=k, bs=bs, m=m),
                                 data=performance,
                                 family=stats::gaussian(link="identity"))
      }

    } else if (variable == "LPD") {
      if (model$maximize){ # e.g. accuracy, kappa, r2
        bs="mpi"
      }else{
        bs="mpd" #e.g. RMSE
      }
      errormodel <- scam::scam(metric~s(LPD, k=k, bs=bs, m=m),
                               data=performance,
                               family=stats::gaussian(link="identity"))
    }
  }
  if(calib=="exp"){
    if (variable %in% c("DI","geodist")) {
      stop("Exponential model currently only implemented for LPD")
    } else if (variable == "LPD") {
      errormodel <- lm(metric ~ log(LPD), data = performance)
    }
  }

  attr(errormodel, "performance") = performance

  return(errormodel)
}


# MultiCV
multiCV <- function(model, locations, length.out, method, useWeight, variable,...){

  preds_all <- data.frame()
  train_predictors <- model$trainingData[,-which(names(model$trainingData)==".outcome")]
  train_response <- model$trainingData$.outcome

  for (nclst in round(seq(3,nrow(train_predictors), length.out = length.out))){
    # define clusters in predictor space used for CV:
    clstrID <- tryCatch({stats::kmeans(train_predictors,nclst)$cluster},
                        error=function(e)e)
    if(inherits(clstrID,"error")){next}
    clstrID <- clstrID
    folds <- CreateSpacetimeFolds(data.frame("clstrID"=clstrID), spacevar="clstrID",k=nclst)

    # update model call with new CV strategy:
    mcall <- as.list(model$call)
    mcall <- mcall[-which(names(mcall)%in%c("form","data","","x","y","","trControl"))]
    mcall$x <- quote(train_predictors)
    mcall$y <- quote(train_response)
    mcall$trControl <- trainControl(method="cv",index=folds$index,savePredictions = TRUE)
    mcall$tuneGrid <- model$bestTune
    mcall$method <- model$method
    mcall$metric <- model$metric
    mcall$cl <- NULL # fix option for parallel later

    # retrain model and calculate AOA
    model_new <- do.call(caret::train,mcall)
    if (variable == "DI") {
      trainDI_new <- trainDI(model_new, method=method, useWeight=useWeight, verbose =FALSE)
    } else if (variable == "LPD") {
      trainDI_new <- trainDI(model_new, method=method, useWeight=useWeight, LPD = TRUE, verbose =FALSE)
    } else if (variable=="geodist"){
      tmp_gd_new <- CAST::geodist(locations,modeldomain=locations,cvfolds = model$control$indexOut)
      geodist_new <- tmp_gd_new[tmp_gd_new$what=="CV-distances","dist"]

    }


    preds <- model_new$pred
    preds <- preds[order(preds$rowIndex),c("pred","obs")]
    # get cross-validated predictions, order them  and use only those located in the AOA
    if (variable == "DI"){
      preds_dat_tmp <- data.frame(preds,"DI"=trainDI_new$trainDI)
      preds_dat_tmp <-  preds_dat_tmp[preds_dat_tmp$DI <= trainDI_new$threshold,]
      preds_all <- rbind(preds_all,preds_dat_tmp)
    } else if (variable == "LPD"){
      preds_dat_tmp <- data.frame(preds,"LPD"=trainDI_new$trainLPD)
      preds_dat_tmp <-  preds_dat_tmp[preds_dat_tmp$LPD > 0,]
      preds_all <- rbind(preds_all,preds_dat_tmp)
    } else if (variable == "geodist"){
      preds_dat_tmp <- data.frame(preds,"geodist"=geodist_new)
      preds_all <- rbind(preds_all,preds_dat_tmp)
      # NO AOA used here
    }
  }
  if(variable=="DI"|variable=="LPD"){
    attr(preds_all, "AOA_threshold") <- trainDI_new$threshold
    message(paste0("Note: multiCV=TRUE calculated new AOA threshold of ", round(trainDI_new$threshold, 5),
                   "\nThreshold is stored in the attributes, access with attr(error_model, 'AOA_threshold').",
                   "\nPlease refere to examples and details for further information."))
  }
  attr(preds_all, "variable") <- variable
  attr(preds_all, "metric") <- model$metric
  return(preds_all)
}


# Get Preds all
get_preds_all <- function(model, trainDI, locations, variable){

  if(is.null(model$pred)){
    stop("no cross-predictions can be retrieved from the model. Train with savePredictions=TRUE or provide calibration data")
  }

  ## extract cv predictions from model
  preds_all <- model$pred
  for (i in 1:length(model$bestTune)){
    tunevar <- names(model$bestTune[i])
    preds_all <- preds_all[preds_all[,tunevar]==model$bestTune[,tunevar],]
  }
  preds_all <- preds_all[order(preds_all$rowIndex),c("pred","obs")]


  if (variable == "DI") {
    ## add DI from trainDI
    preds_all$DI <- trainDI$trainDI[!is.na(trainDI$trainDI)]
    ## only take predictions from inside the AOA:
    preds_all <-  preds_all[preds_all$DI<=trainDI$threshold,]
  } else if (variable == "LPD") {
    ## add LPD from trainLPD
    preds_all$LPD <- trainDI$trainLPD[!is.na(trainDI$trainLPD)]
    ## only take predictions from inside the AOA:
    preds_all <-  preds_all[preds_all$LPD>0,]
  } else if(variable=="geodist"){
    tmp_gd <- CAST::geodist(locations,modeldomain=locations,cvfolds = model$control$indexOut)
    preds_all$geodist <- tmp_gd[tmp_gd$what=="CV-distances","dist"]
  }

  attr(preds_all, "AOA_threshold") <- trainDI$threshold
  attr(preds_all, "variable") <- variable
  attr(preds_all, "metric") <- model$metric
  return(preds_all)
}



#' Forward feature selection
#' @description A simple forward feature selection algorithm
#' @param predictors see \code{\link[caret]{train}}
#' @param response see \code{\link[caret]{train}}
#' @param method see \code{\link[caret]{train}}
#' @param metric see \code{\link[caret]{train}}
#' @param maximize see \code{\link[caret]{train}}
#' @param globalval Logical. Should models be evaluated based on 'global' performance? See \code{\link{global_validation}}
#' @param withinSE Logical Models are only selected if they are better than the
#' currently best models Standard error
#' @param minVar Numeric. Number of variables to combine for the first selection.
#' See Details.
#' @param trControl see \code{\link[caret]{train}}
#' @param tuneLength see \code{\link[caret]{train}}
#' @param tuneGrid see \code{\link[caret]{train}}
#' @param seed A random number used for model training
#' @param cores Numeric. If > 2, mclapply will be used. see \code{\link{mclapply}}
#' @param verbose Logical. Should information about the progress be printed?
#' @param ... arguments passed to the classification or regression routine
#' (such as randomForest).
#' @return A list of class train. Beside of the usual train content
#' the object contains the vector "selectedvars" and "selectedvars_perf"
#' that give the order of the best variables selected as well as their corresponding
#' performance (starting from the first two variables). It also contains "perf_all"
#' that gives the performance of all model runs.
#' @details Models with two predictors are first trained using all possible
#' pairs of predictor variables. The best model of these initial models is kept.
#' On the basis of this best model the predictor variables are iteratively
#' increased and each of the remaining variables is tested for its improvement
#' of the currently best model. The process stops if none of the remaining
#' variables increases the model performance when added to the current best model.
#'
#' The forward feature selection can be run in parallel with forking on Linux systems (mclapply).
#' Each fork computes a model, which drastically speeds up the runtime -
#' especially of the initial predictor search.
#' The internal cross validation can be run in parallel on all systems. See information
#' on parallel processing of carets train functions for details.
#'
#' Using withinSE will favour models with less variables and
#' probably shorten the calculation time
#'
#' Per Default, the ffs starts with all possible 2-pair combinations.
#' minVar allows to start the selection with more than 2 variables, e.g.
#' minVar=3 starts the ffs testing all combinations of 3 (instead of 2) variables
#' first and then increasing the number. This is important for e.g. neural networks
#' that often cannot make sense of only two variables. It is also relevant if
#' it is assumed that the optimal variables can only be found if more than 2
#' are considered at the same time.
#'
#' @note This variable selection is particularly suitable for spatial
#' cross validations where variable selection
#' MUST be based on the performance of the model for predicting new spatial units.
#' See Meyer et al. (2018) and Meyer et al. (2019) for further details.
#'
#' @author Hanna Meyer
#' @seealso \code{\link[caret]{train}},\code{\link{bss}},
#' \code{\link[caret]{trainControl}},\code{\link{CreateSpacetimeFolds}},\code{\link{nndm}}
#' @references
#' \itemize{
#' \item Gasch, C.K., Hengl, T., Gräler, B., Meyer, H., Magney, T., Brown, D.J. (2015): Spatio-temporal interpolation of soil water, temperature, and electrical conductivity in 3D+T: the Cook Agronomy Farm data set. Spatial Statistics 14: 70-90.
#' \item Meyer, H., Reudenbach, C., Hengl, T., Katurji, M., Nauß, T. (2018): Improving performance of spatio-temporal machine learning models using forward feature selection and target-oriented validation. Environmental Modelling & Software 101: 1-9.  \doi{10.1016/j.envsoft.2017.12.001}
#' \item Meyer, H., Reudenbach, C., Wöllauer, S., Nauss, T. (2019): Importance of spatial predictor variable selection in machine learning applications - Moving from data reproduction to spatial prediction. Ecological Modelling. 411, 108815. \doi{10.1016/j.ecolmodel.2019.108815}.
#' \item Ludwig, M., Moreno-Martinez, A., Hölzel, N., Pebesma, E., Meyer, H. (2023): Assessing and improving the transferability of current global spatial prediction models. Global Ecology and Biogeography. \doi{10.1111/geb.13635}.
#' }
#' @examples
#' \dontrun{
#' data(splotdata)
#' ffsmodel <- ffs(splotdata[,6:12], splotdata$Species_richness, ntree = 20)
#'
#' ffsmodel$selectedvars
#' ffsmodel$selectedvars_perf
#' plot(ffsmodel)
#' #or only selected variables:
#' plot(ffsmodel,plotType="selected")
#'}
#'
#' # or perform model with target-oriented validation (LLO CV)
#' #the example is described in Gasch et al. (2015). The ffs approach for this dataset is described in
#' #Meyer et al. (2018). Due to high computation time needed, only a small and thus not robust example
#' #is shown here.
#'
#' \dontrun{
#' # run the model on three cores (see vignette for details):
#' library(doParallel)
#' library(lubridate)
#' cl <- makeCluster(3)
#' registerDoParallel(cl)
#'
#' #load and prepare dataset:
#' data(cookfarm)
#' trainDat <- cookfarm[cookfarm$altitude==-0.3&
#'   year(cookfarm$Date)==2012&week(cookfarm$Date)%in%c(13:14),]
#'
#' #visualize dataset:
#' ggplot(data = trainDat, aes(x=Date, y=VW)) + geom_line(aes(colour=SOURCEID))
#'
#' #create folds for Leave Location Out Cross Validation:
#' set.seed(10)
#' indices <- CreateSpacetimeFolds(trainDat,spacevar = "SOURCEID",k=3)
#' ctrl <- trainControl(method="cv",index = indices$index)
#'
#' #define potential predictors:
#' predictors <- c("DEM","TWI","BLD","Precip_cum","cday","MaxT_wrcc",
#' "Precip_wrcc","NDRE.M","Bt","MinT_wrcc","Northing","Easting")
#'
#' #run ffs model with Leave Location out CV
#' set.seed(10)
#' ffsmodel <- ffs(trainDat[,predictors],trainDat$VW,method="rf",
#' tuneLength=1,trControl=ctrl)
#' ffsmodel
#' plot(ffsmodel)
#' #or only selected variables:
#' plot(ffsmodel,plotType="selected")
#'
#' #compare to model without ffs:
#' model <- train(trainDat[,predictors],trainDat$VW,method="rf",
#' tuneLength=1, trControl=ctrl)
#' model
#' stopCluster(cl)
#'}
#'
#'\dontrun{
#'## on linux machines, you can also run the ffs in parallel with forks:
#' data("splotdata")
#' spatial_cv = CreateSpacetimeFolds(splotdata, spacevar = "Biome", k = 5)
#' ctrl <- trainControl(method="cv",index = spatial_cv$index)
#'
#'ffsmodel <- ffs(predictors = splotdata[,6:16],
#'                response = splotdata$Species_richness,
#'                tuneLength = 1,
#'                method = "rf",
#'                trControl = ctrl,
#'                ntree = 20,
#'                seed = 1,
#'                cores = 4)
#'}
#'
#'
#' @export ffs
#' @aliases ffs

ffs <- function (predictors,
                 response,
                 method = "rf",
                 metric = ifelse(is.factor(response), "Accuracy", "RMSE"),
                 maximize = ifelse(metric == "RMSE", FALSE, TRUE),
                 globalval=FALSE,
                 withinSE = FALSE,
                 minVar = 2,
                 trControl = caret::trainControl(),
                 tuneLength = 3,
                 tuneGrid = NULL,
                 seed = sample(1:1000, 1),
                 verbose=TRUE,
                 cores = 1,
                 ...){



  # Init ----------------

  ## Input Checks ------------------------------

  if(inherits(predictors, "sf")){
    predictors = sf::st_drop_geometry(predictors)
  }

  if(cores > 1 & .Platform$OS.type != "unix"){
    warning("Parallel computations of ffs only implemented on unix systems. cores is set to 1")
    cores <- 1
  }


  if(inherits(response,"character")){
    response <- factor(response)
    if(metric=="RMSE"){
      metric <- "Accuracy"
      maximize <- TRUE
    }
  }

  if (trControl$method=="LOOCV" & withinSE){
    warning("withinSE is set to FALSE as no SE can be calculated using method LOOCV")
    withinSE <- FALSE
  }

  if(globalval & withinSE){
    warning("withinSE is set to FALSE as no SE can be calculated using global validation")
    withinSE <- FALSE
  }

  ## Define helper functions ---------------

  se <- function(x){sd(x, na.rm = TRUE)/sqrt(length(na.exclude(x)))}

  evalfunc = ifelse(maximize,
                    function(x){max(x,na.rm=TRUE)},
                    function(x){min(x,na.rm=TRUE)})

  isBetter <- function (actmodelperf,bestmodelperf,
                        bestmodelperfSE=NULL,
                        maximization=FALSE,
                        withinSE=FALSE){
    if(withinSE){
      result <- ifelse (!maximization, actmodelperf < bestmodelperf-bestmodelperfSE,
                        actmodelperf > bestmodelperf+bestmodelperfSE)
    }else{
      result <- ifelse (!maximization, actmodelperf < bestmodelperf,
                        actmodelperf > bestmodelperf)
    }
    return(result)
  }


  ## Initialize Variables --------------------------

  trControl$returnResamp <- "final"
  trControl$savePredictions <- "final"
  n <- length(names(predictors))
  acc <- 0
  perf_all <- data.frame(matrix(ncol=length(predictors)+3,
                                nrow=choose(n, minVar)+(n-minVar)*(n-minVar+1)/2))
  names(perf_all) <- c(paste0("var",1:length(predictors)),metric,"SE","nvar")
  minGrid <- t(data.frame(combn(names(predictors),minVar)))


  # Computation -----------------------------------------------
  ## Step 1: Search best initial variables -----
  ## parallel ----------
  if(cores > 1){

    initial_models = parallel::mclapply(X = 1:nrow(minGrid), mc.cores = cores, FUN = function(i){

      set.seed(seed)
      #adaptations for pls:
      tuneGrid_orig <- tuneGrid
      tuneLength_orig <- tuneLength
      if(method=="pls"&!is.null(tuneGrid)&any(tuneGrid$ncomp>minVar)){
        tuneGrid <- data.frame(ncomp=tuneGrid[tuneGrid$ncomp<=minVar,])
        if(verbose){
          print(paste0("note: maximum ncomp is ", minVar))
        }
      }
      #adaptations for tuning of ranger:
      if(method=="ranger"&!is.null(tuneGrid)&any(tuneGrid$mtry>minVar)){
        tuneGrid$mtry <- minVar
        if(verbose){
          print("invalid value for mtry. Reset to valid range.")
        }
      }
      # adaptations for RF and minVar == 1 - tuneLength must be 1, only one mtry possible
      if(minVar==1 & method%in%c("ranger", "rf") & is.null(tuneGrid)){
        tuneLength <- minVar
      }

      #train model:
      model <- caret::train(predictors[minGrid[i,]],
                            response,
                            method=method,
                            metric=metric,
                            trControl=trControl,
                            tuneLength = tuneLength,
                            tuneGrid = tuneGrid)
                            #...)


      tuneGrid <- tuneGrid_orig
      tuneLength <- tuneLength_orig


      if (globalval){
        perf_stats <- global_validation(model)[names(global_validation(model))==metric]
      }else{
        perf_stats <- model$results[,names(model$results)==metric]
      }

      result = as.data.frame(t(minGrid[i,]))
      result$actmodelperf <- evalfunc(perf_stats)
      result$actmodelperfSE <- se(
        sapply(unique(model$resample$Resample),
               FUN=function(x){mean(model$resample[model$resample$Resample==x,
                                                   metric],na.rm=TRUE)}))

      return(result)

    })
    initial_models = do.call(rbind, initial_models)


    ## save best model from initial models

    best_rowindex = ifelse(maximize, which.max(initial_models$actmodelperf), which.min(initial_models$actmodelperf))
    bestmodelperf <- initial_models$actmodelperf[best_rowindex]
    bestmodelperfSE <- initial_models$actmodelperfSE[best_rowindex]
    best_predictors <- as.character(initial_models[best_rowindex, 1:minVar])

    # best minVar model has to be retrained
    #
    #
    bestmodel <- caret::train(predictors[,best_predictors],
                              response,
                              method=method,
                              metric=metric,
                              trControl=trControl,
                              tuneLength = tuneLength,
                              tuneGrid = tuneGrid)
                              #...)


    acc = nrow(minGrid)

    # patching perf_all
    perf_all[1:acc, 1:minVar] <- initial_models[,1:minVar]
    perf_all[1:acc, (ncol(perf_all)-2):(ncol(perf_all)-1)] <- initial_models[,(ncol(initial_models)-1):ncol(initial_models)]
    perf_all$nvar[1:nrow(minGrid)] <- minVar



  }else{
    ## unparallel  -------------

    for (i in 1:nrow(minGrid)){
      if (verbose){
        print(paste0("model using ",paste0(minGrid[i,],collapse=","), " will be trained now..." ))
      }
      set.seed(seed)
      #adaptations for pls:
      tuneGrid_orig <- tuneGrid
      tuneLength_orig <- tuneLength
      if(method=="pls"&!is.null(tuneGrid)&any(tuneGrid$ncomp>minVar)){
        tuneGrid <- data.frame(ncomp=tuneGrid[tuneGrid$ncomp<=minVar,])
        if(verbose){
          print(paste0("note: maximum ncomp is ", minVar))
        }
      }
      #adaptations for tuning of ranger:
      if(method=="ranger"&!is.null(tuneGrid)&any(tuneGrid$mtry>minVar)){
        tuneGrid$mtry <- minVar
        if(verbose){
          print("invalid value for mtry. Reset to valid range.")
        }
      }
      # adaptations for RF and minVar == 1 - tuneLength must be 1, only one mtry possible
      if(minVar==1 & method%in%c("ranger", "rf") & is.null(tuneGrid)){
        tuneLength <- minVar
      }

      #train model:
      model <- caret::train(predictors[minGrid[i,]],
                            response,
                            method=method,
                            metric=metric,
                            trControl=trControl,
                            tuneLength = tuneLength,
                            tuneGrid = tuneGrid,
                            ...)

      tuneGrid <- tuneGrid_orig
      tuneLength <- tuneLength_orig

      ### compare the model with the currently best model
      if (globalval){
        perf_stats <- global_validation(model)[names(global_validation(model))==metric]
      }else{
        perf_stats <- model$results[,names(model$results)==metric]
      }
      actmodelperf <- evalfunc(perf_stats)
      actmodelperfSE <- se(
        sapply(unique(model$resample$Resample),
               FUN=function(x){mean(model$resample[model$resample$Resample==x,
                                                   metric],na.rm=TRUE)}))
      if (i == 1){
        bestmodelperf <- actmodelperf
        bestmodelperfSE <- actmodelperfSE
        bestmodel <- model
      } else{
        if (isBetter(actmodelperf,bestmodelperf,maximization=maximize,withinSE=FALSE)){
          bestmodelperf <- actmodelperf
          bestmodelperfSE <- actmodelperfSE
          bestmodel <- model
        }
      }
      acc <- acc+1

      variablenames <- names(model$trainingData)[-length(names(model$trainingData))]
      perf_all[acc,1:length(variablenames)] <- variablenames
      perf_all[acc,(length(predictors)+1):ncol(perf_all)] <- c(actmodelperf,actmodelperfSE,length(variablenames))
      if(verbose){
        print(paste0("maximum number of models that still need to be trained: ",
                     round(choose(n, minVar)+(n-minVar)*(n-minVar+1)/2-acc,0)))
      }
    }


  }


  ## both --------

  selectedvars <- names(bestmodel$trainingData)[-which(
    names(bestmodel$trainingData)==".outcome")]


  if (globalval){
    selectedvars_perf <- global_validation(bestmodel)[names(global_validation(bestmodel))==metric]
  }else{
    if (maximize){
      selectedvars_perf <-max(bestmodel$results[,metric])
    }else{
      selectedvars_perf <- min(bestmodel$results[,metric])
    }
  }
  selectedvars_SE <- bestmodelperfSE
  if(verbose){
    print(paste0(paste0("vars selected: ",paste(selectedvars, collapse = ',')),
                 " with ",metric," ",round(selectedvars_perf,3)))
  }

  ## Step 2: Append more variables ------
  # increase the number of predictors by one (try all combinations)
  # and test if model performance increases
  # k: amount of "additional variables" left after initial search
  # for each k: search best additional predictor

  ## parallel -----

  if(cores > 1){
  for(k in 1:(length(names(predictors))-minVar)){
    startvars <- names(bestmodel$trainingData)[-which(
      names(bestmodel$trainingData)==".outcome")]
    nextvars <- names(predictors)[-which(
      names(predictors)%in%startvars)]

    if(verbose){
    print(paste0("Searching for additional variable ", minVar + k, " now. ",
                 length(nextvars), " potential predictors are available:"))
    print(nextvars)
    }


    # search best additional variable in parallel
    next_models <- parallel::mclapply(1:length(nextvars), mc.cores = cores, FUN = function(i){


      set.seed(seed)

      #adaptation for pls:
      tuneGrid_orig <- tuneGrid
      if(method=="pls"&!is.null(tuneGrid)&any(tuneGrid$ncomp>ncol(predictors[,c(startvars,nextvars[i])]))){
        tuneGrid<- data.frame(ncomp=tuneGrid[tuneGrid$ncomp<=ncol(predictors[,c(startvars,nextvars[i])]),])
        if(verbose){
          print(paste0("note: maximum ncomp is ", ncol(predictors[,c(startvars,nextvars[i])])))
        }}
      #adaptation for ranger:
      if(method=="ranger"&!is.null(tuneGrid)&any(tuneGrid$mtry>ncol(predictors[,c(startvars,nextvars[i])]))){
        tuneGrid$mtry[tuneGrid$mtry>ncol(predictors[,c(startvars,nextvars[i])])] <- ncol(predictors[,c(startvars,nextvars[i])])
        if(verbose){
          print("invalid value for mtry. Reset to valid range.")
        }
      }

      model <- caret::train(predictors[,c(startvars,nextvars[i])],
                            response,
                            method = method,
                            metric=metric,
                            trControl = trControl,
                            tuneLength = tuneLength,
                            tuneGrid = tuneGrid,
                            ...)
      tuneGrid <- tuneGrid_orig



      if (globalval){
        perf_stats <- global_validation(model)[names(global_validation(model))==metric]
      }else{
        perf_stats <- model$results[,names(model$results)==metric]
      }


      startvars
      result = as.data.frame(t(startvars))
      result$nextvar = nextvars[i]
      result$actmodelperf <- evalfunc(perf_stats)
      result$actmodelperfSE <- se(
        sapply(unique(model$resample$Resample),
               FUN=function(x){mean(model$resample[model$resample$Resample==x,
                                                   metric],na.rm=TRUE)}))

      return(result)

    })

    next_models = do.call(rbind, next_models)


    ## best next_model
    best_next_rowindex = ifelse(maximize,
                           which.max(next_models[,(ncol(next_models)-1)]),
                           which.min(next_models[,(ncol(next_models)-1)]))

    better = isBetter(actmodelperf = next_models$actmodelperf[best_next_rowindex],
                      bestmodelperf = bestmodelperf,
                      bestmodelperfSE = bestmodelperfSE,
                      maximization = maximize, withinSE = withinSE)



    # patching perf_all
    perf_all[(acc+1):(acc+length(nextvars)), 1:(minVar+k)] <- next_models[,1:(minVar+k)]
    perf_all[(acc+1):(acc+length(nextvars)), (ncol(perf_all)-2):(ncol(perf_all)-1)] <- next_models[,(ncol(next_models)-1):ncol(next_models)]
    perf_all$nvar[(acc+1):(acc+length(nextvars))] <- minVar+k

    if(better){
      # update best model stats
      bestmodelperf = next_models$actmodelperf[best_next_rowindex]
      bestmodelperfSE = next_models$actmodelperfSE[best_next_rowindex]
      best_predictors = as.character(next_models[best_next_rowindex, 1:(minVar+k)])

      selectedvars_perf = c(selectedvars_perf, bestmodelperf)
      selectedvars_SE = c(selectedvars_SE, bestmodelperfSE)


      bestmodel <- caret::train(predictors[,best_predictors],
                                response,
                                method=method,
                                metric=metric,
                                trControl=trControl,
                                tuneLength = tuneLength,
                                tuneGrid = tuneGrid,
                                ...)


      acc = acc+nrow(next_models)


    }else{
      # not better: return model and stats
      message(paste0("Note: No increase in performance found using more than ",
                     length(startvars), " variables"))
      bestmodel$selectedvars <- best_predictors
      bestmodel$selectedvars_perf <- selectedvars_perf
      bestmodel$selectedvars_perf_SE <- selectedvars_SE
      bestmodel$perf_all <- perf_all
      bestmodel$perf_all <- bestmodel$perf_all[!apply(is.na(bestmodel$perf_all), 1, all),]
      bestmodel$perf_all <- bestmodel$perf_all[colSums(!is.na(bestmodel$perf_all)) > 0]
      bestmodel$minVar <- minVar
      bestmodel$type <- "ffs"
      class(bestmodel) <- c("ffs", "train")
      return(bestmodel)

    }




  }# end of k loop




  }else{
    ## unparallel -----
  for (k in 1:(length(names(predictors))-minVar)){
    startvars <- names(bestmodel$trainingData)[-which(
      names(bestmodel$trainingData)==".outcome")]
    nextvars <- names(predictors)[-which(
      names(predictors)%in%startvars)]
    if (length(startvars)<(k+(minVar-1))){
      message(paste0("Note: No increase in performance found using more than ",
                     length(startvars), " variables"))
      bestmodel$selectedvars <- selectedvars
      bestmodel$selectedvars_perf <- selectedvars_perf[-length(selectedvars_perf)]
      bestmodel$selectedvars_perf_SE <- selectedvars_SE[-length(selectedvars_SE)] #!!!
      bestmodel$perf_all <- perf_all
      bestmodel$perf_all <- bestmodel$perf_all[!apply(is.na(bestmodel$perf_all), 1, all),]
      bestmodel$perf_all <- bestmodel$perf_all[colSums(!is.na(bestmodel$perf_all)) > 0]
      bestmodel$minVar <- minVar
      bestmodel$type <- "ffs"
      class(bestmodel) <- c("ffs", "train")
      return(bestmodel)
      break()
    }
    for (i in 1:length(nextvars)){
      if(verbose){
        print(paste0("model using additional variable ",nextvars[i], " will be trained now..." ))
      }
      set.seed(seed)

      #adaptation for pls:
      tuneGrid_orig <- tuneGrid
      if(method=="pls"&!is.null(tuneGrid)&any(tuneGrid$ncomp>ncol(predictors[,c(startvars,nextvars[i])]))){
        tuneGrid<- data.frame(ncomp=tuneGrid[tuneGrid$ncomp<=ncol(predictors[,c(startvars,nextvars[i])]),])
        if(verbose){
          print(paste0("note: maximum ncomp is ", ncol(predictors[,c(startvars,nextvars[i])])))
        }}
      #adaptation for ranger:
      if(method=="ranger"&!is.null(tuneGrid)&any(tuneGrid$mtry>ncol(predictors[,c(startvars,nextvars[i])]))){
        tuneGrid$mtry[tuneGrid$mtry>ncol(predictors[,c(startvars,nextvars[i])])] <- ncol(predictors[,c(startvars,nextvars[i])])
        if(verbose){
          print("invalid value for mtry. Reset to valid range.")
        }
      }

      model <- caret::train(predictors[,c(startvars,nextvars[i])],
                            response,
                            method = method,
                            metric=metric,
                            trControl = trControl,
                            tuneLength = tuneLength,
                            tuneGrid = tuneGrid,
                            ...)
      tuneGrid <- tuneGrid_orig

      if (globalval){
        perf_stats <- global_validation(model)[names(global_validation(model))==metric]
      }else{
        perf_stats <- model$results[,names(model$results)==metric]
      }
      actmodelperf <- evalfunc(perf_stats)

      actmodelperfSE <- se(
        sapply(unique(model$resample$Resample),
               FUN=function(x){mean(model$resample[model$resample$Resample==x,
                                                   metric],na.rm=TRUE)}))
      if(isBetter(actmodelperf,bestmodelperf,
                  selectedvars_SE[length(selectedvars_SE)], #SE from model with nvar-1
                  maximization=maximize,withinSE=withinSE)){
        bestmodelperf <- actmodelperf
        bestmodelperfSE <- actmodelperfSE
        bestmodel <- model
      }
      acc <- acc+1

      variablenames <- names(model$trainingData)[-length(names(model$trainingData))]
      perf_all[acc,1:length(variablenames)] <- variablenames
      perf_all[acc,(length(predictors)+1):ncol(
        perf_all)] <- c(actmodelperf,actmodelperfSE,length(variablenames))
      if(verbose){
        print(paste0("maximum number of models that still need to be trained: ",
                     round(choose(n, minVar)+(n-minVar)*(n-minVar+1)/2-acc,0)))
      }
    }
    selectedvars <- c(selectedvars,names(bestmodel$trainingData)[-which(
      names(bestmodel$trainingData)%in%c(".outcome",selectedvars))])
    selectedvars_SE <- c(selectedvars_SE,bestmodelperfSE)




    if (maximize){
      if(globalval){
        selectedvars_perf <- c(selectedvars_perf,global_validation(bestmodel)[names(global_validation(bestmodel))==metric])
      }else{
        selectedvars_perf <- c(selectedvars_perf,max(bestmodel$results[,metric]))
      }
    }
    if (!maximize){
      if(globalval){
        selectedvars_perf <- c(selectedvars_perf,global_validation(bestmodel)[names(global_validation(bestmodel))==metric])
      }else{
        selectedvars_perf <- c(selectedvars_perf,min(bestmodel$results[,metric]))
      }
    }
    if(verbose){
      print(paste0(paste0("vars selected: ",paste(selectedvars, collapse = ',')),
                   " with ",metric," ",round(selectedvars_perf[length(selectedvars_perf)],3)))
    }
  }

  }
## return best model --------



  bestmodel$selectedvars <- selectedvars
  bestmodel$selectedvars_perf <- selectedvars_perf
  bestmodel$selectedvars_perf_SE <- selectedvars_SE
  if(globalval){
    bestmodel$selectedvars_perf_SE <- NA
  }
  bestmodel$perf_all <- perf_all
  bestmodel$perf_all <- bestmodel$perf_all[!apply(is.na(bestmodel$perf_all), 1, all),]
  bestmodel$minVar <- minVar
  bestmodel$type <- "ffs"
  bestmodel$perf_all <- bestmodel$perf_all[colSums(!is.na(bestmodel$perf_all)) > 0]
  class(bestmodel) <- c("ffs", "train")
  return(bestmodel)
}




#' Calculate euclidean nearest neighbor distances in geographic space or feature space
#'
#' @description Calculates nearest neighbor distances in geographic space or feature space between training data as well as between training data and prediction locations.
#' Optional, the nearest neighbor distances between training data and test data or between training data and CV iterations is computed.
#' @param x object of class sf, training data locations
#' @param modeldomain SpatRaster, stars or sf object defining the prediction area (see Details)
#' @param type "geo" or "feature". Should the distance be computed in geographic space or in the normalized multivariate predictor space (see Details)
#' @param cvfolds optional. list or vector. Either a list where each element contains the data points used for testing during the cross validation iteration (i.e. held back data).
#' Or a vector that contains the ID of the fold for each training point. See e.g. ?createFolds or ?CreateSpacetimeFolds or ?nndm
#' @param cvtrain optional. List of row indices of x to fit the model to in each CV iteration. If cvtrain is null but cvfolds is not, all samples but those included in cvfolds are used as training data
#' @param testdata optional. object of class sf: Point data used for independent validation
#' @param preddata optional. object of class sf: Point data indicating the locations within the modeldomain to be used as target prediction points. Useful when the prediction objective is a subset of
#' locations within the modeldomain rather than the whole area.
#' @param samplesize numeric. How many prediction samples should be used?
#' @param sampling character. How to draw prediction samples? See \link[sp]{spsample}. Use sampling = "Fibonacci" for global applications.
#' @param variables character vector defining the predictor variables used if type="feature. If not provided all variables included in modeldomain are used.
#' @param timevar optional. character. Column that indicates the date. Only used if type="time".
#' @param time_unit optional. Character. Unit for temporal distances See ?difftime.Only used if type="time".
#' @param algorithm see \code{\link[FNN]{knnx.dist}} and \code{\link[FNN]{knnx.index}}
#' @return A data.frame containing the distances. Unit of returned geographic distances is meters. attributes contain W statistic between prediction area and either sample data, CV folds or test data. See details.
#' @details The modeldomain is a sf polygon or a raster that defines the prediction area. The function takes a regular point sample (amount defined by samplesize) from the spatial extent.
#'     If type = "feature", the argument modeldomain (and if provided then also the testdata and/or preddata) has to include predictors. Predictor values for x, testdata and preddata are optional if modeldomain is a raster.
#'     If not provided they are extracted from the modeldomain rasterStack. If some predictors are categorical (i.e., of class factor or character), gower distances will be used.
#'     W statistic describes the match between the distributions. See Linnenbrink et al (2023) for further details.
#' @note See Meyer and Pebesma (2022) for an application of this plotting function
#' @seealso \code{\link{nndm}} \code{\link{knndm}}
#' @import ggplot2
#' @author Hanna Meyer, Edzer Pebesma, Marvin Ludwig, Jan Linnenbrink
#' @examples
#' \dontrun{
#' library(CAST)
#' library(sf)
#' library(terra)
#' library(caret)
#' library(rnaturalearth)
#' library(ggplot2)
#'
#' data(splotdata)
#' studyArea <- rnaturalearth::ne_countries(continent = "South America", returnclass = "sf")
#'
#' ########### Distance between training data and new data:
#' dist <- geodist(splotdata, studyArea)
#' # With density functions
#' plot(dist)
#' # Or ECDFs (relevant for nndm and knnmd methods)
#' plot(dist, stat="ecdf")
#'
#' ########### Distance between training data, new data and test data (here Chile):
#' plot(splotdata[,"Country"])
#' dist <- geodist(splotdata[splotdata$Country != "Chile",], studyArea,
#'                 testdata = splotdata[splotdata$Country == "Chile",])
#' plot(dist)
#'
#' ########### Distance between training data, new data and CV folds:
#' folds <- createFolds(1:nrow(splotdata), k=3, returnTrain=FALSE)
#' dist <- geodist(x=splotdata, modeldomain=studyArea, cvfolds=folds)
#' # Using density functions
#' plot(dist)
#' # Using ECDFs (relevant for nndm and knnmd methods)
#' plot(dist, stat="ecdf")
#'
#' ########### Distances in the feature space:
#' predictors <- terra::rast(system.file("extdata","predictors_chile.tif", package="CAST"))
#' dist <- geodist(x = splotdata,
#'                 modeldomain = predictors,
#'                 type = "feature",
#'                 variables = c("bio_1","bio_12", "elev"))
#' plot(dist)
#'
#' dist <- geodist(x = splotdata[splotdata$Country != "Chile",],
#'                 modeldomain = predictors, cvfolds = folds,
#'                 testdata = splotdata[splotdata$Country == "Chile",],
#'                 type = "feature",
#'                 variables=c("bio_1","bio_12", "elev"))
#' plot(dist)
#'
#'############Distances in temporal space
#' library(lubridate)
#' library(ggplot2)
#' data(cookfarm)
#' dat <- st_as_sf(cookfarm,coords=c("Easting","Northing"))
#' st_crs(dat) <- 26911
#' trainDat <- dat[dat$altitude==-0.3&lubridate::year(dat$Date)==2010,]
#' predictionDat <- dat[dat$altitude==-0.3&lubridate::year(dat$Date)==2011,]
#' trainDat$week <- lubridate::week(trainDat$Date)
#' cvfolds <- CreateSpacetimeFolds(trainDat,timevar = "week")
#'
#' dist <- geodist(trainDat,preddata = predictionDat,cvfolds = cvfolds$indexOut,
#'    type="time",time_unit="days")
#' plot(dist)+ xlim(0,10)
#'
#'
#' ############ Example for a random global dataset
#' ############ (refer to figure in Meyer and Pebesma 2022)
#'
#' ### Define prediction area (here: global):
#' ee <- st_crs("+proj=eqearth")
#' co <- ne_countries(returnclass = "sf")
#' co.ee <- st_transform(co, ee)
#'
#' ### Simulate a spatial random sample
#' ### (alternatively replace pts_random by a real sampling dataset (see Meyer and Pebesma 2022):
#' sf_use_s2(FALSE)
#' pts_random <- st_sample(co.ee, 2000, exact=FALSE)
#'
#' ### See points on the map:
#' ggplot() + geom_sf(data = co.ee, fill="#00BFC4",col="#00BFC4") +
#'   geom_sf(data = pts_random, color = "#F8766D",size=0.5, shape=3) +
#'   guides(fill = "none", col = "none") +
#'   labs(x = NULL, y = NULL)
#'
#' ### plot distances:
#' dist <- geodist(pts_random,co.ee)
#' plot(dist) + scale_x_log10(labels=round)
#'
#'
#'
#'
#'}
#' @export

geodist <- function(x,
                    modeldomain=NULL,
                    type = "geo",
                    cvfolds=NULL,
                    cvtrain=NULL,
                    testdata=NULL,
                    preddata=NULL,
                    samplesize=2000,
                    sampling = "regular",
                    variables=NULL,
                    timevar=NULL,
                    time_unit="auto",
                    algorithm="brute"){

  # input formatting ------------
  if(is.null(modeldomain)&!is.null(preddata)){
    modeldomain <- sf::st_bbox(preddata)
  }
  if (inherits(modeldomain, "Raster")) {
    modeldomain <- methods::as(modeldomain,"SpatRaster")
  }
  if (inherits(modeldomain, "stars")) {
    if (!requireNamespace("stars", quietly = TRUE))
      stop("package stars required: install that first")
    modeldomain <- methods::as(modeldomain, "SpatRaster")
  }




  if(type == "feature"){
    if(is.null(variables)){
      variables <- names(modeldomain)
    }
    if(any(!variables%in%names(x))){ # extract variable values of raster:
      message("features are extracted from the modeldomain")
      x <- sf::st_transform(x,sf::st_crs(modeldomain))

      if(class(x)[1]=="sfc_POINT"){
        x <- sf::st_as_sf(x)
      }
      x <- sf::st_as_sf(terra::extract(modeldomain, terra::vect(x), na.rm=FALSE,bind=TRUE))
    }
    x <- sf::st_transform(x,4326)
    if(!is.null(testdata)){
      if(any(!variables%in%names(testdata))){# extract variable values of raster:
        testdata <- sf::st_transform(testdata,sf::st_crs(modeldomain))
        testdata <- sf::st_as_sf(terra::extract(modeldomain, terra::vect(testdata), na.rm=FALSE,bind=TRUE))

        if(any(is.na(testdata))){
          testdata <- na.omit(testdata)
          message("some test data were removed because of NA in extracted predictor values")
        }

        testdata <- sf::st_transform(testdata,4326)
      }
    }
    if(!is.null(preddata)){
      if(any(!variables%in%names(preddata))){# extract variable values of raster:
        preddata <- sf::st_transform(preddata,sf::st_crs(modeldomain))
        preddata <- sf::st_as_sf(terra::extract(modeldomain, terra::vect(preddata), na.rm=FALSE,bind=TRUE))

        if(any(is.na(preddata))){
          preddata <- na.omit(preddata)
          message("some prediction data were removed because of NA in extracted predictor values")
        }

        preddata <- sf::st_transform(preddata,4326)
      }
    }
    # get names of categorical variables
    catVars <- names(x[,variables])[which(sapply(x[,variables], class)%in%c("factor","character"))]
    if(length(catVars)==0) {
      catVars <- NULL
    }
    if(!is.null(catVars)) {
      message(paste0("variable(s) '", catVars, "' is (are) treated as categorical variables"))
    }
  }
  if(type != "feature") {
    x <- sf::st_transform(x,4326)
    catVars <- NULL
  }
  if (type=="time" & is.null(timevar)){
    timevar <- names(which(sapply(x, lubridate::is.Date)))
    message("time variable that has been selected: ",timevar)
  }
  if (type=="time"&time_unit=="auto"){
    time_unit <- units(difftime(sf::st_drop_geometry(x)[,timevar],
                                sf::st_drop_geometry(x)[,timevar]))
  }



  # required steps ----

  ## Sample prediction location from the study area if preddata not available:
  if(is.null(preddata)){
    modeldomain <- sampleFromArea(modeldomain, samplesize, type,variables,sampling, catVars)
  } else{
    modeldomain <- preddata
  }

  # always do sample-to-sample and sample-to-prediction
  s2s <- sample2sample(x, type,variables,time_unit,timevar, catVars, algorithm=algorithm)
  s2p <- sample2prediction(x, modeldomain, type, samplesize,variables,time_unit,timevar, catVars, algorithm=algorithm)

  dists <- rbind(s2s, s2p)

  # optional steps ----
  ##### Distance to test data:
  if(!is.null(testdata)){
    s2t <- sample2test(x, testdata, type,variables,time_unit,timevar, catVars, algorithm=algorithm)
    dists <- rbind(dists, s2t)
  }

  ##### Distance to CV data:
  if(!is.null(cvfolds)){

    cvd <- cvdistance(x, cvfolds, cvtrain, type, variables,time_unit,timevar, catVars, algorithm=algorithm)
    dists <- rbind(dists, cvd)
  }
  class(dists) <- c("geodist", class(dists))
  attr(dists, "type") <- type

  if(type=="time"){
    attr(dists, "unit") <- time_unit
  }


  ##### Compute W statistics
  W_sample <- twosamples::wass_stat(dists[dists$what == "sample-to-sample", "dist"],
                                    dists[dists$what == "prediction-to-sample", "dist"])
  attr(dists, "W_sample") <- W_sample
  if(!is.null(testdata)){
    W_test <- twosamples::wass_stat(dists[dists$what == "test-to-sample", "dist"],
                                    dists[dists$what == "prediction-to-sample", "dist"])
    attr(dists, "W_test") <- W_test
  }
  if(!is.null(cvfolds)){
    W_CV <- twosamples::wass_stat(dists[dists$what == "CV-distances", "dist"],
                                  dists[dists$what == "prediction-to-sample", "dist"])
    attr(dists, "W_CV") <- W_CV
  }

  return(dists)
}




# Sample to Sample Distance

sample2sample <- function(x, type,variables,time_unit,timevar, catVars, algorithm){
  if(type == "geo"){
    sf::sf_use_s2(TRUE)
    d <- sf::st_distance(x)
    diag(d) <- Inf
    min_d <- apply(d, 1, min)
    sampletosample <- data.frame(dist = min_d,
                                 what = factor("sample-to-sample"),
                                 dist_type = "geo")
  }else if(type == "feature"){
    x <- x[,variables]
    x <- sf::st_drop_geometry(x)

    if(!is.null(catVars)) {
      x_cat <- x[,catVars,drop=FALSE]
      x_num <- x[,-which(names(x)%in%catVars),drop=FALSE]
      scaleparam <- attributes(scale(x_num))
      x_num <- data.frame(scale(x_num))
      x <- as.data.frame(cbind(x_num, lapply(x_cat, as.factor)))
      x_clean <- x[complete.cases(x),]
    } else {
      scaleparam <- attributes(scale(x))
      x <- data.frame(scale(x))
      x_clean <- data.frame(x[complete.cases(x),])
    }

    # sample to sample feature distance
    d <- c()
    for (i in 1:nrow(x_clean)){

      if(is.null(catVars)) {
        trainDist <-  FNN::knnx.dist(x_clean[i,],x_clean,k=1, algorithm=algorithm)
      } else {
        trainDist <- gower::gower_dist(x_clean[i,],x_clean)
      }

      trainDist[i] <- NA
      d <- c(d,min(trainDist,na.rm=T))
    }
    sampletosample <- data.frame(dist = d,
                                 what = factor("sample-to-sample"),
                                 dist_type = "feature")

  }else if(type == "time"){ # calculate temporal distance matrix
    d <- matrix(ncol=nrow(x),nrow=nrow(x))
    for (i in 1:nrow(x)){
      d[i,] <- abs(difftime(sf::st_drop_geometry(x)[,timevar],
                            sf::st_drop_geometry(x)[i,timevar],
                            units=time_unit))
    }
    diag(d) <- Inf
    min_d <- apply(d, 1, min)
    sampletosample <- data.frame(dist = min_d,
                                 what = factor("sample-to-sample"),
                                 dist_type = "time")
  }
  return(sampletosample)
}


# Sample to Prediction
sample2prediction = function(x, modeldomain, type, samplesize,variables,time_unit,timevar, catVars, algorithm){

  if(type == "geo"){
    modeldomain <- sf::st_transform(modeldomain, sf::st_crs(x))
    sf::sf_use_s2(TRUE)
    d0 <- sf::st_distance(modeldomain, x)
    min_d0 <- apply(d0, 1, min)
    sampletoprediction <- data.frame(dist = min_d0,
                                     what = factor("prediction-to-sample"),
                                     dist_type = "geo")

  }else if(type == "feature"){
    x <- x[,variables]
    x <- sf::st_drop_geometry(x)
    modeldomain <- modeldomain[,variables]
    modeldomain <- sf::st_drop_geometry(modeldomain)

    if(!is.null(catVars)) {

      x_cat <- x[,catVars,drop=FALSE]
      x_num <- x[,-which(names(x)%in%catVars),drop=FALSE]
      scaleparam <- attributes(scale(x_num))
      x_num <- data.frame(scale(x_num))

      modeldomain_num <- modeldomain[,-which(names(modeldomain)%in%catVars),drop=FALSE]
      modeldomain_cat <- modeldomain[,catVars,drop=FALSE]
      modeldomain_num <- data.frame(scale(modeldomain_num,center=scaleparam$`scaled:center`,
                                          scale=scaleparam$`scaled:scale`))

      x <- as.data.frame(cbind(x_num, lapply(x_cat, as.factor)))
      x_clean <- x[complete.cases(x),]
      modeldomain <- as.data.frame(cbind(modeldomain_num, lapply(modeldomain_cat, as.factor)))

    } else {
      scaleparam <- attributes(scale(x))
      x <- data.frame(scale(x))
      x_clean <- x[complete.cases(x),]

      modeldomain <- data.frame(scale(modeldomain,center=scaleparam$`scaled:center`,
                                      scale=scaleparam$`scaled:scale`))
    }


    target_dist_feature <- c()
    for (i in 1:nrow(modeldomain)){

      if(is.null(catVars)) {
        trainDist <-  FNN::knnx.dist(modeldomain[i,],x_clean,k=1, algorithm=algorithm)
      } else {
        trainDist <- gower::gower_dist(modeldomain[i,], x_clean)
      }

      target_dist_feature <- c(target_dist_feature,min(trainDist,na.rm=T))
    }
    sampletoprediction <- data.frame(dist = target_dist_feature,
                                     what = "prediction-to-sample",
                                     dist_type = "feature")
  }else if(type == "time"){

    min_d0 <- c()
    for (i in 1:nrow(modeldomain)){
      min_d0[i] <- min(abs(difftime(sf::st_drop_geometry(modeldomain)[i,timevar],
                                    sf::st_drop_geometry(x)[,timevar],
                                    units=time_unit)))
    }

    sampletoprediction <- data.frame(dist = min_d0,
                                     what = factor("prediction-to-sample"),
                                     dist_type = "time")

  }

  return(sampletoprediction)
}


# sample to test


sample2test <- function(x, testdata, type,variables,time_unit,timevar, catVars, algorithm){

  if(type == "geo"){
    testdata <- sf::st_transform(testdata,4326)
    d_test <- sf::st_distance(testdata, x)
    min_d_test <- apply(d_test, 1, min)

    dists_test <- data.frame(dist = min_d_test,
                             what = factor("test-to-sample"),
                             dist_type = "geo")


  }else if(type == "feature"){


    x <- x[,variables]
    x <- sf::st_drop_geometry(x)
    testdata <- testdata[,variables]
    testdata <- sf::st_drop_geometry(testdata)

    if(!is.null(catVars)) {

      x_cat <- x[,catVars,drop=FALSE]
      x_num <- x[,-which(names(x)%in%catVars),drop=FALSE]
      scaleparam <- attributes(scale(x_num))
      x_num <- data.frame(scale(x_num))

      testdata_num <- testdata[,-which(names(testdata)%in%catVars),drop=FALSE]
      testdata_cat <- testdata[,catVars,drop=FALSE]
      testdata_num <- data.frame(scale(testdata_num,center=scaleparam$`scaled:center`,
                                       scale=scaleparam$`scaled:scale`))

      x <- as.data.frame(cbind(x_num, lapply(x_cat, as.factor)))
      x_clean <- x[complete.cases(x),]
      testdata <- as.data.frame(cbind(testdata_num, lapply(testdata_cat, as.factor)))

    } else {
      scaleparam <- attributes(scale(x))
      x <- data.frame(scale(x))
      x_clean <- x[complete.cases(x),]

      testdata <- data.frame(scale(testdata,center=scaleparam$`scaled:center`,
                                   scale=scaleparam$`scaled:scale`))
    }


    test_dist_feature <- c()
    for (i in 1:nrow(testdata)){

      if(is.null(catVars)) {
        testDist <- FNN::knnx.dist(testdata[i,],x_clean,k=1, algorithm=algorithm)
      } else {
        testDist <- gower::gower_dist(testdata[i,], x_clean)
      }
      test_dist_feature <- c(test_dist_feature,min(testDist,na.rm=T))
    }
    dists_test <- data.frame(dist = test_dist_feature,
                             what = "test-to-sample",
                             dist_type = "feature")
  }else if (type=="time"){
    min_d0 <- c()
    for (i in 1:nrow(testdata)){
      min_d0[i] <- min(abs(difftime(sf::st_drop_geometry(testdata)[i,timevar],
                                    sf::st_drop_geometry(x)[,timevar],
                                    units=time_unit)))
    }

    dists_test <- data.frame(dist = min_d0,
                             what = factor("test-to-sample"),
                             dist_type = "time")



  }
  return(dists_test)
}



# between folds

cvdistance <- function(x, cvfolds, cvtrain, type, variables,time_unit,timevar, catVars, algorithm){

  if(!is.null(cvfolds)&!is.list(cvfolds)){ # restructure input if CVtest only contains the fold ID
    tmp <- list()
    for (i in unique(cvfolds)){
      tmp[[i]] <- which(cvfolds==i)
    }
    cvfolds <- tmp
  }


  if(type == "geo"){
    d_cv <- c()
    for (i in 1:length(cvfolds)){

      if(!is.null(cvtrain)){
        d_cv_tmp <- sf::st_distance(x[cvfolds[[i]],], x[cvtrain[[i]],])
      }else{
        d_cv_tmp <- sf::st_distance(x[cvfolds[[i]],], x[-cvfolds[[i]],])
      }
      d_cv <- c(d_cv,apply(d_cv_tmp, 1, min))
    }

    dists_cv <- data.frame(dist = d_cv,
                           what = factor("CV-distances"),
                           dist_type = "geo")


  }else if(type == "feature"){
    x <- x[,variables]
    x <- sf::st_drop_geometry(x)

    if(is.null(catVars)) {
      x <- data.frame(scale(x))
    } else {
      x_cat <- x[,catVars,drop=FALSE]
      x_num <- x[,-which(names(x)%in%catVars),drop=FALSE]
      scaleparam <- attributes(scale(x_num))
      x_num <- data.frame(scale(x_num))
      x <- as.data.frame(cbind(x_num, lapply(x_cat, as.factor)))
    }

    d_cv <- c()
    for(i in 1:length(cvfolds)){

      if(!is.null(cvtrain)){
        testdata_i <- x[cvfolds[[i]],]
        traindata_i <- x[cvtrain[[i]],]
      }else{
        testdata_i <- x[cvfolds[[i]],]
        traindata_i <- x[-cvfolds[[i]],]
      }

      testdata_i <- testdata_i[complete.cases(testdata_i),]
      traindata_i <- traindata_i[complete.cases(traindata_i),]

      for (k in 1:nrow(testdata_i)){

        if(is.null(catVars)) {
          trainDist <-  tryCatch(FNN::knnx.dist(testdata_i[k,],traindata_i,k=1, algorithm=algorithm),
                                 error = function(e)e)
          if(inherits(trainDist, "error")){
            trainDist <- NA
            message("warning: no distance could be calculated for a fold.
                  Possibly because predictor values are NA")
          }
        } else {
          trainDist <-  tryCatch(gower::gower_dist(testdata_i[i,], traindata_i),
                                 error = function(e)e)
          if(inherits(trainDist, "error")){
            trainDist <- NA
            message("warning: no distance could be calculated for a fold.
                  Possibly because predictor values are NA")
          }
        }


        trainDist[k] <- NA
        d_cv <- c(d_cv,min(trainDist,na.rm=T))
      }
    }

    dists_cv <- data.frame(dist = d_cv,
                           what = factor("CV-distances"),
                           dist_type = "feature")

  }else if(type == "time"){
    d_cv <- c()
    d_cv_tmp <- c()
    for (i in 1:length(cvfolds)){
      if(!is.null(cvtrain)){
        for (k in 1:length(cvfolds[[i]])){
          d_cv_tmp[k] <- min(abs(difftime(sf::st_drop_geometry(x)[cvfolds[[i]][k],timevar],
                                          sf::st_drop_geometry(x)[cvtrain[[i]],timevar],
                                          units=time_unit)))
        }
      }else{
        for (k in 1:length(cvfolds[[i]])){
          d_cv_tmp[k] <- min(abs(difftime(sf::st_drop_geometry(x)[cvfolds[[i]][k],timevar],
                                          sf::st_drop_geometry(x)[-cvfolds[[i]],timevar],
                                          units=time_unit)))
        }
      }
      d_cv <- c(d_cv,d_cv_tmp)
    }


    dists_cv <- data.frame(dist = d_cv,
                           what = factor("CV-distances"),
                           dist_type = "time")

  }

  return(dists_cv)
}





sampleFromArea <- function(modeldomain, samplesize, type,variables,sampling, catVars){

  ##### Distance to prediction locations:
  # regularly spread points (prediction locations):
  # see https://edzer.github.io/OGH21/
  if(inherits(modeldomain, "Raster")){
    modeldomain <- terra::rast(modeldomain)
  }

  if(inherits(modeldomain, "SpatRaster")) {
    if(samplesize>terra::ncell(modeldomain)){
      samplesize <- terra::ncell(modeldomain)
      message(paste0("samplesize for new data shouldn't be larger than number of pixels.
              Samplesize was reduced to ",terra::ncell(modeldomain)))
    }
    #create mask to sample from:
    template <- modeldomain[[1]]
    #terra::values(template)[!is.na(terra::values(template))] <-1
    template <- terra::classify(template, cbind(-Inf, Inf, 1), right=FALSE)
    modeldomainextent <- terra::as.polygons(template) |>
      sf::st_as_sf() |>
      sf::st_geometry()
  }else{
    modeldomainextent <- modeldomain
  }

  sf::sf_use_s2(FALSE)
  sf::st_as_sf(modeldomainextent) |>
    sf::st_transform(4326) -> bb

  methods::as(bb, "Spatial") |>
    sp::spsample(n = samplesize, type = sampling)  |>
    sf::st_as_sfc() |>
    sf::st_set_crs(4326) -> predictionloc
  predictionloc <- sf::st_as_sf(predictionloc)

  # sf version:
  #  predictionloc <- sf::st_sample(sf::st_make_valid(bb),size=samplesize,type=sampling)
  #  sf::st_crs(predictionloc) <- 4326
  #  predictionloc <- sf::st_as_sf(predictionloc)




  if(type == "feature"){

    if(is.null(catVars)) {
      modeldomain <- terra::project(modeldomain, "epsg:4326")
    } else {
      modeldomain <- terra::project(modeldomain, "epsg:4326", method="near")
    }

    predictionloc <- sf::st_as_sf(terra::extract(modeldomain,terra::vect(predictionloc),bind=TRUE))
    predictionloc <- na.omit(predictionloc)
  }

  return(predictionloc)

}






#' Evaluate 'global' cross-validation
#' @description Calculate validation metric using all held back predictions at once
#' @param model an object of class \code{\link[caret]{train}}
#' @return regression (\code{\link[caret]{postResample}}) or classification  (\code{\link[caret]{confusionMatrix}}) statistics
#' @details Relevant when folds are not representative for the entire area of interest.
#' In this case, metrics like R2 are not meaningful since it doesn't reflect the general ability of
#' the model to explain the entire gradient of the response.
#' Comparable to LOOCV, predictions from all held back folds are used here together to calculate validation statistics.
#' @author Hanna Meyer
#' @seealso \code{\link{CreateSpacetimeFolds}}
#' @examples
#' \dontrun{
#' library(caret)
#' data(cookfarm)
#' dat <- cookfarm[sample(1:nrow(cookfarm),500),]
#' indices <- CreateSpacetimeFolds(dat,"SOURCEID","Date")
#' ctrl <- caret::trainControl(method="cv",index = indices$index,savePredictions="final")
#' model <- caret::train(dat[,c("DEM","TWI","BLD")],dat$VW, method="rf", trControl=ctrl, ntree=10)
#' global_validation(model)
#' }
#' @export global_validation
#' @aliases global_validation

global_validation <- function(model){
  predictions <- model$pred
  if(is.null(predictions)){stop("Global performance could not be estimated because predictions were not saved.
                                Train model with savePredictions='final'")}

  ### only use predictions of best tune:
  for (i in 1:length(model$bestTune)){
    predictions <- predictions[predictions[,names(model$bestTune)[i]]==model$bestTune[,i],]
  }

  obs <- predictions$obs
  pred <- predictions$pred

  if(model$modelType=="Regression"){
    out <- caret::postResample(pred = pred, obs = obs)
  }else{
    out <- caret::confusionMatrix(pred, obs)$overall
  }
  return(out)
}


#' K-fold Nearest Neighbour Distance Matching
#' @description
#' This function implements the kNNDM algorithm and returns the necessary
#' indices to perform a k-fold NNDM CV for map validation.
#'
#' @author Carles Milà and Jan Linnenbrink
#' @param tpoints sf or sfc point object, or data.frame if space = "feature". Contains the training points samples.
#' @param modeldomain sf polygon object or SpatRaster defining the prediction area. Optional; alternative to predpoints (see Details).
#' @param predpoints sf or sfc point object, or data.frame if space = "feature". Contains the target prediction points. Optional; alternative to modeldomain (see Details).
#' @param space character. Either "geographical" or "feature".
#' @param k integer. Number of folds desired for CV. Defaults to 10.
#' @param maxp numeric. Maximum fold size allowed, defaults to 0.5, i.e. a single fold can hold a maximum of half of the training points.
#' @param clustering character. Possible values include "hierarchical" and "kmeans". See details.
#' @param linkf character. Only relevant if clustering = "hierarchical". Link function for agglomerative hierarchical clustering.
#' Defaults to "ward.D2". Check `stats::hclust` for other options.
#' @param samplesize numeric. How many points in the modeldomain should be sampled as prediction points?
#' Only required if modeldomain is used instead of predpoints.
#' @param sampling character. How to draw prediction points from the modeldomain? See `sf::st_sample`.
#' Only required if modeldomain is used instead of predpoints.
#' @param useMD boolean. Only for `space`=feature: shall the Mahalanobis distance be calculated instead of Euclidean?
#' Only works with numerical variables.
#' @param algorithm see \code{\link[FNN]{knnx.dist}} and \code{\link[FNN]{knnx.index}}
#' @return An object of class \emph{knndm} consisting of a list of eight elements:
#' indx_train, indx_test (indices of the observations to use as
#' training/test data in each kNNDM CV iteration), Gij (distances for
#' G function construction between prediction and target points), Gj
#' (distances for G function construction during LOO CV), Gjstar (distances
#' for modified G function during kNNDM CV), clusters (list of cluster IDs),
#' W (Wasserstein statistic), and space (stated by the user in the function call).
#'
#' @details
#' knndm is a k-fold version of NNDM LOO CV for medium and large datasets. Brielfy, the algorithm tries to
#' find a k-fold configuration such that the integral of the absolute differences (Wasserstein W statistic)
#' between the empirical nearest neighbour distance distribution function between the test and training data during CV (Gj*),
#' and the empirical nearest neighbour distance distribution function between the prediction and training points (Gij),
#' is minimised. It does so by performing clustering of the training points' coordinates for different numbers of
#' clusters that range from k to N (number of observations), merging them into k final folds,
#' and selecting the configuration with the lowest W.
#'
#' Using a projected CRS in `knndm` has large computational advantages since fast nearest neighbour search can be
#' done via the `FNN` package, while working with geographic coordinates requires computing the full
#' spherical distance matrices. As a clustering algorithm, `kmeans` can only be used for
#' projected CRS while `hierarchical` can work with both projected and geographical coordinates, though it requires
#' calculating the full distance matrix of the training points even for a projected CRS.
#'
#' In order to select between clustering algorithms and number of folds `k`, different `knndm` configurations can be run
#' and compared, being the one with a lower W statistic the one that offers a better match. W statistics between `knndm`
#' runs are comparable as long as `tpoints` and `predpoints` or `modeldomain` stay the same.
#'
#' Map validation using `knndm` should be used using `CAST::global_validation`, i.e. by stacking all out-of-sample
#' predictions and evaluating them all at once. The reasons behind this are 1) The resulting folds can be
#' unbalanced and 2) nearest neighbour functions are constructed and matched using all CV folds simultaneously.
#'
#' If training data points are very clustered with respect to the prediction area and the presented `knndm`
#' configuration still show signs of Gj* > Gij, there are several things that can be tried. First, increase
#' the `maxp` parameter; this may help to control for strong clustering (at the cost of having unbalanced folds).
#' Secondly, decrease the number of final folds `k`, which may help to have larger clusters.
#'
#' The `modeldomain` is either a sf polygon that defines the prediction area, or alternatively a SpatRaster out of which a polygon,
#' transformed into the CRS of the training points, is defined as the outline of all non-NA cells.
#' Then, the function takes a regular point sample (amount defined by `samplesize`) from the spatial extent.
#' As an alternative use `predpoints` instead of `modeldomain`, if you have already defined the prediction locations (e.g. raster pixel centroids).
#' When using either `modeldomain` or `predpoints`, we advise to plot the study area polygon and the training/prediction points as a previous step to ensure they are aligned.
#'
#' `knndm` can also be performed in the feature space by setting `space` to "feature".
#' Euclidean distances or Mahalanobis distances can be used for distance calculation, but only Euclidean are tested.
#' In this case, nearest neighbour distances are calculated in n-dimensional feature space rather than in geographical space.
#' `tpoints` and `predpoints` can be data frames or sf objects containing the values of the features. Note that the names of `tpoints` and `predpoints` must be the same.
#' `predpoints` can also be missing, if `modeldomain` is of class SpatRaster. In this case, the values of of the SpatRaster will be extracted to the `predpoints`.
#' In the case of any categorical features, Gower distances will be used to calculate the Nearest Neighbour distances [Experimental]. If categorical
#' features are present, and `clustering` = "kmeans", K-Prototype clustering will be performed instead.
#'
#' @note
#' For spatial visualization of fold affiliation see examples.
#' @references
#' \itemize{
#' \item Linnenbrink, J., Milà, C., Ludwig, M., and Meyer, H.: kNNDM: k-fold Nearest Neighbour Distance Matching Cross-Validation for map accuracy estimation, EGUsphere [preprint], https://doi.org/10.5194/egusphere-2023-1308, 2023.
#' \item Milà, C., Mateu, J., Pebesma, E., Meyer, H. (2022): Nearest Neighbour Distance Matching Leave-One-Out Cross-Validation for map validation. Methods in Ecology and Evolution 00, 1– 13.
#' }
#' @seealso \code{\link{geodist}}, \code{\link{nndm}}
#'
#' @export
#' @examples
#' ########################################################################
#' # Example 1: Simulated data - Randomly-distributed training points
#' ########################################################################
#'
#' library(sf)
#' library(ggplot2)
#'
#' # Simulate 1000 random training points in a 100x100 square
#' set.seed(1234)
#' simarea <- list(matrix(c(0,0,0,100,100,100,100,0,0,0), ncol=2, byrow=TRUE))
#' simarea <- sf::st_polygon(simarea)
#' train_points <- sf::st_sample(simarea, 1000, type = "random")
#' pred_points <- sf::st_sample(simarea, 1000, type = "regular")
#' plot(simarea)
#' plot(pred_points, add = TRUE, col = "blue")
#' plot(train_points, add = TRUE, col = "red")
#'
#' # Run kNNDM for the whole domain, here the prediction points are known.
#' knndm_folds <- knndm(train_points, predpoints = pred_points, k = 5)
#' knndm_folds
#' plot(knndm_folds)
#' plot(knndm_folds, type = "simple") # For more accessible legend labels
#' plot(knndm_folds, type = "simple", stat = "density") # To visualize densities rather than ECDFs
#' folds <- as.character(knndm_folds$clusters)
#' ggplot() +
#'   geom_sf(data = simarea, alpha = 0) +
#'   geom_sf(data = train_points, aes(col = folds))
#'
#' ########################################################################
#' # Example 2: Simulated data - Clustered training points
#' ########################################################################
#' \dontrun{
#' library(sf)
#' library(ggplot2)
#'
#' # Simulate 1000 clustered training points in a 100x100 square
#' set.seed(1234)
#' simarea <- list(matrix(c(0,0,0,100,100,100,100,0,0,0), ncol=2, byrow=TRUE))
#' simarea <- sf::st_polygon(simarea)
#' train_points <- clustered_sample(simarea, 1000, 50, 5)
#' pred_points <- sf::st_sample(simarea, 1000, type = "regular")
#' plot(simarea)
#' plot(pred_points, add = TRUE, col = "blue")
#' plot(train_points, add = TRUE, col = "red")
#'
#' # Run kNNDM for the whole domain, here the prediction points are known.
#' knndm_folds <- knndm(train_points, predpoints = pred_points, k = 5)
#' knndm_folds
#' plot(knndm_folds)
#' plot(knndm_folds, type = "simple") # For more accessible legend labels
#' plot(knndm_folds, type = "simple", stat = "density") # To visualize densities rather than ECDFs
#' folds <- as.character(knndm_folds$clusters)
#' ggplot() +
#'   geom_sf(data = simarea, alpha = 0) +
#'   geom_sf(data = train_points, aes(col = folds))
#'}
#' ########################################################################
#' # Example 3: Real- world example; using a modeldomain instead of previously
#' # sampled prediction locations
#' ########################################################################
#' \dontrun{
#' library(sf)
#' library(terra)
#' library(ggplot2)
#'
#' ### prepare sample data:
#' data(cookfarm)
#' dat <- aggregate(cookfarm[,c("DEM","TWI", "NDRE.M", "Easting", "Northing","VW")],
#'    by=list(as.character(cookfarm$SOURCEID)),mean)
#' pts <- dat[,-1]
#' pts <- st_as_sf(pts,coords=c("Easting","Northing"))
#' st_crs(pts) <- 26911
#' studyArea <- rast(system.file("extdata","predictors_2012-03-25.tif",package="CAST"))
#' pts <- st_transform(pts, crs = st_crs(studyArea))
#' terra::plot(studyArea[["DEM"]])
#' terra::plot(vect(pts), add = T)
#'
#' knndm_folds <- knndm(pts, modeldomain=studyArea, k = 5)
#' knndm_folds
#' plot(knndm_folds)
#' folds <- as.character(knndm_folds$clusters)
#' ggplot() +
#'   geom_sf(data = pts, aes(col = folds))
#'
#' #use for cross-validation:
#' library(caret)
#' ctrl <- trainControl(method="cv",
#'    index=knndm_folds$indx_train,
#'    savePredictions='final')
#' model_knndm <- train(dat[,c("DEM","TWI", "NDRE.M")],
#'    dat$VW,
#'    method="rf",
#'    trControl = ctrl)
#' global_validation(model_knndm)
#'}
#' ########################################################################
#' # Example 4: Real- world example; kNNDM in feature space
#' ########################################################################
#' \dontrun{
#' library(sf)
#' library(terra)
#' library(ggplot2)
#'
#'data(splotdata)
#'splotdata <- splotdata[splotdata$Country == "Chile",]
#'
#'predictors <- c("bio_1", "bio_4", "bio_5", "bio_6",
#'                "bio_8", "bio_9", "bio_12", "bio_13",
#'                "bio_14", "bio_15", "elev")
#'
#'trainDat <- sf::st_drop_geometry(splotdata)
#'predictors_sp <- terra::rast(system.file("extdata", "predictors_chile.tif",package="CAST"))
#'
#'
#' terra::plot(predictors_sp[["bio_1"]])
#' terra::plot(vect(splotdata), add = T)
#'
#'knndm_folds <- knndm(trainDat[,predictors], modeldomain = predictors_sp, space = "feature",
#'                     clustering="kmeans", k=4, maxp=0.8)
#'plot(knndm_folds)
#'
#'}
knndm <- function(tpoints, modeldomain = NULL, predpoints = NULL,
                  space = "geographical",
                  k = 10, maxp = 0.5,
                  clustering = "hierarchical", linkf = "ward.D2",
                  samplesize = 1000, sampling = "regular", useMD=FALSE,
                  algorithm="brute"){

  # create sample points from modeldomain
  if(is.null(predpoints)&!is.null(modeldomain)){

    # Check modeldomain is indeed a sf/SpatRaster
    if(!any(c("sfc", "sf", "SpatRaster") %in% class(modeldomain))){
      stop("modeldomain must be a sf/sfc object or a 'SpatRaster' object.")
    }

    # If modeldomain is a SpatRaster, transform into polygon
    if(any(class(modeldomain) == "SpatRaster")){

      # save predictor stack for extraction if space = "feature"
      if(space == "feature") {
        predictor_stack <- modeldomain
      }
      modeldomain[!is.na(modeldomain)] <- 1
      modeldomain <- terra::as.polygons(modeldomain, values = FALSE, na.all = TRUE) |>
        sf::st_as_sf() |>
        sf::st_union()
      if(any(c("sfc", "sf") %in% class(tpoints))) {
        modeldomain <- sf::st_transform(modeldomain, crs = sf::st_crs(tpoints))
      }
    }




    # Check modeldomain is indeed a polygon sf
    if(!any(class(sf::st_geometry(modeldomain)) %in% c("sfc_POLYGON", "sfc_MULTIPOLYGON"))){
      stop("modeldomain must be a sf/sfc polygon object.")
    }

    # Check whether modeldomain has the same crs as tpoints
    if(!identical(sf::st_crs(tpoints), sf::st_crs(modeldomain)) & space == "geographical"){
      stop("tpoints and modeldomain must have the same CRS")
    }

    # We sample
    message(paste0(samplesize, " prediction points are sampled from the modeldomain"))
    predpoints <- sf::st_sample(x = modeldomain, size = samplesize, type = sampling)
    sf::st_crs(predpoints) <- sf::st_crs(modeldomain)

    if(space == "feature") {
      message("predictor values are extracted for prediction points")
      predpoints <- terra::extract(predictor_stack, terra::vect(predpoints), ID=FALSE)
    }

  }else if(!is.null(predpoints) & space == "geographical"){
    if(!identical(sf::st_crs(tpoints), sf::st_crs(predpoints))){
      stop("tpoints and predpoints must have the same CRS")
    }
  }


  # Conditional preprocessing actions
  if(space == "geographical") {
    if (any(class(tpoints) %in% "sfc")) {
      tpoints <- sf::st_sf(geom = tpoints)
    }
    if (any(class(predpoints) %in% "sfc")) {
      predpoints <- sf::st_sf(geom = predpoints)
    }
    if(is.na(sf::st_crs(tpoints))){
      warning("Missing CRS in training or prediction points. Assuming projected CRS.")
      islonglat <- FALSE
    }else{
      islonglat <- sf::st_is_longlat(tpoints)
    }
  } else if (space == "feature") {
    # drop geometry if tpoints / predpoints are of class sf
    if(any(class(tpoints) %in% c("sf","sfc"))) {
      tpoints <- sf::st_set_geometry(tpoints, NULL)
    }
    if(any(class(predpoints) %in% c("sf","sfc"))) {
      predpoints <- sf::st_set_geometry(predpoints, NULL)
    }
    # get names of categorical variables
    catVars <- names(tpoints)[which(sapply(tpoints, class)%in%c("factor","character"))]
    if(length(catVars)==0) {
      catVars <- NULL
    }
    if(!is.null(catVars)) {
      message(paste0("variable(s) '", catVars, "' is (are) treated as categorical variables"))
    }
    # omit NAs
    if(any(is.na(predpoints))) {
      message("some prediction points contain NAs, which will be removed")
      predpoints <- stats::na.omit(predpoints)
    }
    if(any(is.na(tpoints))) {
      message("some training points contain NAs, which will be removed")
      tpoints <- stats::na.omit(tpoints)
    }
  }





  # kNNDM in the geographical / feature space
  if(isTRUE(space == "geographical")){

    # prior checks
    check_knndm_geo(tpoints, predpoints, space, k, maxp, clustering, islonglat)
    # kNNDM in geographical space
    knndm_res <- knndm_geo(tpoints, predpoints, k, maxp, clustering, linkf, islonglat, algorithm=algorithm)

  } else if (isTRUE(space == "feature")) {

    # prior checks
    check_knndm_feature(tpoints, predpoints, space, k, maxp, clustering, islonglat, catVars,useMD)
    # kNNDM in feature space
    knndm_res <- knndm_feature(tpoints, predpoints, k, maxp, clustering, linkf, catVars, useMD, algorithm=algorithm)

  }

  # Output
  knndm_res
}


# kNNDM checks
check_knndm_geo <- function(tpoints, predpoints, space, k, maxp, clustering, islonglat){

  if(!identical(sf::st_crs(tpoints), sf::st_crs(predpoints))){
    stop("tpoints and predpoints must have the same CRS")
  }
  if (!(clustering %in% c("kmeans", "hierarchical"))) {
    stop("clustering must be one of `kmeans` or `hierarchical`")
  }
  if (space != "geographical") {
    stop("Only kNNDM in the geographical space is currently implemented.")
  }
  if (!(maxp < 1 & maxp > 1/k)) {
    stop("maxp must be strictly between 1/k and 1")
  }
  if(isTRUE(islonglat) & clustering == "kmeans"){
    stop("kmeans works in the Euclidean space and therefore can only handle
         projected coordinates. Please use hierarchical clustering or project your data.")
  }
}

check_knndm_feature <- function(tpoints, predpoints, space, k, maxp, clustering, islonglat, catVars, useMD){

  if(!is.null(catVars) & isTRUE(useMD)) {
    warning("Mahalanobis distances not supported for categorical features, Gower distances will be used")
    useMD <- FALSE
  }

  if (!(maxp < 1 & maxp > 1/k)) {
    stop("maxp must be strictly between 1/k and 1")
  }

  if(is.null(predpoints)) {
    stop("predpoints with predictor data missing")
  }

  if(length(setdiff(names(tpoints), names(predpoints)))>0) {
    stop("tpoints and predpoints need to contain the predictor data and have the same colnames.")
  }

  for (catvar in catVars) {
    if (any(!unique(tpoints[,catvar]) %in% unique(predpoints[,catvar]))) {
      stop(paste0("Some values of factor", catvar, "are only present in training / prediction points.
                  All factor values in the prediction points must be present in the training points."))
    }
  }

}


# kNNDM in the geographical space
knndm_geo <- function(tpoints, predpoints, k, maxp, clustering, linkf, islonglat, algorithm){

  # Gj and Gij calculation
  tcoords <- sf::st_coordinates(tpoints)[,1:2]
  if(isTRUE(islonglat)){
    distmat <- sf::st_distance(tpoints)
    units(distmat) <- NULL
    diag(distmat) <- NA
    Gj <- apply(distmat, 1, function(x) min(x, na.rm=TRUE))
    Gij <- sf::st_distance(predpoints, tpoints)
    units(Gij) <- NULL
    Gij <- apply(Gij, 1, min)
  }else{
    Gj <- c(FNN::knn.dist(tcoords, k = 1, algorithm=algorithm))
    Gij <- c(FNN::knnx.dist(query = sf::st_coordinates(predpoints)[,1:2],
                            data = tcoords, k = 1, algorithm=algorithm))
  }

  # Check if Gj > Gij (warning suppressed regarding ties)
  testks <- suppressWarnings(stats::ks.test(Gj, Gij, alternative = "great"))
  if(testks$p.value >= 0.05){

    clust <- sample(rep(1:k, ceiling(nrow(tpoints)/k)), size = nrow(tpoints), replace=F)

    if(isTRUE(islonglat)){
      Gjstar <- distclust_distmat(distmat, clust)
    }else{
      Gjstar <- distclust_euclidean(tcoords, clust, algorithm=algorithm)
    }
    k_final <- "random CV"
    W_final <- twosamples::wass_stat(Gjstar, Gij)
    message("Gij <= Gj; a random CV assignment is returned")

  }else{

    if(clustering == "hierarchical"){
      # For hierarchical clustering we need to compute the full distance matrix,
      # but we can integrate geographical distances
      if(!isTRUE(islonglat)){
        distmat <- sf::st_distance(tpoints)
      }
      hc <- stats::hclust(d = stats::as.dist(distmat), method = linkf)
    }

    # Build grid of number of clusters to try - we sample low numbers more intensively
    clustgrid <- data.frame(nk = as.integer(round(exp(seq(log(k), log(nrow(tpoints)-2),
                                                          length.out = 100)))))
    clustgrid$W <- NA
    clustgrid <- clustgrid[!duplicated(clustgrid$nk),]
    clustgroups <- list()

    # Compute 1st PC for ordering clusters
    pcacoords <- stats::prcomp(tcoords, center = TRUE, scale. = FALSE, rank = 1)

    # We test each number of clusters
    for(nk in clustgrid$nk){

      # Create nk clusters
      if(clustering == "hierarchical"){
        clust_nk <- stats::cutree(hc, k=nk)
      }else if(clustering == "kmeans"){
        clust_nk <- stats::kmeans(tcoords, nk)$cluster
      }

      tabclust <- as.data.frame(table(clust_nk))
      tabclust$clust_k <- NA

      # compute cluster centroids and apply PC loadings to shuffle along the 1st dimension
      centr_tpoints <- sapply(tabclust$clust_nk, function(x){
        centrpca <- matrix(apply(tcoords[clust_nk %in% x, , drop=FALSE], 2, mean), nrow = 1)
        colnames(centrpca) <- colnames(tcoords)
        return(predict(pcacoords, centrpca))
      })
      tabclust$centrpca <- centr_tpoints
      tabclust <- tabclust[order(tabclust$centrpca),]

      # We don't merge big clusters
      clust_i <- 1
      for(i in 1:nrow(tabclust)){
        if(tabclust$Freq[i] >= nrow(tpoints)/k){
          tabclust$clust_k[i] <- clust_i
          clust_i <- clust_i + 1
        }
      }
      rm("clust_i")

      # And we merge the remaining into k groups
      clust_i <- setdiff(1:k, unique(tabclust$clust_k))
      tabclust$clust_k[is.na(tabclust$clust_k)] <- rep(clust_i, ceiling(nk/length(clust_i)))[1:sum(is.na(tabclust$clust_k))]
      tabclust2 <- data.frame(ID = 1:length(clust_nk), clust_nk = clust_nk)
      tabclust2 <- merge(tabclust2, tabclust, by = "clust_nk")
      tabclust2 <- tabclust2[order(tabclust2$ID),]
      clust_k <- tabclust2$clust_k

      # Compute W statistic if not exceeding maxp
      if(!any(table(clust_k)/length(clust_k)>maxp)){

        if(isTRUE(islonglat)){
          Gjstar_i <- distclust_distmat(distmat, clust_k)
        }else{
          Gjstar_i <- distclust_euclidean(tcoords, clust_k,algorithm=algorithm)
        }
        clustgrid$W[clustgrid$nk==nk] <- twosamples::wass_stat(Gjstar_i, Gij)
        clustgroups[[paste0("nk", nk)]] <- clust_k
      }
    }

    # Final configuration
    k_final <- clustgrid$nk[which.min(clustgrid$W)]
    W_final <- min(clustgrid$W, na.rm=T)
    clust <- clustgroups[[paste0("nk", k_final)]]
    if(isTRUE(islonglat)){
      Gjstar <- distclust_distmat(distmat, clust)
    }else{
      Gjstar <- distclust_euclidean(tcoords, clust,algorithm=algorithm)
    }
  }

  # Output
  cfolds <- CAST::CreateSpacetimeFolds(data.frame(clust=clust), spacevar = "clust", k = k)
  res <- list(clusters = clust,
              indx_train = cfolds$index, indx_test = cfolds$indexOut,
              Gij = Gij, Gj = Gj, Gjstar = Gjstar,
              W = W_final, method = clustering, q = k_final, space = "geographical")
  class(res) <- c("knndm", "list")
  res
}


# kNNDM in the feature space
knndm_feature <- function(tpoints, predpoints, k, maxp, clustering, linkf, catVars, useMD, algorithm) {

  # rescale data
  if(is.null(catVars)) {

    scale_attr <- attributes(scale(tpoints))
    tpoints <- scale(tpoints) |> as.data.frame()
    predpoints <- scale(predpoints,center=scale_attr$`scaled:center`,
                        scale=scale_attr$`scaled:scale`) |>
      as.data.frame()

  } else {
    tpoints_cat <- tpoints[,catVars,drop=FALSE]
    predpoints_cat <- predpoints[,catVars,drop=FALSE]

    tpoints_num <- tpoints[,-which(names(tpoints)%in%catVars),drop=FALSE]
    predpoints_num <- predpoints[,-which(names(predpoints)%in%catVars),drop=FALSE]

    scale_attr <- attributes(scale(tpoints_num))
    tpoints <- scale(tpoints_num) |> as.data.frame()
    predpoints <- scale(predpoints_num,center=scale_attr$`scaled:center`,
                        scale=scale_attr$`scaled:scale`) |>
      as.data.frame()
    tpoints <- as.data.frame(cbind(tpoints, lapply(tpoints_cat, as.factor)))
    predpoints <- as.data.frame(cbind(predpoints, lapply(predpoints_cat, as.factor)))

  }


  # Gj and Gij calculation
  if(is.null(catVars)) {


    if(isTRUE(useMD)) {

      tpoints_mat <- as.matrix(tpoints)
      predpoints_mat <- as.matrix(predpoints)

      # use Mahalanobis distances
      if (dim(tpoints_mat)[2] == 1) {
        S <- matrix(stats::var(tpoints_mat), 1, 1)
        tpoints_mat <- as.matrix(tpoints_mat, ncol = 1)
      } else {
        S <- stats::cov(tpoints_mat)
      }
      S_inv <- MASS::ginv(S)

      # calculate distance matrix
      distmat <- matrix(nrow=nrow(tpoints), ncol=nrow(tpoints))
      distmat <- sapply(1:nrow(distmat), function(i) {
        sapply(1:nrow(distmat), function(j) {
          sqrt(t(tpoints_mat[i,] - tpoints_mat[j,]) %*% S_inv %*% (tpoints_mat[i,] - tpoints_mat[j,]))
        })
      })
      diag(distmat) <- NA

      Gj <- apply(distmat, 1, min, na.rm=TRUE)

      Gij <- sapply(1:dim(predpoints_mat)[1], function(y) {
        min(sapply(1:dim(tpoints_mat)[1], function(x) {
          sqrt(t(predpoints_mat[y,] - tpoints_mat[x,]) %*% S_inv %*% (predpoints_mat[y,] - tpoints_mat[x,]))
        }))
      })


    } else {
      # use FNN with Euclidean distances if no categorical variables are present
      Gj <- c(FNN::knn.dist(tpoints, k = 1, algorithm=algorithm))
      Gij <- c(FNN::knnx.dist(query = predpoints, data = tpoints, k = 1, algorithm=algorithm))
    }


  } else {

    # use Gower distances if categorical variables are present
    Gj <- sapply(1:nrow(tpoints), function(i) gower::gower_topn(tpoints[i,], tpoints[-i,], n=1)$distance[[1]])
    Gij <- c(gower::gower_topn(predpoints, tpoints, n = 1)$distance)

  }


  # Check if Gj > Gij (warning suppressed regarding ties)
  testks <- suppressWarnings(stats::ks.test(Gj, Gij, alternative = "great"))
  if(testks$p.value >= 0.05){

    clust <- sample(rep(1:k, ceiling(nrow(tpoints)/k)), size = nrow(tpoints), replace=F)

    if(is.null(catVars)) {
      if(isTRUE(useMD)) {
        Gjstar <- distclust_MD(tpoints, clust)
      } else {
        Gjstar <- distclust_euclidean(tpoints, clust,algorithm=algorithm)
      }

    } else {
      Gjstar <- distclust_gower(tpoints, clust)
    }

    k_final <- "random CV"
    W_final <- twosamples::wass_stat(Gjstar, Gij)
    message("Gij <= Gj; a random CV assignment is returned")

  }else{

    if(clustering == "hierarchical"){

      # calculate distance matrix which is needed for hierarchical clustering
      if(is.null(catVars)) {

        if(isFALSE(useMD)) {
          # calculate distance matrix with Euclidean distances if no categorical variables are present
          # for MD: distance matrix was already calculated
          distmat <- stats::dist(tpoints, upper=TRUE, diag=TRUE) |> as.matrix()
          diag(distmat) <- NA
        }

      } else {

        # calculate distance matrix with Gower distances if categorical variables are present
        distmat <- matrix(nrow=nrow(tpoints), ncol=nrow(tpoints))
        for (i in 1:nrow(tpoints)){

          trainDist <-  gower::gower_dist(tpoints[i,], tpoints)

          trainDist[i] <- NA
          distmat[i,] <- trainDist
        }
      }
      hc <- stats::hclust(d = stats::as.dist(distmat), method = linkf)
    }

    # Build grid of number of clusters to try - we sample low numbers more intensively
    clustgrid <- data.frame(nk = as.integer(round(exp(seq(log(k), log(nrow(tpoints)-2),
                                                          length.out = 100)))))
    clustgrid$W <- NA
    clustgrid <- clustgrid[!duplicated(clustgrid$nk),]
    clustgroups <- list()

    # Compute 1st PC for ordering clusters
    if(is.null(catVars)) {
      pcacoords <- stats::prcomp(tpoints, center = TRUE, scale. = FALSE, rank = 1)
    } else {
      pcacoords <- PCAmixdata::PCAmix(X.quanti = tpoints[,!(names(tpoints) %in% catVars), drop=FALSE],
                                      X.quali = tpoints[,names(tpoints) %in% catVars, drop=FALSE],
                                      graph = FALSE)
    }

    # We test each number of clusters
    for(nk in clustgrid$nk) {

      # Create nk clusters
      if(clustering == "hierarchical"){
        clust_nk <- stats::cutree(hc, k=nk)
      } else if(clustering == "kmeans"){
        if(is.null(catVars)) {
          clust_nk <- tryCatch(stats::kmeans(tpoints, nk)$cluster,
                               error=function(e) e)

        } else {
          # prototype clustering for mixed data sets
          clust_nk <- tryCatch(clustMixType::kproto(tpoints, nk,verbose=FALSE)$cluster,
                               error=function(e) e)
        }
      }

      if (!inherits(clust_nk,"error")){
        tabclust <- as.data.frame(table(clust_nk))
        tabclust$clust_k <- NA

        # compute cluster centroids and apply PC loadings to shuffle along the 1st dimension
        if(is.null(catVars)) {
          centr_tpoints <- sapply(tabclust$clust_nk, function(x){
            centrpca <- matrix(apply(tpoints[clust_nk %in% x, , drop=FALSE], 2, mean), nrow = 1)
            colnames(centrpca) <- colnames(tpoints)
            return(predict(pcacoords, centrpca))
          })
        } else {
          centr_tpoints <- sapply(tabclust$clust_nk, function(x){
            centrpca_num <- matrix(apply(tpoints[clust_nk %in% x, !(names(tpoints) %in% catVars), drop=FALSE], 2, mean), nrow = 1)
            centrpca_cat <- matrix(apply(tpoints[clust_nk %in% x, names(tpoints) %in% catVars, drop=FALSE], 2,
                                         function(y) names(which.max(table(y)))), nrow = 1)
            colnames(centrpca_num) <- colnames(tpoints[,!(names(tpoints) %in% catVars), drop=FALSE])
            colnames(centrpca_cat) <- colnames(tpoints[,names(tpoints) %in% catVars, drop=FALSE])

            return(predict(pcacoords, centrpca_num, centrpca_cat)[,1])

          })
        }

        tabclust$centrpca <- centr_tpoints
        tabclust <- tabclust[order(tabclust$centrpca),]

        # We don't merge big clusters
        clust_i <- 1
        for(i in 1:nrow(tabclust)){
          if(tabclust$Freq[i] >= nrow(tpoints)/k){
            tabclust$clust_k[i] <- clust_i
            clust_i <- clust_i + 1
          }
        }
        rm("clust_i")

        # And we merge the remaining into k groups
        clust_i <- setdiff(1:k, unique(tabclust$clust_k))
        tabclust$clust_k[is.na(tabclust$clust_k)] <- rep(clust_i, ceiling(nk/length(clust_i)))[1:sum(is.na(tabclust$clust_k))]
        tabclust2 <- data.frame(ID = 1:length(clust_nk), clust_nk = clust_nk)
        tabclust2 <- merge(tabclust2, tabclust, by = "clust_nk")
        tabclust2 <- tabclust2[order(tabclust2$ID),]
        clust_k <- tabclust2$clust_k

        # Compute W statistic if not exceeding maxp
        if(!(any(table(clust_k)/length(clust_k)>maxp))){

          if(clustering == "kmeans") {
            if(is.null(catVars)) {
              if(isTRUE(useMD)){
                Gjstar_i <- distclust_MD(tpoints, clust_k)
              } else {
                Gjstar_i <- distclust_euclidean(tpoints, clust_k,algorithm=algorithm)
              }
            } else {
              Gjstar_i <- distclust_gower(tpoints, clust_k)
            }

          } else {
            Gjstar_i <- distclust_distmat(distmat, clust_k)
          }

          clustgrid$W[clustgrid$nk==nk] <- twosamples::wass_stat(Gjstar_i, Gij)
          clustgroups[[paste0("nk", nk)]] <- clust_k
        }
      } else {
        message(paste("skipped nk", nk))
      }
    }

    # Final configuration
    k_final <- clustgrid$nk[which.min(clustgrid$W)]
    W_final <- min(clustgrid$W, na.rm=T)
    clust <- clustgroups[[paste0("nk", k_final)]]

    if(clustering == "kmeans") {
      if(is.null(catVars)) {
        if(isTRUE(useMD)) {
          Gjstar <- distclust_MD(tpoints, clust)
        } else {
          Gjstar <- distclust_euclidean(tpoints, clust,algorithm=algorithm)
        }

      } else {
        Gjstar <- distclust_gower(tpoints, clust)
      }
    } else {
      Gjstar <- distclust_distmat(distmat, clust)
    }

  }


  # Output
  cfolds <- CAST::CreateSpacetimeFolds(data.frame(clust=clust), spacevar = "clust", k = k)
  res <- list(clusters = clust,
              indx_train = cfolds$index, indx_test = cfolds$indexOut,
              Gij = Gij, Gj = Gj, Gjstar = Gjstar,
              W = W_final, method = clustering, q = k_final, space = "feature")
  class(res) <- c("knndm", "list")
  res
}


# Helper function: Compute out-of-fold NN distance (geographical coordinates / numerical variables)
distclust_distmat <- function(distm, folds){
  alldist <- rep(NA, length(folds))
  for(f in unique(folds)){
    alldist[f == folds] <- apply(distm[f == folds, f != folds, drop=FALSE], 1, min)
  }
  alldist
}

# Helper function: Compute out-of-fold NN distance (projected coordinates / numerical variables)
distclust_euclidean <- function(tr_coords, folds, algorithm){
  alldist <- rep(NA, length(folds))
  for(f in unique(folds)){
    alldist[f == folds] <- c(FNN::knnx.dist(query = tr_coords[f == folds,,drop=FALSE],
                                            data = tr_coords[f != folds,,drop=FALSE], k = 1, algorithm=algorithm))
  }
  alldist
}

# Helper function: Compute out-of-fold NN distance (categorical variables)
distclust_gower <- function(tr_coords, folds){

  alldist <- rep(NA, length(folds))
  for(f in unique(folds)){
    alldist[f == folds] <- c(gower::gower_topn(tr_coords[f == folds,,drop=FALSE],
                                               tr_coords[f != folds,,drop=FALSE], n=1))$distance[[1]]
  }
  unlist(alldist)
}

# Helper function: Compute out-of-fold NN distance (Mahalanobian distance)
distclust_MD <- function(tr_coords, folds){

  tr_mat <- as.matrix(tr_coords)

  S <- stats::cov(tr_mat)
  S_inv <- MASS::ginv(S)

  alldist <- rep(NA, length(folds))
  for(f in unique(folds)) {

    alldist[f == folds] <- apply(tr_mat[f==folds,,drop=FALSE], 1, function(y) {
      min(apply(tr_mat[f!=folds,,drop=FALSE], 1, function(x) {
        sqrt(t(y - x) %*% S_inv %*% (y - x))
      }))
    })
  }
  unlist(alldist)
}


#' Nearest Neighbour Distance Matching (NNDM) algorithm
#' @description
#' This function implements the NNDM algorithm and returns the necessary indices to perform a NNDM LOO CV for map validation.
#' @author Carles Milà
#' @param tpoints sf or sfc point object, or data.frame if space = "feature". Contains the training points samples.
#' @param modeldomain sf polygon object or SpatRaster defining the prediction area. Optional; alternative to predpoints (see Details).
#' @param predpoints sf or sfc point object, or data.frame if space = "feature". Contains the target prediction points. Optional; alternative to modeldomain (see Details).
#' @param space character. Either "geographical" or "feature". Feature space is still experimental, so use with caution.
#' @param samplesize numeric. How many points in the modeldomain should be sampled as prediction points?
#' Only required if modeldomain is used instead of predpoints.
#' @param sampling character. How to draw prediction points from the modeldomain? See `sf::st_sample`.
#' Only required if modeldomain is used instead of predpoints.
#' @param phi Numeric. Estimate of the landscape autocorrelation range in the
#' same units as the tpoints and predpoints for projected CRS, in meters for geographic CRS.
#' Per default (phi="max"), the maximum distance found in the training and prediction points is used. See Details.
#' @param min_train Numeric between 0 and 1. Minimum proportion of training
#' data that must be used in each CV fold. Defaults to 0.5 (i.e. half of the training points).
#' @param algorithm see \code{\link[FNN]{knnx.dist}} and \code{\link[FNN]{knnx.index}}
#' @return An object of class \emph{nndm} consisting of a list of six elements:
#' indx_train, indx_test, and indx_exclude (indices of the observations to use as
#' training/test/excluded data in each NNDM LOO CV iteration), Gij (distances for
#' G function construction between prediction and target points), Gj
#' (distances for G function construction during LOO CV), Gjstar (distances
#' for modified G function during NNDM LOO CV), phi (landscape autocorrelation range).
#' indx_train and indx_test can directly be used as "index" and "indexOut" in
#' caret's \code{\link[caret]{trainControl}} function or used to initiate a custom validation strategy in mlr3.
#'
#' @details NNDM proposes a LOO CV scheme such that the nearest neighbour distance distribution function between the test and training data during the CV process is matched to the nearest neighbour
#' distance distribution function between the prediction and training points. Details of the method can be found in Milà et al. (2022).
#'
#' Specifying \emph{phi} allows limiting distance matching to the area where this is assumed to be relevant due to spatial autocorrelation.
#' Distances are only matched up to \emph{phi}. Beyond that range, all data points are used for training, without exclusions.
#' When \emph{phi} is set to "max", nearest neighbor distance matching is performed for the entire prediction area. Euclidean distances are used for projected
#' and non-defined CRS, great circle distances are used for geographic CRS (units in meters).
#'
#' The \emph{modeldomain} is either a sf polygon that defines the prediction area, or alternatively a SpatRaster out of which a polygon,
#' transformed into the CRS of the training points, is defined as the outline of all non-NA cells.
#' Then, the function takes a regular point sample (amount defined by \emph{samplesize)} from the spatial extent.
#' As an alternative use \emph{predpoints} instead of \emph{modeldomain}, if you have already defined the prediction locations (e.g. raster pixel centroids).
#' When using either \emph{modeldomain} or \emph{predpoints}, we advise to plot the study area polygon and the training/prediction points as a previous step to ensure they are aligned.
#'
#' @note NNDM is a variation of LOOCV and therefore may take a long time for large training data sets. See \code{\link{knndm}} for a more efficient k-fold variant of the method.
#' @seealso \code{\link{geodist}}, \code{\link{knndm}}
#' @references
#' \itemize{
#' \item Milà, C., Mateu, J., Pebesma, E., Meyer, H. (2022): Nearest Neighbour Distance Matching Leave-One-Out Cross-Validation for map validation. Methods in Ecology and Evolution 00, 1– 13.
#' \item Meyer, H., Pebesma, E. (2022): Machine learning-based global maps of ecological variables and the challenge of assessing them. Nature Communications. 13.
#' }
#' @export
#' @examples
#' ########################################################################
#' # Example 1: Simulated data - Randomly-distributed training points
#' ########################################################################
#'
#' library(sf)
#'
#' # Simulate 100 random training points in a 100x100 square
#' set.seed(123)
#' poly <- list(matrix(c(0,0,0,100,100,100,100,0,0,0), ncol=2, byrow=TRUE))
#' sample_poly <- sf::st_polygon(poly)
#' train_points <- sf::st_sample(sample_poly, 100, type = "random")
#' pred_points <- sf::st_sample(sample_poly, 100, type = "regular")
#' plot(sample_poly)
#' plot(pred_points, add = TRUE, col = "blue")
#' plot(train_points, add = TRUE, col = "red")
#'
#' # Run NNDM for the whole domain, here the prediction points are known
#' nndm_pred <- nndm(train_points, predpoints=pred_points)
#' nndm_pred
#' plot(nndm_pred)
#' plot(nndm_pred, type = "simple") # For more accessible legend labels
#'
#' # ...or run NNDM with a known autocorrelation range of 10
#' # to restrict the matching to distances lower than that.
#' nndm_pred <- nndm(train_points, predpoints=pred_points, phi = 10)
#' nndm_pred
#' plot(nndm_pred)
#'
#' ########################################################################
#' # Example 2: Simulated data - Clustered training points
#' ########################################################################
#'
#' library(sf)
#'
#' # Simulate 100 clustered training points in a 100x100 square
#' set.seed(123)
#' poly <- list(matrix(c(0,0,0,100,100,100,100,0,0,0), ncol=2, byrow=TRUE))
#' sample_poly <- sf::st_polygon(poly)
#' train_points <- clustered_sample(sample_poly, 100, 10, 5)
#' pred_points <- sf::st_sample(sample_poly, 100, type = "regular")
#' plot(sample_poly)
#' plot(pred_points, add = TRUE, col = "blue")
#' plot(train_points, add = TRUE, col = "red")
#'
#' # Run NNDM for the whole domain
#' nndm_pred <- nndm(train_points, predpoints=pred_points)
#' nndm_pred
#' plot(nndm_pred)
#' plot(nndm_pred, type = "simple") # For more accessible legend labels
#'
#' ########################################################################
#' # Example 3: Real- world example; using a SpatRast modeldomain instead
#' # of previously sampled prediction locations
#' ########################################################################
#' \dontrun{
#' library(sf)
#' library(terra)
#'
#' ### prepare sample data:
#' data(cookfarm)
#' dat <- aggregate(cookfarm[,c("DEM","TWI", "NDRE.M", "Easting", "Northing","VW")],
#'    by=list(as.character(cookfarm$SOURCEID)),mean)
#' pts <- dat[,-1]
#' pts <- st_as_sf(pts,coords=c("Easting","Northing"))
#' st_crs(pts) <- 26911
#' studyArea <- rast(system.file("extdata","predictors_2012-03-25.tif",package="CAST"))
#' pts <- st_transform(pts, crs = st_crs(studyArea))
#' terra::plot(studyArea[["DEM"]])
#' terra::plot(vect(pts), add = T)
#'
#' nndm_folds <- nndm(pts, modeldomain = studyArea)
#' plot(nndm_folds)
#'
#' #use for cross-validation:
#' library(caret)
#' ctrl <- trainControl(method="cv",
#'    index=nndm_folds$indx_train,
#'    indexOut=nndm_folds$indx_test,
#'    savePredictions='final')
#' model_nndm <- train(dat[,c("DEM","TWI", "NDRE.M")],
#'    dat$VW,
#'    method="rf",
#'    trControl = ctrl)
#' global_validation(model_nndm)
#'}
#'
#' ########################################################################
#' # Example 4: Real- world example; nndm in feature space
#' ########################################################################
#' \dontrun{
#' library(sf)
#' library(terra)
#' library(ggplot2)
#'
#' # Prepare the splot dataset for Chile
#' data(splotdata)
#' splotdata <- splotdata[splotdata$Country == "Chile",]
#'
#' # Select a series of bioclimatic predictors
#' predictors <- c("bio_1", "bio_4", "bio_5", "bio_6",
#'                "bio_8", "bio_9", "bio_12", "bio_13",
#'                "bio_14", "bio_15", "elev")
#'
#' predictors_sp <- terra::rast(system.file("extdata", "predictors_chile.tif", package="CAST"))
#'
#' # Data visualization
#' terra::plot(predictors_sp[["bio_1"]])
#' terra::plot(vect(splotdata), add = T)
#'
#' # Run and visualise the nndm results
#' nndm_folds <- nndm(splotdata[,predictors], modeldomain = predictors_sp, space = "feature")
#' plot(nndm_folds)
#'
#'
#' #use for cross-validation:
#' library(caret)
#' ctrl <- trainControl(method="cv",
#'    index=nndm_folds$indx_train,
#'    indexOut=nndm_folds$indx_test,
#'    savePredictions='final')
#' model_nndm <- train(st_drop_geometry(splotdata[,predictors]),
#'    splotdata$Species_richness,
#'    method="rf",
#'    trControl = ctrl)
#' global_validation(model_nndm)
#'
#' }
nndm <- function(tpoints, modeldomain = NULL, predpoints = NULL,
                 space="geographical",
                 samplesize = 1000, sampling = "regular",
                 phi = "max", min_train = 0.5, algorithm="brute"){


  # 1. Preprocessing actions ----
  if(is.null(predpoints)&!is.null(modeldomain)){

    # Check modeldomain is indeed a sf/SpatRaster
    if(!any(c("sfc", "sf", "SpatRaster") %in% class(modeldomain))){
      stop("modeldomain must be a sf/sfc object or a 'SpatRaster' object.")
    }

    # If modeldomain is a SpatRaster, transform into polygon
    if(any(class(modeldomain) == "SpatRaster")){

      # save predictor stack for extraction if space = "feature"
      if(space == "feature") {
        predictor_stack <- modeldomain
      }
      modeldomain[!is.na(modeldomain)] <- 1
      modeldomain <- terra::as.polygons(modeldomain, values = FALSE, na.all = TRUE) |>
        sf::st_as_sf() |>
        sf::st_union()
      if(any(c("sfc", "sf") %in% class(tpoints))) {
        modeldomain <- sf::st_transform(modeldomain, crs = sf::st_crs(tpoints))
      }
    }

    # Check modeldomain is indeed a polygon sf
    if(!any(class(sf::st_geometry(modeldomain)) %in% c("sfc_POLYGON", "sfc_MULTIPOLYGON"))){
      stop("modeldomain must be a sf/sfc polygon object.")
    }

    # Check whether modeldomain has the same crs as tpoints
    if(!identical(sf::st_crs(tpoints), sf::st_crs(modeldomain)) & space == "geographical"){
      stop("tpoints and modeldomain must have the same CRS")
    }

    # We sample
    message(paste0(samplesize, " prediction points are sampled from the modeldomain"))
    predpoints <- sf::st_sample(x = modeldomain, size = samplesize, type = sampling)
    sf::st_crs(predpoints) <- sf::st_crs(modeldomain)

    if(space == "feature") {
      message("predictor values are extracted for prediction points")
      predpoints <- terra::extract(predictor_stack, terra::vect(predpoints), ID=FALSE)
    }

  }else if(!is.null(predpoints) & space == "geographical"){
    if(!identical(sf::st_crs(tpoints), sf::st_crs(predpoints))){
      stop("tpoints and predpoints must have the same CRS")
    }
  }


  # 2. Data formats, data checks, scaling and categorical variables ----
  if(isTRUE(space == "geographical")) {

    # If tpoints is sfc, coerce to sf.
    if(any(class(tpoints) %in% "sfc")){
      tpoints <- sf::st_sf(geom=tpoints)
    }
    # If predpoints is sfc, coerce to sf.
    if(any(class(predpoints) %in% "sfc")){
      predpoints <- sf::st_sf(geom=predpoints)
    }

    # Input data checks
    nndm_checks_geo(tpoints, predpoints, phi, min_train)

  }else if(isTRUE(space == "feature")){

    # drop geometry if tpoints / predpoints are of class sf
    if(any(class(tpoints) %in% c("sf","sfc"))) {
      tpoints <- sf::st_set_geometry(tpoints, NULL)
    }
    if(any(class(predpoints) %in% c("sf","sfc"))) {
      predpoints <- sf::st_set_geometry(predpoints, NULL)
    }

    # get names of categorical variables
    catVars <- names(tpoints)[which(sapply(tpoints, class)%in%c("factor","character"))]
    if(length(catVars)==0) {
      catVars <- NULL
    }
    if(!is.null(catVars)) {
      message(paste0("variable(s) '", catVars, "' is (are) treated as categorical variables"))
    }

    # omit NAs
    if(any(is.na(predpoints))) {
      message("some prediction points contain NAs, which will be removed")
      predpoints <- stats::na.omit(predpoints)
    }
    if(any(is.na(tpoints))) {
      message("some training points contain NAs, which will be removed")
      tpoints <- stats::na.omit(tpoints)
    }

    # Input data checks
    nndm_checks_feature(tpoints, predpoints, phi, min_train, catVars)

    # Scaling and dealing with categorical factors
    if(is.null(catVars)) {

      scale_attr <- attributes(scale(tpoints))
      tpoints <- scale(tpoints) |> as.data.frame()
      predpoints <- scale(predpoints,center=scale_attr$`scaled:center`,
                          scale=scale_attr$`scaled:scale`) |>
        as.data.frame()

    } else {
      tpoints_cat <- tpoints[,catVars,drop=FALSE]
      predpoints_cat <- predpoints[,catVars,drop=FALSE]

      tpoints_num <- tpoints[,-which(names(tpoints)%in%catVars),drop=FALSE]
      predpoints_num <- predpoints[,-which(names(predpoints)%in%catVars),drop=FALSE]

      scale_attr <- attributes(scale(tpoints_num))
      tpoints <- scale(tpoints_num) |> as.data.frame()
      predpoints <- scale(predpoints_num,center=scale_attr$`scaled:center`,
                          scale=scale_attr$`scaled:scale`) |>
        as.data.frame()
      tpoints <- as.data.frame(cbind(tpoints, lapply(tpoints_cat, as.factor)))
      predpoints <- as.data.frame(cbind(predpoints, lapply(predpoints_cat, as.factor)))

      # 0/1 encode categorical variables (as in R/trainDI.R)
      for (catvar in catVars){
        # mask all unknown levels in newdata as NA
        tpoints[,catvar]<-droplevels(tpoints[,catvar])
        predpoints[,catvar]<-droplevels(predpoints[,catvar])

        # then create dummy variables for the remaining levels in train:
        dvi_train <- predict(caret::dummyVars(paste0("~",catvar), data = tpoints),
                             tpoints)
        dvi_predpoints <- predict(caret::dummyVars(paste0("~",catvar), data = predpoints),
                                  predpoints)
        tpoints <- data.frame(tpoints,dvi_train)
        predpoints <- data.frame(predpoints,dvi_predpoints)

      }
      tpoints <- tpoints[,-which(names(tpoints)%in%catVars)]
      predpoints <- predpoints[,-which(names(predpoints)%in%catVars)]

    }
  }

  # 3. Distance and phi computation ----
  if(isTRUE(space=="geographical")){

    # Compute nearest neighbour distances between training and prediction points
    Gij <- sf::st_distance(predpoints, tpoints)
    units(Gij) <- NULL
    Gij <- apply(Gij, 1, min)

    # Compute distance matrix of training points
    tdist <- sf::st_distance(tpoints)
    units(tdist) <- NULL
    diag(tdist) <- NA
    Gj <- apply(tdist, 1, function(x) min(x, na.rm=TRUE))
    Gjstar <- Gj

    # if phi==max calculate the maximum relevant distance
    if(phi=="max"){
      phi <- max(c(Gij, c(tdist)), na.rm=TRUE) + 1e-9
    }


  }else if(isTRUE(space=="feature")){

    if(is.null(catVars)) {

      # Euclidean distances if no categorical variables are present
      Gij <- c(FNN::knnx.dist(query = predpoints, data = tpoints, k = 1, algorithm=algorithm))
      tdist <- as.matrix(stats::dist(tpoints, upper = TRUE))
      diag(tdist) <- NA
      Gj <- apply(tdist, 1, function(x) min(x, na.rm=TRUE))
      Gjstar <- Gj

    } else {

      # Gower distances if categorical variables are present
      Gj <- sapply(1:nrow(tpoints), function(i) gower::gower_topn(tpoints[i,], tpoints[-i,], n=1)$distance[[1]])
      tdist <- matrix(NA, nrow=nrow(tpoints), ncol=nrow(tpoints))
      for(r in 1:nrow(tdist)){
        tdist[r,] <- gower::gower_dist(tpoints[r,], tpoints)
      }
      diag(tdist) <- NA
      Gj <- apply(tdist, 1, function(x) min(x, na.rm=TRUE))
      Gjstar <- Gj

    }

    # if phi==max calculate the maximum relevant distance
    if(phi=="max"){
      phi <- max(c(Gij, c(tdist)), na.rm=TRUE) + 1e-9
    }
  }

  # Start algorithm
  rmin <- min(Gjstar)
  jmin <- which.min(Gjstar)[1]
  kmin <- which(tdist[jmin,]==rmin)

  while(rmin <= phi){

    # Check if removing the point improves the match. If yes, update
    if((sum(Gjstar<=rmin)-1)/length(Gjstar) >= (sum(Gij<=rmin)/length(Gij)) &
       sum(!is.na(tdist[jmin, ]))/ncol(tdist) > min_train){
      tdist[jmin, kmin] <- NA
      Gjstar <- apply(tdist, 1, function(x) min(x, na.rm=TRUE))
      rmin <- min(Gjstar[Gjstar>=rmin]) # Distances are the same for the same pair
      jmin <- which(Gjstar==rmin)[1]
      kmin <- which(tdist[jmin,]==rmin)

    }else if(sum(Gjstar>rmin)==0){
      break
    }else{ # Otherwise move on to the next distance
      rmin <- min(Gjstar[Gjstar>rmin])
      jmin <- which(Gjstar==rmin)[1]
      kmin <- which(tdist[jmin,]==rmin)
    }
  }

  # Derive indicators
  indx_train <- list()
  indx_test <- list()
  indx_exclude <- list()
  for(i in 1:nrow(tdist)){
    indx_train[[i]] <- which(!is.na(tdist[i,]))
    indx_test[[i]] <- i
    indx_exclude[[i]] <- setdiff(which(is.na(tdist[i,])), i)
  }

  # Return list of indices
  res <- list(indx_train=indx_train, indx_test=indx_test,
              indx_exclude=indx_exclude, Gij=Gij, Gj=Gj, Gjstar=Gjstar, phi=phi)
  class(res) <- c("nndm", "list")
  res

}


# Input data checks for NNDM
nndm_checks_geo <- function(tpoints, predpoints, phi, min_train){

  # Check for valid range of phi
  if(phi < 0 | (!is.numeric(phi) & phi!= "max")){
    stop("phi must be positive or set to 'max'.")
  }

  # min_train must be a single positive numeric
  if(length(min_train)!=1 | min_train<0 | min_train>1 | !is.numeric(min_train)){
    stop("min_train must be a numeric between 0 and 1.")
  }

  # Check class and geometry type of tpoints
  if(!any(c("sfc", "sf") %in% class(tpoints))){
    stop("tpoints must be a sf/sfc object.")
  }else if(!any(class(sf::st_geometry(tpoints)) %in% c("sfc_POINT"))){
    stop("tpoints must be a sf/sfc point object.")
  }

  # Check class and geometry type of predpoints
  if(!any(c("sfc", "sf") %in% class(predpoints))){
    stop("predpoints must be a sf/sfc object.")
  }else if(!any(class(sf::st_geometry(predpoints)) %in% c("sfc_POINT"))){
    stop("predpoints must be a sf/sfc point object.")
  }

}

nndm_checks_feature <- function(tpoints, predpoints, phi, min_train, catVars){

  # Check for valid range of phi
  if(phi < 0 | (!is.numeric(phi) & phi!= "max")){
    stop("phi must be positive or set to 'max'.")
  }

  # min_train must be a single positive numeric
  if(length(min_train)!=1 | min_train<0 | min_train>1 | !is.numeric(min_train)){
    stop("min_train must be a numeric between 0 and 1.")
  }

  if(length(setdiff(names(tpoints), names(predpoints)))>0) {
    stop("tpoints and predpoints need to contain the predictor data and have the same colnames.")
  }

  for (catvar in catVars) {
    if (any(!unique(tpoints[,catvar]) %in% unique(predpoints[,catvar]))) {
      stop(paste0("Some values of factor", catvar, "are only present in training / prediction points.
                  All factor values in the prediction points must be present in the training points."))
    }
  }
}


#' Normalize DI values
#' @description
#' The DI is normalized by the DI threshold to allow for a more straightforward interpretation.
#' A value in the resulting DI larger 1 means that the data are more dissimilar than what has been observed during cross-validation.
#' The returned threshold is adjusted accordingly and is, as a consequence, 1.
#' @param AOA An AOA object
#' @return An object of class \code{aoa}
#' @seealso \code{\link{aoa}}
#' @examples
#' \dontrun{
#' library(sf)
#' library(terra)
#' library(caret)
#'
#' # prepare sample data:
#' data(cookfarm)
#' dat <- aggregate(cookfarm[,c("VW","Easting","Northing")],
#'    by=list(as.character(cookfarm$SOURCEID)),mean)
#' pts <- st_as_sf(dat,coords=c("Easting","Northing"))
#' pts$ID <- 1:nrow(pts)
#' set.seed(100)
#' pts <- pts[1:30,]
#' studyArea <- rast(system.file("extdata","predictors_2012-03-25.tif",package="CAST"))[[1:8]]
#' trainDat <- extract(studyArea,pts,na.rm=FALSE)
#' trainDat <- merge(trainDat,pts,by.x="ID",by.y="ID")
#'
#' # train a model:
#' set.seed(100)
#' variables <- c("DEM","NDRE.Sd","TWI")
#' model <- train(trainDat[,which(names(trainDat)%in%variables)],
#' trainDat$VW, method="rf", importance=TRUE, tuneLength=1,
#' trControl=trainControl(method="cv",number=5,savePredictions=T))
#'
#' #...then calculate the AOA of the trained model for the study area:
#' AOA <- aoa(studyArea, model)
#' plot(AOA)
#' plot(AOA$DI)
#'
#' #... then normalize the DI
#' DI_norm <- normalize_DI(AOA)
#' plot(DI_norm)
#' plot(DI_norm$DI)
#'
#' }
#' @export normalize_DI
#' @aliases normalize_DI


normalize_DI <- function(AOA) {
  AOA$DI <- AOA$DI/AOA$parameters$threshold
  AOA$parameters$trainDI <- AOA$parameters$trainDI/AOA$parameters$threshold
  AOA$parameters$threshold <- AOA$parameters$threshold/AOA$parameters$threshold
  return(AOA)
}



#' Plot CAST classes
#' @description Generic plot function for CAST Classes
#'
#' @name plot
#' @param x trainDI object
#' @param ... other params
#'
#'
#' @author Marvin Ludwig, Hanna Meyer
#' @export

plot.trainDI = function(x, ...){
  ggplot(data.frame(TrainDI = x$trainDI), aes_string(x = "TrainDI"))+
    geom_density()+
    geom_vline(aes(xintercept = x$threshold, linetype = "AOA_threshold"))+
    scale_linetype_manual(name = "", values = c(AOA_threshold = "dashed"))+
    theme_bw()+
    theme(legend.position="bottom")

}








#' @name plot
#'
#' @param x aoa object
#' @param samplesize numeric. How many prediction samples should be plotted?
#' @param variable character. Variable for which to generate the density plot. 'DI' or 'LPD'
#' @param ... other params
#'
#' @import ggplot2
#'
#' @author Marvin Ludwig, Hanna Meyer
#'
#' @export

plot.aoa = function(x, samplesize = 1000, variable = "DI", ...){

  if (variable == "DI") {
    trainDI = data.frame(DI = x$parameters$trainDI,
                         what = "trainDI")

    if(inherits(x$AOA, "RasterLayer")){
      targetDI = terra::spatSample(methods::as(x$DI, "SpatRaster"),
                                   size = samplesize, method = "regular")
      targetDI = data.frame(DI = as.numeric(targetDI[, 1]),
                            what = "predictionDI")
    }else if(inherits(x$AOA, "stars")){
      targetDI = terra::spatSample(methods::as(x$DI, "SpatRaster"),
                                   size = samplesize, method = "regular")
      targetDI = data.frame(DI = as.numeric(targetDI[, 1]),
                            what = "predictionDI")
    }else if(inherits(x$AOA, "SpatRaster")){
      targetDI = terra::spatSample(x$DI, size = samplesize, method = "regular")
      targetDI = data.frame(DI = as.numeric(targetDI[, 1]),
                            what = "predictionDI")
    }else{
      targetDI = data.frame(DI = sample(x$DI, size = samplesize),
                            what = "predictionDI")
    }

    dfDI = rbind(trainDI, targetDI)

    plot = ggplot(dfDI, aes_string(x = "DI", group = "what", fill = "what"))+
      geom_density(adjust=1.5, alpha=.4)+
      scale_fill_discrete(name = "Set")+
      geom_vline(aes(xintercept = x$parameters$threshold, linetype = "AOA_threshold"))+
      scale_linetype_manual(name = "", values = c(AOA_threshold = "dashed"))+
      theme_bw()+
      theme(legend.position = "bottom")
  }


  if (variable == "LPD") {
    trainLPD = data.frame(LPD = x$parameters$trainLPD,
                          what = "trainLPD")



    if(inherits(x$AOA, "RasterLayer")){
      targetLPD = terra::spatSample(methods::as(x$LPD, "SpatRaster"),
                                    size = samplesize, method = "regular")
      targetLPD = data.frame(LPD = as.numeric(targetLPD[, 1]),
                             what = "predictionLPD")
    }else if(inherits(x$AOA, "stars")){
      targetLPD = terra::spatSample(methods::as(x$LPD, "SpatRaster"),
                                    size = samplesize, method = "regular")
      targetLPD = data.frame(LPD = as.numeric(targetLPD[, 1]),
                             what = "predictionLPD")
    }else if(inherits(x$AOA, "SpatRaster")){
      targetLPD = terra::spatSample(x$LPD, size = samplesize, method = "regular")
      targetLPD = data.frame(LPD = as.numeric(targetLPD[, 1]),
                             what = "predictionLPD")
    }else{
      targetLPD = data.frame(LPD = sample(x$LPD, size = samplesize),
                             what = "predictionLPD")
    }

    dfLPD = rbind(trainLPD, targetLPD)


    plot = ggplot(dfLPD, aes_string(x = "LPD", group = "what", fill = "what"))+
      geom_density(adjust=1.5, alpha=0.4)+
      scale_fill_discrete(name = "Set")+
      geom_vline(aes(xintercept = median(x$parameters$trainLPD), linetype = "MtrainLPD"))+
      scale_linetype_manual(name = "", values = c(MtrainLPD = "dashed"))+
      theme_bw()+
      theme(legend.position = "bottom")

  }

  return(plot)
}



#' @name plot
#' @param x An object of type \emph{nndm}.
#' @param type String, defaults to "strict" to show the original nearest neighbour distance definitions in the legend.
#' Alternatively, set to "simple" to have more intuitive labels.
#' @param ... other arguments.
#' @author Carles Milà
#'
#' @export
plot.nndm <- function(x, type="strict", stat = "ecdf", ...){

  # Prepare data for plotting: Gij function
  Gij_df <- data.frame(r=x$Gij[order(x$Gij)])
  Gij_df$val <- 1:nrow(Gij_df)/nrow(Gij_df)
  Gij_df <- Gij_df[Gij_df$r <= x$phi,]
  Gij_df <- rbind(Gij_df, data.frame(r=0, val=0))
  Gij_df <- rbind(Gij_df, data.frame(r=x$phi,
                                     val=sum(x$Gij<=x$phi)/length(x$Gij)))
  Gij_df$Function <- "1_Gij(r)"

  # Prepare data for plotting: Gjstar function
  Gjstar_df <- data.frame(r=x$Gjstar[order(x$Gjstar)])
  Gjstar_df$val <- 1:nrow(Gjstar_df)/nrow(Gjstar_df)
  Gjstar_df <- Gjstar_df[Gjstar_df$r <= x$phi,]
  Gjstar_df <- rbind(Gjstar_df, data.frame(r=0, val=0))
  Gjstar_df <- rbind(Gjstar_df, data.frame(r=x$phi,
                                           val=sum(x$Gjstar<=x$phi)/length(x$Gjstar)))
  Gjstar_df$Function <- "2_Gjstar(r)"

  # Prepare data for plotting: G function
  Gj_df <- data.frame(r=x$Gj[order(x$Gj)])
  Gj_df$val <- 1:nrow(Gj_df)/nrow(Gj_df)
  Gj_df <- Gj_df[Gj_df$r <= x$phi,]
  Gj_df <- rbind(Gj_df, data.frame(r=0, val=0))
  Gj_df <- rbind(Gj_df, data.frame(r=x$phi,
                                   val=sum(x$Gj<=x$phi)/length(x$Gj)))
  Gj_df$Function <- "3_Gj(r)"

  # Merge data for plotting, get maxdist relevant for plotting
  if(any(Gj_df$val==1)&any(Gjstar_df$val==1)&any(Gij_df$val==1)){
    Gplot <- rbind(Gij_df, Gjstar_df, Gj_df)
    maxdist <- max(Gplot$r[Gplot$val!=1]) + 1e-9
    Gplot <- Gplot[Gplot$r <= maxdist,]
    Gplot <- rbind(Gplot, data.frame(r=maxdist, val=1,
                                     Function = c("1_Gij(r)", "2_Gjstar(r)", "3_Gj(r)")))
  }else{
    Gplot <- rbind(Gij_df, Gjstar_df, Gj_df)
  }

  # Define colours matching those of geodist
  myColors <- RColorBrewer::brewer.pal(3, "Dark2")

  # Plot
  if(stat=="ecdf"){
    p <- ggplot2::ggplot(data=Gplot, ggplot2::aes_string(x="r", group="Function", col="Function")) +
      ggplot2::geom_vline(xintercept=0, lwd = 0.1) +
      ggplot2::geom_hline(yintercept=0, lwd = 0.1) +
      ggplot2::geom_hline(yintercept=1, lwd = 0.1) +
      ggplot2::stat_ecdf(geom = "step", lwd = 0.8) +
      ggplot2::theme_bw() +
      ggplot2::ylab("ECDF") +
      ggplot2::labs(group="Distance function", col="Distance function") +
      ggplot2::theme(legend.position = "bottom",
                     legend.text=ggplot2::element_text(size=10))

    if(type=="strict"){
      p <-  p +
        ggplot2::scale_colour_manual(values=c(myColors[2], myColors[3], myColors[1]),
                                     labels=c(expression(hat(G)[ij](r)),
                                              expression(hat(G)[j]^"*"*"(r,L)"),
                                              expression(hat(G)[j](r))))
    }else if(type == "simple"){
      p <-  p +
        ggplot2::scale_colour_manual(values=c(myColors[2], myColors[3], myColors[1]),
                                     labels=c("prediction-to-sample",
                                              "CV-distances",
                                              "sample-to-sample"))
    }

  }else if(stat=="density"){
    p <- ggplot2::ggplot(data=Gplot, ggplot2::aes_string(x="r", group="Function", fill="Function")) +
      ggplot2::geom_density(adjust=1.5, alpha=.5, stat=stat, lwd = 0.3) +
      ggplot2::theme_bw() +
      ggplot2::ylab("Density") +
      ggplot2::labs(group="Distance function", col="Distance function") +
      ggplot2::theme(legend.position = "bottom",
                     legend.text=ggplot2::element_text(size=10))

    if(type=="strict"){
      p <-  p +
        ggplot2::scale_fill_manual(values=c(myColors[2], myColors[3], myColors[1]),
                                   labels=c(expression(hat(G)[ij](r)),
                                            expression(hat(G)[j]^"*"*"(r,L)"),
                                            expression(hat(G)[j](r))))
    }else if(type == "simple"){
      p <-  p +
        ggplot2::scale_fill_manual(values=c(myColors[2], myColors[3], myColors[1]),
                                   labels=c("prediction-to-sample",
                                            "CV-distances",
                                            "sample-to-sample"))
    }
  }

  p

}


#' @name plot
#' @param x An object of type \emph{knndm}.
#' @param type String, defaults to "strict" to show the original nearest neighbour distance definitions in the legend.
#' Alternatively, set to "simple" to have more intuitive labels.
#' @param stat String, defaults to "ecdf" but can be set to "density" to estimate density functions.
#' @param ... other arguments.
#' @author Carles Milà
#'
#' @export
plot.knndm <- function(x, type="strict", stat = "ecdf", ...){

  # Prepare data for plotting: Gij function
  Gij_df <- data.frame(r=x$Gij[order(x$Gij)])
  Gij_df$Function <- "1_Gij(r)"

  # Prepare data for plotting: Gjstar function
  Gjstar_df <- data.frame(r=x$Gjstar[order(x$Gjstar)])
  Gjstar_df$Function <- "2_Gjstar(r)"

  # Prepare data for plotting: G function
  Gj_df <- data.frame(r=x$Gj[order(x$Gj)])
  Gj_df$Function <- "3_Gj(r)"

  # Merge data for plotting
  Gplot <- rbind(Gij_df, Gjstar_df, Gj_df)

  # Define colours matching those of geodist
  myColors <- RColorBrewer::brewer.pal(3, "Dark2")

  # Plot
  if(stat=="ecdf"){
    p <- ggplot2::ggplot(data=Gplot, ggplot2::aes_string(x="r", group="Function", col="Function")) +
      ggplot2::geom_vline(xintercept=0, lwd = 0.1) +
      ggplot2::geom_hline(yintercept=0, lwd = 0.1) +
      ggplot2::geom_hline(yintercept=1, lwd = 0.1) +
      ggplot2::stat_ecdf(geom = "step", lwd = 0.8) +
      ggplot2::theme_bw() +
      ggplot2::ylab("ECDF") +
      ggplot2::labs(group="Distance function", col="Distance function") +
      ggplot2::theme(legend.position = "bottom",
                     legend.text=ggplot2::element_text(size=10))

    if(type=="strict"){
      p <-  p +
        ggplot2::scale_colour_manual(values=c(myColors[2], myColors[3], myColors[1]),
                                     labels=c(expression(hat(G)[ij](r)),
                                              expression(hat(G)[j]^"*"*"(r,L)"),
                                              expression(hat(G)[j](r))))
    }else if(type == "simple"){
      p <-  p +
        ggplot2::scale_colour_manual(values=c(myColors[2], myColors[3], myColors[1]),
                                     labels=c("prediction-to-sample",
                                              "CV-distances",
                                              "sample-to-sample"))
    }

  }else if(stat=="density"){
    p <- ggplot2::ggplot(data=Gplot, ggplot2::aes_string(x="r", group="Function", fill="Function")) +
      ggplot2::geom_density(adjust=1.5, alpha=.5, stat=stat, lwd = 0.3) +
      ggplot2::theme_bw() +
      ggplot2::ylab("Density") +
      ggplot2::labs(group="Distance function", col="Distance function") +
      ggplot2::theme(legend.position = "bottom",
                     legend.text=ggplot2::element_text(size=10))

    if(type=="strict"){
      p <-  p +
        ggplot2::scale_fill_manual(values=c(myColors[2], myColors[3], myColors[1]),
                                   labels=c(expression(hat(G)[ij](r)),
                                            expression(hat(G)[j]^"*"*"(r,L)"),
                                            expression(hat(G)[j](r))))
    }else if(type == "simple"){
      p <-  p +
        ggplot2::scale_fill_manual(values=c(myColors[2], myColors[3], myColors[1]),
                                   labels=c("prediction-to-sample",
                                            "CV-distances",
                                            "sample-to-sample"))
    }
  }

  p
}

#' Plot results of a Forward feature selection or best subset selection
#' @description A plotting function for a forward feature selection result.
#' Each point is the mean performance of a model run. Error bars represent
#' the standard errors from cross validation.
#' Marked points show the best model from each number of variables until a further variable
#' could not improve the results.
#' If type=="selected", the contribution of the selected variables to the model
#' performance is shown.
#' @param x Result of a forward feature selection see \code{\link{ffs}}
#' @param plotType character. Either "all" or "selected"
#' @param palette A color palette
#' @param reverse Character. Should the palette be reversed?
#' @param marker Character. Color to mark the best models
#' @param size Numeric. Size of the points
#' @param lwd Numeric. Width of the error bars
#' @param pch Numeric. Type of point marking the best models
#' @param ... Further arguments for base plot if type="selected"
#' @author Marvin Ludwig, Hanna Meyer
#' @examples
#' \dontrun{
#' data(splotdata)
#' splotdata <- st_drop_geometry(splotdata)
#' ffsmodel <- ffs(splotdata[,6:16], splotdata$Species_richness, ntree = 10)
#' plot(ffsmodel)
#' #plot performance of selected variables only:
#' plot(ffsmodel,plotType="selected")
#'}
#' @name plot
#' @importFrom forcats fct_rev fct_inorder
#' @export



plot.ffs <- function(x,plotType="all",palette=rainbow,reverse=FALSE,
                     marker="black",size=1.5,lwd=0.5,
                     pch=21,...){
  metric <- x$metric
  if (is.null(x$type)){
    x$type <- "ffs"
  }
  if(is.null(x$minVar)){
    x$minVar <- 2
  }
  if(x$type=="bss"&plotType=="selected"){
    type <- "all"
    print("warning: type must be 'all' for a bss model")
  }
  if (plotType=="selected"){

    plot_df = data.frame(labels = forcats::fct_rev(forcats::fct_inorder(c(paste(x$selectedvars[1:x$minVar], collapse = "\n + "),
                                    paste("+", x$selectedvars[-1:-x$minVar], sep = " ")))),
                         perf = x$selectedvars_perf,
                         perfse = x$selectedvars_perf_SE)


    p <- ggplot(plot_df, aes_string(x = "perf", y = "labels"))+
      geom_point()+
      geom_segment(aes_string(x = "perf - perfse", xend = "perf + perfse",
                       y = "labels", yend = "labels"))+
      xlab(x$metric)+
      ylab(NULL)
    return(p)

  }else{


    output_df <- x$perf_all
    output_df$run <- seq(nrow(output_df))
    names(output_df)[which(names(output_df)==metric)] <- "value"

    if (x$type=="bss"){
      bestmodels <- output_df$run[which(output_df$value==x$selectedvars_perf)]
    }else{

      bestmodels <- c()
      for (i in unique(output_df$nvar)){
        if (x$maximize){
          bestmodels <- c(bestmodels,
                          output_df$run[output_df$nvar==i][which(output_df$value[
                            output_df$nvar==i]==max(output_df$value[output_df$nvar==i]))][1])
        }else{
          bestmodels <- c(bestmodels,
                          output_df$run[output_df$nvar==i][which(output_df$value[
                            output_df$nvar==i]==min(output_df$value[output_df$nvar==i]))][1])
        }
      }
      bestmodels <- bestmodels[1:(length(x$selectedvars)-1)]
    }

    if (!reverse){
      cols <- palette(max(output_df$nvar)-(min(output_df$nvar)-1))
    }else{
      cols <- rev(palette(max(output_df$nvar)-(min(output_df$nvar)-1)))
    }
    ymin <- output_df$value - output_df$SE
    ymax <- output_df$value + output_df$SE
    if (max(output_df$nvar)>11){
      p <- ggplot2::ggplot(output_df, ggplot2::aes_string(x = "run", y = "value"))+
        ggplot2::geom_errorbar(ggplot2::aes(ymin = ymin, ymax = ymax),
                               color = cols[output_df$nvar-(min(output_df$nvar)-1)],lwd=lwd)+
        ggplot2::geom_point(ggplot2::aes_string(colour="nvar"),size=size)+
        ggplot2::geom_point(data=output_df[bestmodels, ],
                            ggplot2::aes_string(x = "run", y = "value"),
                            pch=pch,colour=marker,size=size)+
        ggplot2::scale_x_continuous(name = "Model run", breaks = pretty(output_df$run))+
        ggplot2::scale_y_continuous(name = metric)+
        ggplot2::scale_colour_gradientn(breaks=seq(2,max(output_df$nvar),
                                                   by=ceiling(max(output_df$nvar)/5)),
                                        colours = cols, name = "variables",guide = "colourbar")
    }else{
      dfint <- output_df
      dfint$nvar <- as.factor(dfint$nvar)
      p <- ggplot2::ggplot(dfint, ggplot2::aes_string(x = "run", y = "value"))+
        ggplot2::geom_errorbar(ggplot2::aes(ymin = ymin, ymax = ymax),
                               color = cols[output_df$nvar-(min(output_df$nvar)-1)],lwd=lwd)+
        ggplot2::geom_point(ggplot2::aes_string(colour="nvar"),size=size)+
        ggplot2::geom_point(data=output_df[bestmodels, ],
                            ggplot2::aes_string(x = "run", y = "value"),
                            pch=pch,colour=marker,size=size)+
        ggplot2::scale_x_continuous(name = "Model run", breaks = pretty(dfint$run))+
        ggplot2::scale_y_continuous(name = metric)+
        ggplot2::scale_colour_manual(values = cols, name = "variables")

    }

    return(p)
  }
}


#' @name plot
#' @description Density plot of nearest neighbor distances in geographic space or feature space between training data as well as between training data and
#' prediction locations.
#' Optional, the nearest neighbor distances between training data and test data or between training data and CV iterations is shown.
#' The plot can be used to check the suitability of a chosen CV method to be representative to estimate map accuracy.
#' @param x geodist, see \code{\link{geodist}}
#' @param unit character. Only if type=="geo" and only applied to the plot. Supported: "m" or "km".
#' @param stat "density" for density plot or "ecdf" for empirical cumulative distribution function plot.
#' @export
#' @return a ggplot
#'



plot.geodist <- function(x, unit = "m", stat = "density", ...){

  # Define colours - they must match those of knndm and nndm
  labs <- c("sample-to-sample",
            "prediction-to-sample",
            "CV-distances",
            "test-to-sample")
  myColors <- RColorBrewer::brewer.pal(4, "Dark2")
  names(myColors) <- labs


  type <- attr(x, "type")

  if(unit=="km"){
    x$dist <- x$dist/1000
    xlabs <- "geographic distances (km)"
  }else{
    xlabs <- "geographic distances (m)"
  }

  if( type=="feature"){ xlabs <- "feature space distances"}
  what <- "" #just to avoid check note
  if (type=="feature"){unit ="unitless"}

  if (type=="time"){unit = attr(x,"unit")}
  if( type=="time"){ xlabs <- paste0("temporal distances (",unit,")")}
  if(stat=="density"){
    p <- ggplot2::ggplot(data=x, aes(x=dist, group=what, fill=what)) +
      ggplot2::geom_density(adjust=1.5, alpha=.5, stat=stat, lwd = 0.3) +
      ggplot2::scale_fill_manual(name = "distance function", values = myColors) +
      ggplot2::theme_bw() +
      ggplot2::xlab(xlabs) +
      ggplot2::theme(legend.position="bottom",
                     plot.margin = unit(c(0,0.5,0,0),"cm"))
  }else if(stat=="ecdf"){
    p <- ggplot2::ggplot(data=x, aes(x=dist, group=what, col=what)) +
      ggplot2::geom_vline(xintercept=0, lwd = 0.1) +
      ggplot2::geom_hline(yintercept=0, lwd = 0.1) +
      ggplot2::geom_hline(yintercept=1, lwd = 0.1) +
      ggplot2::stat_ecdf(geom = "step", lwd = 1) +
      ggplot2::scale_color_manual(name = "distance function", values = myColors) +
      ggplot2::theme_bw() +
      ggplot2::xlab(xlabs) +
      ggplot2::ylab("ECDF") +
      ggplot2::theme(legend.position="bottom",
                     plot.margin = unit(c(0,0.5,0,0),"cm"))
  }
  p
}


#' @name plot
#' @description Plot the DI/LPD and errormetric from Cross-Validation with the modeled relationship
#' @param x errorModel, see \code{\link{DItoErrormetric}}
#' @param ... other params
#' @export
#' @return a ggplot
#'


plot.errorModel <- function(x, ...){

  variable = attr(x, "variable")
  metric = attr(x, "metric")

  performance = attr(x, "performance")[,c(variable, "metric")]
  performance$what = "cross-validation"

  model_line = data.frame(variable = performance[, variable],
                          metric = predict(x, performance),
                          what = "model")

  p = ggplot()+
    geom_point(data = performance, mapping = aes_string(x = variable, y = "metric", shape = "what"))+
    geom_line(data = model_line, mapping =  aes_string(x = "variable", y = "metric", linetype = "what"), lwd = 1)+
    labs(x = variable, y = metric)+
    theme(legend.title = element_blank(), legend.position = "bottom")

  return(p)

}


#' Print CAST classes
#' @description Generic print function for trainDI and aoa
#' @name print
#' @param x trainDI object
#' @param ... other params
#' @export

print.trainDI = function(x, ...){
  cat(paste0("DI of ", nrow(x$train), " observation \n"))
  cat(paste0("Predictors:"), x$variables, "\n\n")

  cat("AOA Threshold: ")
  cat(x$threshold)

}


#' @name print
#' @param x trainDI object
#' @param ... other params
#' @export

show.trainDI = function(x, ...){
  print.trainDI(x)
}





#' @name print
#' @param x aoa object
#' @param ... other params
#' @export


print.aoa = function(x, ...){
  cat("DI:\n")
  print(x$DI)

  if ("LPD" %in% names(x)) {
    cat("LPD:\n")
    print(x$LPD)
  }

  cat("AOA:\n")
  print(x$AOA)

  cat("\n\nPredictor Weights:\n")

  print(x$parameters$weight)

  cat("\n\nAOA Threshold: ")
  cat(x$parameters$threshold)



}


#' @name print
#' @param x aoa object
#' @param ... other params
#' @export


show.aoa = function(x, ...){
  print.aoa(x)
}



#' @name print
#' @param x An object of type \emph{nndm}.
#' @param ... other arguments.
#' @export
print.nndm <- function(x, ...){
  mean_train <- round(mean(sapply(x$indx_train, length)), 2)
  min_train <- round(min(sapply(x$indx_train, length)), 2)
  cat(paste0("nndm object\n",
             "Total number of points: ", length(x$Gj), "\n",
             "Mean number of training points: ", mean_train, "\n",
             "Minimum number of training points: ", min_train, "\n"))
}

#' @name print
#' @param x An object of type \emph{nndm}.
#' @param ... other arguments.
#' @export
show.nndm = function(x, ...){
  print.nndm(x)
}


#' @name print
#' @param x An object of type \emph{knndm}.
#' @param ... other arguments.
#' @export
print.knndm <- function(x, ...){
  cat(paste0("knndm object\n",
             "Space: ", x$space, "\n",
             "Clustering algorithm: ", x$method, "\n",
             "Intermediate clusters (q): ", x$q, "\n",
             "W statistic: ", round(x$W, 4), "\n",
             "Number of folds: ", length(unique(x$clusters)),  "\n",
             "Observations in each fold: "),
      table(x$clusters), "\n")
}

#' @name print
#' @param x An object of type \emph{knndm}.
#' @param ... other arguments.
#' @export
show.knndm = function(x, ...){
  print.knndm(x)
}



#' @name print
#' @param x An object of type \emph{ffs}
#' @param ... other arguments.
#' @export


print.ffs = function(x, ...){
  cat("Selected Variables: \n")
  cat(x$selectedvars)
  cat("\n")
  cat("---\n")
  print.train(x)
}


#' @name print
#' @param x An object of type \emph{ffs}
#' @param ... other arguments.
#' @export

show.ffs = function(x, ...){
  print.ffs(x)

}






#' sPlotOpen Data of Species Richness
#'
#' sPlotOpen Species Richness for South America with associated predictors
#' @format
#' A sf points / data.frame with 703 rows and 17 columns:
#' \describe{
#'   \item{PlotObeservationID, GIVD_ID, Country, Biome}{sPlotOpen Metadata}
#'   \item{Species_richness}{Response Variable - Plant species richness from sPlotOpen}
#'   \item{bio_x, elev}{Predictor Variables - Worldclim and SRTM elevation}
#'   \item{geometry}{Lat/Lon}
#' }
#' @source \itemize{
#' \item{Plot with Species_richness from \href{https://onlinelibrary.wiley.com/doi/full/10.1111/geb.13346}{sPlotOpen}}
#' \item{predictors acquired via R package \href{https://github.com/rspatial/geodata}{geodata}}
#' }
#'
#' @references \itemize{
#' \item{Sabatini, F. M. et al. sPlotOpen – An environmentally balanced, open‐access, global dataset of vegetation plots. (2021). \doi{10.1111/geb.13346}}
#' \item{Lopez-Gonzalez, G. et al. ForestPlots.net: a web application and research tool to manage and analyse tropical forest plot data: ForestPlots.net.
#'  Journal of Vegetation Science (2011).}
#' \item{Pauchard, A. et al. Alien Plants Homogenise Protected Areas: Evidence from the Landscape and Regional Scales in South Central Chile. in Plant Invasions in Protected Areas (2013).}
#' \item{Peyre, G. et al. VegPáramo, a flora and vegetation database for the Andean páramo. phytocoenologia (2015).}
#' \item{Vibrans, A. C. et al. Insights from a large-scale inventory in the southern Brazilian Atlantic Forest. Scientia Agricola (2020).}
#' }
#' @usage data(splotdata)
#'
"splotdata"


#' Calculate Dissimilarity Index of training data
#' @description
#' This function estimates the Dissimilarity Index (DI)
#' within the training data set used for a prediction model.
#' Optionally, the local point density can also be calculated.
#' Predictors can be weighted based on the internal
#' variable importance of the machine learning algorithm used for model training.
#' @note
#' This function is called within \code{\link{aoa}} to estimate the DI and AOA of new data.
#' However, it may also be used on its own if only the DI of training data is of interest,
#' or to facilitate a parallelization of \code{\link{aoa}} by avoiding a repeated calculation of the DI within the training data.
#'
#' @param model A train object created with caret used to extract weights from (based on variable importance) as well as cross-validation folds
#' @param train A data.frame containing the data used for model training. Only required when no model is given
#' @param weight A data.frame containing weights for each variable. Only required if no model is given.
#' @param variables character vector of predictor variables. if "all" then all variables
#' of the model are used or if no model is given then of the train dataset.
#' @param CVtest list or vector. Either a list where each element contains the data points used for testing during the cross validation iteration (i.e. held back data).
#' Or a vector that contains the ID of the fold for each training point.
#' Only required if no model is given.
#' @param CVtrain list. Each element contains the data points used for training during the cross validation iteration (i.e. held back data).
#' Only required if no model is given and only required if CVtrain is not the opposite of CVtest (i.e. if a data point is not used for testing, it is used for training).
#' Relevant if some data points are excluded, e.g. when using \code{\link{nndm}}.
#' @param method Character. Method used for distance calculation. Currently euclidean distance (L2) and Mahalanobis distance (MD) are implemented but only L2 is tested. Note that MD takes considerably longer.
#' @param useWeight Logical. Only if a model is given. Weight variables according to importance in the model?
#' @param useCV Logical. Only if a model is given. Use the CV folds to calculate the DI threshold?
#' @param LPD Logical. Indicates whether the local point density should be calculated or not.
#' @param verbose Logical. Print progress or not?
#' @param algorithm see \code{\link[FNN]{knnx.dist}} and \code{\link[FNN]{knnx.index}}
#' @seealso \code{\link{aoa}}
#' @importFrom graphics boxplot
#' @import ggplot2
#'
#' @return A list of class \code{trainDI} containing:
#'  \item{train}{A data frame containing the training data}
#'  \item{weight}{A data frame with weights based on the variable importance.}
#'  \item{variables}{Names of the used variables}
#'  \item{catvars}{Which variables are categorial}
#'  \item{scaleparam}{Scaling parameters. Output from \code{scale}}
#'  \item{trainDist_avrg}{A data frame with the average distance of each training point to every other point}
#'  \item{trainDist_avrgmean}{The mean of trainDist_avrg. Used for normalizing the DI}
#'  \item{trainDI}{Dissimilarity Index of the training data}
#'  \item{threshold}{The DI threshold used for inside/outside AOA}
#'  \item{trainLPD}{LPD of the training data}
#'  \item{avrgLPD}{Average LPD of the training data}
#'
#'
#'
#' @export trainDI
#'
#' @author
#' Hanna Meyer, Marvin Ludwig, Fabian Schumacher
#'
#' @references Meyer, H., Pebesma, E. (2021): Predicting into unknown space?
#' Estimating the area of applicability of spatial prediction models.
#' \doi{10.1111/2041-210X.13650}
#'
#'
#' @examples
#' \dontrun{
#' library(sf)
#' library(terra)
#' library(caret)
#' library(CAST)
#'
#' # prepare sample data:
#' data("splotdata")
#' splotdata = st_drop_geometry(splotdata)
#'
#' # train a model:
#' set.seed(100)
#' model <- caret::train(splotdata[,6:16],
#'                       splotdata$Species_richness,
#'                       importance=TRUE, tuneLength=1, ntree = 15, method = "rf",
#'                       trControl = trainControl(method="cv", number=5, savePredictions=T))
#' # variable importance is used for scaling predictors
#' plot(varImp(model,scale=FALSE))
#'
#' # calculate the DI of the trained model:
#' DI = trainDI(model=model)
#' plot(DI)
#'
#' #...or calculate the DI and LPD of the trained model:
#' # DI = trainDI(model=model, LPD = TRUE)
#'
#' # the DI can now be used to compute the AOA (here with LPD):
#' studyArea = rast(system.file("extdata/predictors_chile.tif", package = "CAST"))
#' AOA = aoa(studyArea, model = model, trainDI = DI, LPD = TRUE, maxLPD = 1)
#' print(AOA)
#' plot(AOA)
#' plot(AOA$AOA)
#' plot(AOA$LPD)
#' }
#'


trainDI <- function(model = NA,
                    train = NULL,
                    variables = "all",
                    weight = NA,
                    CVtest = NULL,
                    CVtrain = NULL,
                    method="L2",
                    useWeight = TRUE,
                    useCV =TRUE,
                    LPD = FALSE,
                    verbose = TRUE,
                    algorithm = "brute"){

  # get parameters if they are not provided in function call-----
  if(is.null(train)){train = aoa_get_train(model)}
  if(length(variables) == 1){
    if(variables == "all"){
      variables = aoa_get_variables(variables, model, train)
    }
  }
  if(is.na(weight)[1]){
    if(useWeight){
      weight = aoa_get_weights(model, variables = variables)
    }else{
      message("variable are not weighted. see ?aoa")
      weight <- t(data.frame(rep(1,length(variables))))
      names(weight) <- variables
    }
  }else{


    weight <- user_weights(weight, variables)

  }

  # get CV folds from model or from parameters
  folds <-  aoa_get_folds(model,CVtrain,CVtest,useCV)
  CVtest <- folds[[2]]
  CVtrain <- folds[[1]]

  # check for input errors -----
  if(nrow(train)<=1){stop("at least two training points need to be specified")}

  # reduce train to specified variables
  train <- train[,na.omit(match(variables, names(train)))]

  train_backup <- train

  # convert categorial variables
  catupdate <- aoa_categorial_train(train, variables, weight)

  train <- catupdate$train
  weight <- catupdate$weight

  # scale train
  train <- scale(train)

  # make sure all variables have variance
  if (any(apply(train, 2, FUN=function(x){all(is.na(x))}))){
    stop("some variables in train seem to have no variance")
  }

  # save scale param for later
  scaleparam <- attributes(train)


  # multiply train data with variable weights (from variable importance)
  if(!inherits(weight, "error")&!is.null(unlist(weight))){
    train <- sapply(1:ncol(train),function(x){train[,x]*unlist(weight[x])})
  }


  # calculate average mean distance between training data

  trainDist_avrg <- c()
  trainDist_min <- c()

  if(method=="MD"){
    if(dim(train)[2] == 1){
      S <- matrix(stats::var(train), 1, 1)
    } else {
      S <- stats::cov(train)
    }
    S_inv <- MASS::ginv(S)
  }

  if (verbose) {
    message("Computing DI of training data...")
    pb <- txtProgressBar(min = 0,
                         max = nrow(train),
                         style = 3)
  }

  for(i in seq(nrow(train))){

    # distance to all other training data (for average)
    ## redundant distance calculation (removed 13.03.24)
    #trainDistAll   <- .alldistfun(t(train[i,]), train,  method, S_inv=S_inv)[-1]
    #trainDist_avrg <- append(trainDist_avrg, mean(trainDistAll, na.rm = TRUE))

    # calculate  distance to other training data:
    trainDist      <- matrix(.alldistfun(t(matrix(train[i,])), train, method, sorted = FALSE, S_inv,algorithm=algorithm))
    trainDist[i]   <- NA
    trainDist_avrg <- append(trainDist_avrg, mean(trainDist, na.rm = TRUE))


    # mask of any data that are not used for training for the respective data point (using CV)
    whichfold <- NA
    if(!is.null(CVtrain)&!is.null(CVtest)){
      whichfold <-  as.numeric(which(lapply(CVtest,function(x){any(x==i)})==TRUE)) # index of the fold where i is held back
      if(length(whichfold)>1){stop("a datapoint is used for testing in more than one fold. currently this option is not implemented")}
      if(length(whichfold)!=0){ # in case that a data point is never used for testing
        trainDist[!seq(nrow(train))%in%CVtrain[[whichfold]]] <- NA # everything that is not in the training data for i is ignored
      }
      if(length(whichfold)==0){#in case that a data point is never used for testing, the distances for that point are ignored
        trainDist <- NA
      }
    }

    #######################################

    if (length(whichfold)==0){
      trainDist_min <- append(trainDist_min, NA)
    }else{
      trainDist_min <- append(trainDist_min, min(trainDist, na.rm = TRUE))
    }
    if (verbose) {
      setTxtProgressBar(pb, i)
    }
  }

  if (verbose) {
    close(pb)
  }
  trainDist_avrgmean <- mean(trainDist_avrg,na.rm=TRUE)



  # Dissimilarity Index of training data -----
  TrainDI <- trainDist_min/trainDist_avrgmean


  # AOA Threshold ----
  threshold_quantile <- stats::quantile(TrainDI, 0.75,na.rm=TRUE)
  threshold_iqr <- (1.5 * stats::IQR(TrainDI,na.rm=T))
  thres <- threshold_quantile + threshold_iqr
  # account for case that threshold_quantile + threshold_iqr is larger than maximum DI.
  if (thres>max(TrainDI,na.rm=T)){
    thres <- max(TrainDI,na.rm=T)
  }

  # note: previous versions of CAST derived the threshold this way:
  # thres <- grDevices::boxplot.stats(TrainDI)$stats[5]


  # calculate trainLPD and avrgLPD according to the CV folds
  if (LPD == TRUE) {
    if (verbose) {
      message("Computing LPD of training data...")
      pb <- txtProgressBar(min = 0,
                           max = nrow(train),
                           style = 3)
    }

    trainLPD <- c()
    for (j in  seq(nrow(train))) {

      # calculate  distance to other training data:
      trainDist      <- .alldistfun(t(matrix(train[j,])), train, method, sorted = FALSE, S_inv,algorithm=algorithm)
      DItrainDist <- trainDist/trainDist_avrgmean
      DItrainDist[j]   <- NA

      # mask of any data that are not used for training for the respective data point (using CV)
      whichfold <- NA
      if(!is.null(CVtrain)&!is.null(CVtest)){
        whichfold <- as.numeric(which(lapply(CVtest,function(x){any(x==j)})==TRUE)) # index of the fold where i is held back
        if(length(whichfold)>1){stop("a datapoint is used for testing in more than one fold. currently this option is not implemented")}
        if(length(whichfold)!=0){ # in case that a data point is never used for testing
          DItrainDist[!seq(nrow(train))%in%CVtrain[[whichfold]]] <- NA # everything that is not in the training data for i is ignored
        }
        if(length(whichfold)==0){#in case that a data point is never used for testing, the distances for that point are ignored
          DItrainDist <- NA
        }
      }

      #######################################

      if (length(whichfold)==0){
        trainLPD <- append(trainLPD, NA)
      } else {
        trainLPD <- append(trainLPD, sum(DItrainDist[,1] < thres, na.rm = TRUE))
      }
      if (verbose) {
        setTxtProgressBar(pb, j)
      }
    }

    if (verbose) {
      close(pb)
    }

    # Average LPD in trainData
    avrgLPD <- round(mean(trainLPD))
  }


  # Return: trainDI Object -------

  aoa_results = list(
    train = train_backup,
    weight = weight,
    variables = variables,
    catvars = catupdate$catvars,
    scaleparam = scaleparam,
    trainDist_avrg = trainDist_avrg,
    trainDist_avrgmean = trainDist_avrgmean,
    trainDI = TrainDI,
    threshold = thres,
    method = method
  )

  if (LPD == TRUE) {
    aoa_results$trainLPD <- trainLPD
    aoa_results$avrgLPD <- avrgLPD
  }

  class(aoa_results) = "trainDI"

  return(aoa_results)
}


################################################################################
# Helper functions
################################################################################
# Encode categorial variables

aoa_categorial_train <- function(train, variables, weight){

  # get all categorial variables
  catvars <- tryCatch(names(train)[which(sapply(train[,variables], class)%in%c("factor","character"))],
                      error=function(e) e)

  if (!inherits(catvars,"error")&length(catvars)>0){
    message("warning: predictors contain categorical variables. The integration is currently still under development. Please check results carefully!")

    for (catvar in catvars){
      # mask all unknown levels in newdata as NA (even technically no predictions can be made)
      train[,catvar]<-droplevels(train[,catvar])

      # then create dummy variables for the remaining levels in train:
      dvi_train <- predict(caret::dummyVars(paste0("~",catvar), data = train),
                           train)
      train <- data.frame(train,dvi_train)

      if(!inherits(weight, "error")){
        addweights <- data.frame(t(rep(weight[,which(names(weight)==catvar)],
                                       ncol(dvi_train))))
        names(addweights)<- colnames(dvi_train)
        weight <- data.frame(weight,addweights)
      }
    }
    if(!inherits(weight, "error")){
      weight <- weight[,-which(names(weight)%in%catvars)]
    }
    train <- train[,-which(names(train)%in%catvars)]
  }
  return(list(train = train, weight = weight, catvars = catvars))


}



# Get weights from train object


aoa_get_weights = function(model, variables){

  weight <- tryCatch(if(model$modelType=="Classification"){
    as.data.frame(t(apply(caret::varImp(model,scale=F)$importance,1,mean)))
  }else{
    as.data.frame(t(caret::varImp(model,scale=F)$importance[,"Overall"]))
  }, error=function(e) e)
  if(!inherits(weight, "error") & length(variables)>1){
    names(weight)<- rownames(caret::varImp(model,scale=F)$importance)
  }else{
    # set all weights to 1
    weight <- as.data.frame(t(rep(1, length(variables))))
    names(weight) = variables
    message("note: variables were not weighted either because no weights or model were given,
    no variable importance could be retrieved from the given model, or the model has a single feature.
    Check caret::varImp(model)")
  }

  #set negative weights to 0
  if(!inherits(weight, "error")){
    weight <- weight[,na.omit(match(variables, names(weight)))]
    if (any(weight<0)){
      weight[weight<0]<-0
      message("negative weights were set to 0")
    }
  }
  if(sum(weight)==0){
    stop("all weights are <=0, hence no variable is used. Check variable importance of the model, define weights manually or set useWeight=FALSE")
  }
  return(weight)

}



# check user weight input
# make sure this function outputs a data.frame with
# one row and columns named after the variables

user_weights = function(weight, variables){

  # list input support
  if(inherits(weight, "list")){
    # check if all list entries are in variables
    weight = as.data.frame(weight)
  }


  #check if manually given weights are correct. otherwise ignore (set to 1):
  if(nrow(weight)!=1  || !all(variables %in% names(weight))){
    message("variable weights are not correctly specified and will be ignored. See ?aoa")
    weight <- t(data.frame(rep(1,length(variables))))
    names(weight) <- variables
  }
  weight <- weight[,na.omit(match(variables, names(weight)))]
  if (any(weight<0)){
    weight[weight<0]<-0
    message("negative weights were set to 0")
  }

  return(weight)

}




# Get trainingdata from train object

aoa_get_train <- function(model){

  train <- as.data.frame(model$trainingData)
  return(train)


}


# Get folds from train object


aoa_get_folds <- function(model, CVtrain, CVtest, useCV){
  ### if folds are to be extracted from the model:
  if (useCV&!is.na(model)[1]){
    if(tolower(model$control$method)!="cv"){
      message("note: Either no model was given or no CV was used for model training. The DI threshold is therefore based on all training data")
    }else{
      CVtest <- model$control$indexOut
      CVtrain <- model$control$index
    }
  }
  ### if folds are specified manually:
  if(is.na(model)[1]){

    if(!is.null(CVtest)&!is.list(CVtest)){ # restructure input if CVtest only contains the fold ID
      tmp <- list()
      for (i in unique(CVtest)){
        tmp[[i]] <- which(CVtest==i)
      }
      CVtest <- tmp
    }

    if(is.null(CVtest)&is.null(CVtrain)){
      message("note: No model and no CV folds were given. The DI threshold is therefore based on all training data")
    }else{
      if(is.null(CVtest)){ # if CVtest is not given, then use the opposite of CVtrain
        CVtest <- lapply(CVtrain,function(x){which(!sort(unique(unlist(CVtrain)))%in%x)})
      }else{
        if(is.null(CVtrain)){ # if CVtrain is not given, then use the opposite of CVtest
          CVtrain <- lapply(CVtest,function(x){which(!sort(unique(unlist(CVtest)))%in%x)})
        }
      }
    }

  }
  if(!is.na(model)[1]&useCV==FALSE){
    message("note: useCV is set to FALSE. The DI threshold is therefore based on all training data")
    CVtrain <- NULL
    CVtest <- NULL
  }
  return(list(CVtrain,CVtest))
}






# Get variables from train object

aoa_get_variables <- function(variables, model, train){

  if(length(variables) == 1){
    if(variables == "all"){
      if(!is.na(model)[1]){
        variables <- names(model$trainingData)[-which(names(model$trainingData)==".outcome")]
      }else{
        variables <- names(train)
      }
    }
  }
  return(variables)


}



.mindistfun <- function(point, reference, method, S_inv=NULL,algorithm){

  if (method == "L2"){ # Euclidean Distance
    return(c(FNN::knnx.dist(reference, point, k = 1, algorithm = algorithm)))
  } else if (method == "MD"){ # Mahalanobis Distance
    return(sapply(1:dim(point)[1],
                  function(y) min(sapply(1:dim(reference)[1],
                                         function(x) sqrt( t(point[y,] - reference[x,]) %*% S_inv %*% (point[y,] - reference[x,]) )))))
  }
}

.alldistfun <- function(point, reference, method, sorted = TRUE,S_inv=NULL,algorithm){

  if (method == "L2"){ # Euclidean Distance
    if(sorted){
      return(FNN::knnx.dist(reference, point, k = dim(reference)[1], algorithm = algorithm))
    } else {
      return(FNN::knnx.dist(point,reference,k=1, algorithm=algorithm))
    }
  } else if (method == "MD"){ # Mahalanobis Distance
    if(sorted){
      return(t(sapply(1:dim(point)[1],
                      function(y) sort(sapply(1:dim(reference)[1],
                                              function(x) sqrt( t(point[y,] - reference[x,]) %*% S_inv %*% (point[y,] - reference[x,]) ))))))
    } else {
      return(t(sapply(1:dim(point)[1],
                      function(y) sapply(1:dim(reference)[1],
                                         function(x) sqrt( t(point[y,] - reference[x,]) %*% S_inv %*% (point[y,] - reference[x,]) )))))
    }
  }
}



library(testthat)
library(CAST)

test_check("CAST")


loaddata <- function() {
  # prepare sample data:
  data(cookfarm)
  dat <- aggregate(cookfarm[,c("VW","Easting","Northing")],by=list(as.character(cookfarm$SOURCEID)),mean)
  pts <- sf::st_as_sf(dat,coords=c("Easting","Northing"))
  pts$ID <- 1:nrow(pts)
  set.seed(100)
  pts <- pts[1:30,]
  studyArea <- terra::rast(system.file("extdata","predictors_2012-03-25.tif",package="CAST"))[[1:8]]
  trainDat <- terra::extract(studyArea,pts,na.rm=FALSE)
  trainDat <- merge(trainDat,pts,by.x="ID",by.y="ID")

  # train a model:
  set.seed(100)
  variables <- c("DEM","NDRE.Sd","TWI")
  ctrl <- caret::trainControl(method="cv",number=5,savePredictions=T)
  model <- caret::train(trainDat[,which(names(trainDat)%in%variables)],
                 trainDat$VW, method="rf", importance=TRUE, tuneLength=1,
                 trControl=ctrl)


  data <- list(
    studyArea = studyArea,
    trainDat = trainDat,
    variables = variables,
    model = model
  )

  return(data)
}


test_that("AOA works in default: used with raster data and a trained model", {
  skip_if_not_installed("randomForest")
  dat <- loaddata()
  # calculate the AOA of the trained model for the study area:
  AOA <- aoa(dat$studyArea, dat$model, verbose = F)

  #test threshold:
  expect_equal(as.numeric(round(AOA$parameters$threshold,5)), 0.38986)
  #test number of pixels within AOA:
  expect_equal(sum(terra::values(AOA$AOA)==1,na.rm=TRUE), 2936)
  # test trainDI
  expect_equal(AOA$parameters$trainDI, c(0.09043580, 0.14046341, 0.16584582, 0.57617177, 0.26840303,
                                         0.14353894, 0.19768329, 0.24022059, 0.06832037, 0.29150668,
                                         0.18471625, 0.57617177, 0.12344463, 0.09043580, 0.14353894,
                                         0.26896008, 0.22713731, 0.24022059, 0.20388725, 0.06832037,
                                         0.23604264, 0.20388725, 0.91513568, 0.09558666, 0.14046341,
                                         0.16214832, 0.37107762, 0.16214832, 0.18471625, 0.12344463))
  # test summary statistics of the DI
  expect_equal(as.vector(summary(terra::values(AOA$DI))),
               c("Min.   :0.0000  ", "1st Qu.:0.1329  ", "Median :0.2052  ",
                 "Mean   :0.2858  ", "3rd Qu.:0.3815  ",
                "Max.   :4.4485  ", "NA's   :1993  "))
})


test_that("AOA works without a trained model", {
  skip_if_not_installed("randomForest")
  dat <- loaddata()
  AOA <- aoa(dat$studyArea,train=dat$trainDat,variables=dat$variables, verbose = F)

  #test threshold:
  expect_equal(as.numeric(round(AOA$parameters$threshold,5)), 0.52872)
  #test number of pixels within AOA:
  expect_equal(sum(terra::values(AOA$AOA)==1,na.rm=TRUE), 3377)
  # test summary statistics of the DI
  expect_equal(as.vector(summary(terra::values(AOA$DI))),
               c("Min.   :0.0000  ", "1st Qu.:0.1759  ", "Median :0.2642  ",
                 "Mean   :0.3109  ", "3rd Qu.:0.4051  ",
                 "Max.   :2.6631  ", "NA's   :1993  "))
})

test_that("AOA (including LPD) works with raster data and a trained model", {
  skip_if_not_installed("randomForest")
  dat <- loaddata()
  # calculate the AOA of the trained model for the study area:
  AOA <- aoa(dat$studyArea, dat$model, LPD = TRUE, maxLPD = 1, verbose = F)

  #test threshold:
  expect_equal(as.numeric(round(AOA$parameters$threshold,5)), 0.38986)
  #test number of pixels within AOA:
  expect_equal(sum(terra::values(AOA$AOA)==1,na.rm=TRUE), 2936)
  #test trainLPD
  expect_equal(AOA$parameters$trainLPD, c(3, 4, 6, 0, 7,
                                          6, 2, 1, 5, 3,
                                          4, 0, 1, 2, 6,
                                          5, 4, 4, 5, 7,
                                          3, 4, 0, 2, 3,
                                          6, 1, 7, 3, 2))
  # test summary statistics of the DI
  expect_equal(as.vector(summary(terra::values(AOA$DI))),
               c("Min.   :0.0000  ", "1st Qu.:0.1329  ", "Median :0.2052  ",
                 "Mean   :0.2858  ", "3rd Qu.:0.3815  ",
                 "Max.   :4.4485  ", "NA's   :1993  "))
})


test_that("AOA (inluding LPD) works without a trained model", {
  skip_if_not_installed("randomForest")
  dat <- loaddata()
  AOA <- aoa(dat$studyArea,train=dat$trainDat,variables=dat$variables, LPD = TRUE, maxLPD = 1, verbose = F)

  #test threshold:
  expect_equal(as.numeric(round(AOA$parameters$threshold,5)), 0.52872)
  #test number of pixels within AOA:
  expect_equal(sum(terra::values(AOA$AOA)==1,na.rm=TRUE), 3377)
  # test trainLPD
  expect_equal(AOA$parameters$trainLPD, c(7, 9, 12, 1, 12,
                                          12, 4, 2, 8, 10,
                                          6, 1, 3,4, 11,
                                          9, 9, 7, 5, 5,
                                          6, 5, 0, 5, 9,
                                          8, 4, 11, 3,2))
  # test summary statistics of the DI
  expect_equal(as.vector(summary(terra::values(AOA$DI))),
               c("Min.   :0.0000  ", "1st Qu.:0.1759  ", "Median :0.2642  ",
                 "Mean   :0.3109  ", "3rd Qu.:0.4051  ",
                 "Max.   :2.6631  ", "NA's   :1993  "))
})


test_that("AOA (including LPD) works in parallel with raster data and a trained model", {
  skip_if_not_installed("randomForest")
  dat <- loaddata()
  # calculate the AOA of the trained model for the study area:
  AOA <- aoa(dat$studyArea, dat$model, LPD = TRUE, maxLPD = 1, verbose = F, parallel = TRUE, cores = 2) # limit to 2 cores

  #test threshold:
  expect_equal(as.numeric(round(AOA$parameters$threshold,5)), 0.38986)
  #test number of pixels within AOA:
  expect_equal(sum(terra::values(AOA$AOA)==1,na.rm=TRUE), 2936)
  #test trainLPD
  expect_equal(AOA$parameters$trainLPD, c(3, 4, 6, 0, 7,
                                          6, 2, 1, 5, 3,
                                          4, 0, 1, 2, 6,
                                          5, 4, 4, 5, 7,
                                          3, 4, 0, 2, 3,
                                          6, 1, 7, 3, 2))
  # test summary statistics of the DI
  expect_equal(as.vector(summary(terra::values(AOA$DI))),
               c("Min.   :0.0000  ", "1st Qu.:0.1329  ", "Median :0.2052  ",
                 "Mean   :0.2858  ", "3rd Qu.:0.3815  ",
                 "Max.   :4.4485  ", "NA's   :1993  "))
})


test_that("AOA (inluding LPD) works in parallel without a trained model", {
  skip_if_not_installed("randomForest")
  dat <- loaddata()
  AOA <- aoa(dat$studyArea,train=dat$trainDat,variables=dat$variables, LPD = TRUE, maxLPD = 1, verbose = F, parallel = TRUE, cores = 2) # limit to 2 cores

  #test threshold:
  expect_equal(as.numeric(round(AOA$parameters$threshold,5)), 0.52872)
  #test number of pixels within AOA:
  expect_equal(sum(terra::values(AOA$AOA)==1,na.rm=TRUE), 3377)
  # test trainLPD
  expect_equal(AOA$parameters$trainLPD, c(7, 9, 12, 1, 12,
                                          12, 4, 2, 8, 10,
                                          6, 1, 3,4, 11,
                                          9, 9, 7, 5, 5,
                                          6, 5, 0, 5, 9,
                                          8, 4, 11, 3,2))
  # test summary statistics of the DI
  expect_equal(as.vector(summary(terra::values(AOA$DI))),
               c("Min.   :0.0000  ", "1st Qu.:0.1759  ", "Median :0.2642  ",
                 "Mean   :0.3109  ", "3rd Qu.:0.4051  ",
                 "Max.   :2.6631  ", "NA's   :1993  "))
})



test_that("errorProfiles works in default settings", {
  skip_on_cran()
  skip_on_os("mac", arch = "aarch64")
  skip_if_not_installed("randomForest")
  skip_if_not_installed("scam")
  data(splotdata)
  splotdata <- sf::st_drop_geometry(splotdata)
  predictors <- terra::rast(system.file("extdata","predictors_chile.tif", package="CAST"))

  set.seed(100)
  model <- caret::train(splotdata[,6:16], splotdata$Species_richness, ntree = 10,
                        trControl = caret::trainControl(method = "cv", savePredictions = TRUE))

  AOA <- CAST::aoa(predictors, model,verbose=F)

  # DI ~ error
  errormodel_DI <- CAST::errorProfiles(model, AOA, variable = "DI")

  expected_error_DI = terra::predict(AOA$DI, errormodel_DI)



  #test model fit:
  expect_equal(round(as.numeric(summary(errormodel_DI$fitted.values)),2),
               c(14.25, 14.34, 15.21, 17.23, 18.70, 27.46))
  # test model predictions
  expect_equal(as.vector(summary(terra::values(expected_error_DI))),
               c("Min.   :14.26  ", "1st Qu.:27.46  ", "Median :27.46  ",
                 "Mean   :26.81  ", "3rd Qu.:27.46  ","Max.   :27.47  ",
                 "NA's   :17678  "))
})


test_that("errorProfiles works in with LPD", {
  skip_on_cran()
  skip_on_os("mac", arch = "aarch64")
  skip_if_not_installed("randomForest")
  skip_if_not_installed("scam")
  data(splotdata)
  splotdata <- sf::st_drop_geometry(splotdata)
  predictors <- terra::rast(system.file("extdata","predictors_chile.tif", package="CAST"))

  set.seed(100)
  model <- caret::train(splotdata[,6:16], splotdata$Species_richness, ntree = 10,
                        trControl = caret::trainControl(method = "cv", savePredictions = TRUE))

  AOA <- CAST::aoa(predictors, model, LPD = TRUE, maxLPD = 1,verbose=F)
  errormodel_LPD <- CAST::errorProfiles(model, AOA, variable = "LPD")
  expected_error_LPD = terra::predict(AOA$LPD, errormodel_LPD)


  #test model fit:
  expect_equal(round(as.numeric(summary(errormodel_LPD$fitted.values)),2),
               c(16.36, 16.36, 16.36, 16.36, 16.36, 16.36))
  # test model predictions
  expect_equal(as.vector(summary(terra::values(expected_error_LPD))),
               c("Min.   :16.36  ", "1st Qu.:16.36  ", "Median :16.36  ",
                 "Mean   :16.36  ", "3rd Qu.:16.36  ",
                 "Max.   :16.36  ", "NA's   :17678  "))

})



test_that("errorProfiles works for multiCV", {
  skip_on_cran()
  skip_on_os("mac", arch = "aarch64")
  skip_if_not_installed("randomForest")
  skip_if_not_installed("scam")
  data(splotdata)
  splotdata <- sf::st_drop_geometry(splotdata)
  predictors <- terra::rast(system.file("extdata","predictors_chile.tif", package="CAST"))

  set.seed(100)
  model <- caret::train(splotdata[,6:16], splotdata$Species_richness, ntree = 10,
                        trControl = caret::trainControl(method = "cv", savePredictions = TRUE))

  AOA <- CAST::aoa(predictors, model,verbose=F)
  set.seed(100)
  errormodel_DI = suppressWarnings(errorProfiles(model, AOA, multiCV = TRUE, length.out = 3))
  expected_error_DI = terra::predict(AOA$DI, errormodel_DI)

  #test model fit:
  expect_equal(round(as.numeric(summary(errormodel_DI$fitted.values)),2),
               c(12.53, 17.21, 26.80, 26.19, 35.28, 35.30))
  # test model predictions
  expect_equal(as.vector( summary(terra::values(expected_error_DI))),
               c("Min.   :13.11  ", "1st Qu.:32.58  ", "Median :35.05  ",
                 "Mean   :32.54  ", "3rd Qu.:35.30  ",
                 "Max.   :35.30  ", "NA's   :17678  "))


})



test_that("ffs works with default arguments and the splotopen dataset (numerical only)",{
  skip_on_cran()
  skip_on_os("mac", arch = "aarch64")
  skip_if_not_installed("randomForest")
  data("splotdata")
  splotdata = splotdata |> sf::st_drop_geometry()
  set.seed(1)
  selection = ffs(predictors = splotdata[,6:12],
                  response = splotdata$Species_richness,
                  seed = 1,
                  verbose = FALSE,
                  ntree = 5,
                  tuneLength = 1)


  expect_identical(selection$selectedvars, c("bio_6", "bio_12", "bio_5", "bio_4"))
  expect_identical(selection$metric, "RMSE")
  expect_identical(selection$maximize, FALSE)

})




test_that("ffs works with default arguments and the splotopen dataset (include categorial)",{
  skip_on_cran()
  skip_on_os("mac", arch = "aarch64")
  skip_if_not_installed("randomForest")
  data("splotdata")
  splotdata = splotdata |> sf::st_drop_geometry()
  set.seed(1)
  selection = ffs(predictors = splotdata[,c(4,6:12)],
                  response = splotdata$Species_richness,
                  verbose = FALSE,
                  seed = 1,
                  ntree = 5,
                  tuneLength = 1)

  expect_identical(selection$selectedvars, c("bio_6", "bio_12", "Biome","bio_1" , "bio_5"))
  expect_identical(selection$metric, "RMSE")
  expect_identical(selection$maximize, FALSE)
})


test_that("ffs works for classification with default arguments",{
  skip_on_cran()
  skip_on_os("mac", arch = "aarch64")
  skip_if_not_installed("randomForest")
  data("splotdata")
  splotdata = splotdata |> sf::st_drop_geometry()
  splotdata$Biome = droplevels(splotdata$Biome)
  set.seed(1)
  selection = ffs(predictors = splotdata[,c(6:12)],
                  response = splotdata$Biome,
                  verbose = FALSE,
                  seed = 1,
                  ntree = 5,
                  tuneLength = 1)

  expect_identical(selection$selectedvars, c("bio_4", "bio_8",  "bio_12",
                                             "bio_9"))
  expect_identical(selection$metric, "Accuracy")
  expect_identical(selection$maximize, TRUE)

})


test_that("ffs works for withinSE = TRUE",{
  skip_on_cran()
  skip_on_os("mac", arch = "aarch64")
  skip_if_not_installed("randomForest")
  data("splotdata")
  splotdata = splotdata |> sf::st_drop_geometry()
  splotdata$Biome = droplevels(splotdata$Biome)
  set.seed(1)
  selection = ffs(predictors = splotdata[,c(6:16)],
                  response = splotdata$Biome,
                  seed = 1,
                  verbose = FALSE,
                  ntree = 5,
                  withinSE = TRUE,
                  tuneLength = 1)

    expect_identical(selection$selectedvars, c("bio_4", "bio_8",  "bio_12",
                                               "bio_13","bio_14", "bio_5"))

  })











  ## Iris tests that should fail if implemented new


  test_that("ffs works with default arguments and the iris dataset",{
    skip_if_not_installed("randomForest")
    data(iris)
    set.seed(1)
    selection = ffs(predictors = iris[,1:4],
                    response = iris$Species,
                    seed = 1)

    expect_identical(selection$selectedvars, c("Petal.Length", "Petal.Width", "Sepal.Width"))
    expect_equal(selection$selectedvars_perf, c(0.9530141, 0.9544820, 0.9544820),
                 tolerance = 0.05)

  })



  test_that("ffs works with globalVal = TRUE", {
    skip_on_cran()
    skip_if_not_installed("randomForest")
    data(iris)
    set.seed(1)
    selection = ffs(predictors = iris[,1:4],
                    response = iris$Species,
                    seed = 1,
                    globalval = TRUE)

    expect_identical(selection$selectedvars, c("Petal.Length", "Petal.Width", "Sepal.Width"))
    expect_equal(selection$selectedvars_perf, c("Accuracy" = 0.9530792,"Accuracy" = 0.9545455,"Accuracy" = 0.9545455 ), tolerance = 0.005)

  })

  test_that("ffs works with withinSE = TRUE", {
    skip_on_cran()
    skip_if_not_installed("randomForest")
    data(iris)
    set.seed(1)
    selection = ffs(predictors = iris[,1:4],
                    response = iris$Species,
                    seed = 1,
                    withinSE = TRUE)


    expect_identical(selection$selectedvars, c("Petal.Length", "Petal.Width"))
    expect_equal(selection$selectedvars_perf, c(0.9530141), tolerance = 0.005)

  })


  test_that("ffs fails with minvar set to maximum", {
    skip_on_cran()
    skip_if_not_installed("randomForest")
    data(iris)
    set.seed(1)
    expect_error(ffs(predictors = iris[,1:4],
                     response = iris$Species,
                     seed = 1,
                     minVar = 4), regexp = ".*undefined columns selected")



  })










test_that("geodist works with points and polygon in geographic space", {
  skip_if_not_installed("rnaturalearth")
  data(splotdata)
  studyArea <- rnaturalearth::ne_countries(continent = "South America", returnclass = "sf")
  set.seed(1)
  folds <- data.frame("folds"=sample(1:3, nrow(splotdata), replace=TRUE))
  folds <- CreateSpacetimeFolds(folds, spacevar="folds", k=3)

  dist_geo <- geodist(x=splotdata,
                      modeldomain=studyArea,
                      cvfolds=folds$indexOut,
                      type = "geo")

  mean_sample2sample <- round(mean(dist_geo[dist_geo$what=="sample-to-sample","dist"]))
  mean_CV_distances <- round(mean(dist_geo[dist_geo$what=="CV-distances","dist"]))
  nrow_dist <- nrow(dist_geo)

  expect_equal(mean_sample2sample, 20321)
  expect_equal(mean_CV_distances, 25616)
  expect_equal(nrow_dist, 3410)


})

test_that("geodist works with points and polygon in feature space", {
  skip_if_not_installed("rnaturalearth")
  data(splotdata)
  studyArea <- rnaturalearth::ne_countries(continent = "South America", returnclass = "sf")
  set.seed(1)
  folds <- data.frame("folds"=sample(1:3, nrow(splotdata), replace=TRUE))
  folds <- CreateSpacetimeFolds(folds, spacevar="folds", k=3)
  predictors <- terra::rast(system.file("extdata","predictors_chile.tif", package="CAST"))

  dist_fspace <- geodist(x = splotdata,
                         modeldomain = predictors,
                         cvfolds=folds$indexOut,
                         type = "feature",
                         variables = c("bio_1","bio_12", "elev"))

  mean_sample2sample <- round(mean(dist_fspace[dist_fspace$what=="sample-to-sample","dist"]), 4)
  mean_CV_distances <- round(mean(dist_fspace[dist_fspace$what=="CV-distances","dist"]), 4)

  expect_equal(mean_sample2sample, 0.0843)
  expect_equal(mean_CV_distances, 0.1036)

})


test_that("geodist works space with points and preddata in geographic space", {
  aoi <- sf::st_as_sfc("POLYGON ((0 0, 10 0, 10 10, 0 10, 0 0))", crs="epsg:25832")
  tpoints <- sf::st_as_sfc("MULTIPOINT ((1 1), (1 2), (2 2), (2 3), (1 4), (5 4))", crs="epsg:25832") |>
    sf::st_cast("POINT")
  set.seed(1)
  ppoints <- suppressWarnings(sf::st_sample(aoi, 20, type="regular")) |>
    sf::st_set_crs("epsg:25832")

  set.seed(1)
  folds <- data.frame("folds"=sample(1:3, length(tpoints), replace=TRUE))
  folds <- CreateSpacetimeFolds(folds, spacevar="folds", k=3)

  dist_geo <- geodist(x=tpoints,
                      modeldomain=aoi,
                      preddata=ppoints,
                      type = "geo")

  mean_sample2sample <- round(mean(dist_geo[dist_geo$what=="sample-to-sample","dist"]), 4)
  mean_prediction_to_sample <- round(mean(dist_geo[dist_geo$what=="prediction-to-sample","dist"]), 4)

  expect_equal(mean_sample2sample, 1.4274)
  expect_equal(mean_prediction_to_sample, 2.9402)


})


test_that("geodist works with points and preddata in feature space", {
  aoi <- sf::st_as_sfc("POLYGON ((0 0, 10 0, 10 10, 0 10, 0 0))", crs="epsg:25832")
  tpoints <- sf::st_as_sfc("MULTIPOINT ((1.5 1.5), (1.5 2.5), (2.5 2.5), (2.5 3.5), (1.5 4.5), (5.5 4.5))", crs="epsg:25832") |>
    sf::st_cast("POINT")
  set.seed(1)
  ppoints <- suppressWarnings(sf::st_sample(aoi, 20, type="regular")) |>
    sf::st_set_crs("epsg:25832")

  raster <- terra::rast(nrows=10, ncols=10, nlyrs=1, xmin=0, xmax=10,
                        ymin=0, ymax=10, crs="epsg:25832", vals=1:100)

  dist <- geodist(x=tpoints,
                  modeldomain=raster,
                  preddata=ppoints,
                  type = "feature")

  mean_sample2sample <- round(mean(dist[dist$what=="sample-to-sample","dist"]), 4)
  mean_prediction_to_sample <- round(mean(dist[dist$what=="prediction-to-sample","dist"]), 4)

  expect_equal(mean_sample2sample, 0.3814)
  expect_equal(mean_prediction_to_sample, 1.0783)


})


test_that("geodist works with points and raster in geographic space", {
  tpoints <- sf::st_as_sfc("MULTIPOINT ((1 1), (1 2), (2 2), (2 3), (1 4), (5 4))", crs="epsg:25832") |>
    sf::st_cast("POINT")

  raster <- terra::rast(nrows=10, ncols=10, nlyrs=1, xmin=0, xmax=10,
                        ymin=0, ymax=10, crs="epsg:25832", vals=1:100)

  dist <- geodist(x=tpoints,
                  modeldomain=raster,
                  type = "geo")

  mean_sample2sample <- round(mean(dist[dist$what=="sample-to-sample","dist"]), 4)
  expect_equal(mean_sample2sample, 1.4274)


})


test_that("geodist works with points and raster in feature space", {
  tpoints <- sf::st_as_sfc("MULTIPOINT ((1.5 1.5), (1.5 2.5), (2.5 2.5), (2.5 3.5), (1.5 4.5), (5.5 4.5))", crs="epsg:25832") |>
    sf::st_cast("POINT")

  raster <- terra::rast(nrows=10, ncols=10, nlyrs=1, xmin=0, xmax=10,
                        ymin=0, ymax=10, crs="epsg:25832", vals=1:100)

  dist <- geodist(x=tpoints,
                  modeldomain=raster,
                  type = "feature")

  mean_sample2sample <- round(mean(dist[dist$what=="sample-to-sample","dist"]), 4)
  expect_equal(mean_sample2sample, 0.3814)


})


test_that("geodist works with points and stars raster in geographic space", {
  skip_if_not_installed("stars")
  tpoints <- sf::st_as_sfc("MULTIPOINT ((1.5 1.5), (1.5 2.5), (2.5 2.5), (2.5 3.5), (1.5 4.5), (5.5 4.5))", crs="epsg:25832") |>
    sf::st_cast("POINT")

  raster <- terra::rast(nrows=10, ncols=10, nlyrs=1, xmin=0, xmax=10,
                        ymin=0, ymax=10, crs="epsg:25832", vals=1:100) |>
    stars::st_as_stars()

  dist <- geodist(x=tpoints,
                  modeldomain=raster,
                  type = "feature")

  mean_sample2sample <- round(mean(dist[dist$what=="sample-to-sample","dist"]), 4)
  expect_equal(mean_sample2sample, 0.3814)


})



test_that("geodist works with points and test data in geographic space", {
  aoi <- sf::st_as_sfc("POLYGON ((0 0, 10 0, 10 10, 0 10, 0 0))", crs="epsg:25832")
  tpoints <- sf::st_as_sfc("MULTIPOINT ((1 1), (1 2), (2 2), (2 3), (1 4), (5 4))", crs="epsg:25832") |>
    sf::st_cast("POINT")

  set.seed(1)
  test_point <- suppressWarnings(sf::st_sample(aoi, 20, type="regular")) |>
    sf::st_set_crs("epsg:25832")

  dist <- geodist(x=tpoints,
                  modeldomain=aoi,
                  testdata=test_point,
                  type = "geo")

  mean_sample2sample <- round(mean(dist[dist$what=="sample-to-sample","dist"]), 4)
  mean_test_to_sample <- round(mean(dist[dist$what=="test-to-sample","dist"]), 4)

  expect_equal(mean_sample2sample, 1.4274)
  expect_equal(mean_test_to_sample, 2.9402)



})


test_that("geodist works with points and test data in feature space", {
  aoi <- sf::st_as_sfc("POLYGON ((0 0, 10 0, 10 10, 0 10, 0 0))", crs="epsg:25832")

  raster <- terra::rast(nrows=10, ncols=10, nlyrs=1, xmin=0, xmax=10,
                        ymin=0, ymax=10, crs="epsg:25832", vals=1:100)

  tpoints <- sf::st_as_sfc("MULTIPOINT ((1 1), (1 2), (2 2), (2 3), (1 4), (5 4))", crs="epsg:25832") |>
    sf::st_cast("POINT")

  set.seed(1)
  test_points <- suppressWarnings(sf::st_sample(aoi, 20, type="random")) |>
    sf::st_set_crs("epsg:25832")

  dist <- geodist(x=tpoints,
                  modeldomain=raster,
                  testdata = test_points,
                  type = "feature")

  mean_sample2sample <- round(mean(dist[dist$what=="sample-to-sample","dist"]), 4)
  mean_test_to_sample <- round(mean(dist[dist$what=="test-to-sample","dist"]), 4)

  expect_equal(mean_sample2sample, 0.3814)
  expect_equal(mean_test_to_sample, 1.4524)


})


test_that("geodist works with categorical variables in feature space", {
  set.seed(1234)
  predictor_stack <- terra::rast(system.file("extdata","predictors_2012-03-25.tif",package="CAST"))
  predictors <- c("DEM","TWI", "NDRE.M", "Easting", "Northing", "fct")
  predictor_stack$fct <- factor(c(rep(LETTERS[1], terra::ncell(predictor_stack)/2),
                                  rep(LETTERS[2], terra::ncell(predictor_stack)/2)))

  predictor_stack <- predictor_stack[[predictors]]
  studyArea <- predictor_stack
  studyArea[!is.na(studyArea)] <- 1
  studyArea <- terra::as.polygons(studyArea, values = FALSE, na.all = TRUE) |>
    sf::st_as_sf() |>
    sf::st_union()

  pts <- clustered_sample(studyArea, 30, 5, 60)
  pts <- sf::st_transform(pts, crs = sf::st_crs(studyArea))
  pts <- terra::extract(predictor_stack, terra::vect(pts), ID=FALSE, bind=TRUE) |>
    sf::st_as_sf()

  test_pts <- clustered_sample(studyArea, 50, 5, 20)

  folds <- data.frame("folds"=sample(1:3, nrow(pts), replace=TRUE))
  folds <- CreateSpacetimeFolds(folds, spacevar="folds", k=3)

  sf::st_as_sf(terra::extract(predictor_stack,terra::vect(pts$geometry),bind=TRUE))

  dist <- geodist(x=pts,
                  modeldomain=predictor_stack,
                  type = "feature",
                  testdata = test_pts,
                  cvfolds = folds$indexOut)

  mean_sample2sample <- round(mean(dist[dist$what=="sample-to-sample","dist"]), 4)
  mean_prediction2sample <- round(mean(dist[dist$what=="prediction-to-sample","dist"]), 4)
  mean_test2sample <- round(mean(dist[dist$what=="test-to-sample","dist"]), 4)
  mean_CV_distance <- round(mean(dist[dist$what=="CV-distances","dist"]), 4)

  expect_equal(mean_sample2sample, 0.0459)
  expect_equal(mean_prediction2sample, 0.1625) #0.1625
  expect_equal(mean_test2sample, 0.2358)
  expect_equal(mean_CV_distance, 0.0663)
})


test_that("geodist works in temporal space", {
  data(cookfarm)
  dat <- sf::st_as_sf(cookfarm,coords=c("Easting","Northing"))
  sf::st_crs(dat) <- 26911
  trainDat <- dat[dat$altitude==-0.3&lubridate::year(dat$Date)==2010,]
  predictionDat <- dat[dat$altitude==-0.3&lubridate::year(dat$Date)==2011,]
  dist <- CAST::geodist(trainDat,preddata = predictionDat,type="time",time_unit="days")

  mean_sample2sample <- round(mean(dist[dist$what=="sample-to-sample","dist"]), 4)
  mean_prediction_to_sample <- round(mean(dist[dist$what=="prediction-to-sample","dist"]), 4)

  expect_equal(mean_sample2sample, 0.02)
  expect_equal(mean_prediction_to_sample, 194.7656)

  dist <- CAST::geodist(trainDat,preddata = predictionDat,type="time",time_unit="hours")
  mean_prediction_to_sample <- round(mean(dist[dist$what=="prediction-to-sample","dist"]), 2)
  expect_equal(mean_prediction_to_sample, 4674.37)

})

test_that("geodist works in temporal space and with CV", {
  data(cookfarm)
  dat <- sf::st_as_sf(cookfarm,coords=c("Easting","Northing"))
  sf::st_crs(dat) <- 26911
  trainDat <- dat[dat$altitude==-0.3&lubridate::year(dat$Date)==2010,]
  predictionDat <- dat[dat$altitude==-0.3&lubridate::year(dat$Date)==2011,]
  trainDat$week <- lubridate::week(trainDat$Date)
  set.seed(100)
  cvfolds <- CreateSpacetimeFolds(trainDat,timevar = "week")

  dist <- CAST::geodist(trainDat,preddata = predictionDat,cvfolds = cvfolds$indexOut,
                        type="time",time_unit="days")

  mean_cv <- round(mean(dist[dist$what=="CV-distances","dist"]), 4)

  expect_equal(mean_cv,  2.4048)
}
)




test_that("global_validation correctly handles missing predictions", {
  skip_if_not_installed("randomForest")
  data("iris")
  set.seed(123)
  ctrl <- caret::trainControl(method="cv")
  model <- caret::train(iris[,c("Sepal.Width", "Petal.Length", "Petal.Width")],
                        iris[,c("Sepal.Length")],
                        method="rf", trControl=ctrl, ntree=10)
  expect_error(global_validation(model))
})

test_that("global_validation works with caret regression", {
  skip_if_not_installed("randomForest")
  data("iris")
  set.seed(123)
  ctrl <- caret::trainControl(method="cv", savePredictions="final")
  model <- caret::train(iris[,c("Sepal.Width", "Petal.Length", "Petal.Width")],
                        iris[,c("Sepal.Length")],
                        method="rf", trControl=ctrl, ntree=10)
  expect_equal(global_validation(model),
               c("RMSE"=0.3307870, "Rsquared"=0.8400544, "MAE"=0.2621827),
               tolerance = 0.02)

})

test_that("global_validation works with caret classification", {
  skip_if_not_installed("randomForest")
  data("iris")
  set.seed(123)
  ctrl <- caret::trainControl(method="cv", savePredictions="final")
  model <- caret::train(iris[,c("Sepal.Width", "Petal.Length", "Petal.Width", "Sepal.Length")],
                        iris[,c("Species")],
                        method="rf", trControl=ctrl, ntree=10)
  expect_equal(global_validation(model)[1:2],
               c("Accuracy"=0.96, "Kappa"=0.94),
               tolerance = 0.02)

})

test_that("global_validation works with CreateSpacetimeFolds", {
  skip_if_not_installed("randomForest")
  data("iris")
  set.seed(123)
  iris$folds <- sample(rep(1:10, ceiling(nrow(iris)/10)), nrow(iris))
  indices <- CreateSpacetimeFolds(iris, "folds")
  ctrl <- caret::trainControl(method="cv", savePredictions="final", index = indices$index)
  model <- caret::train(iris[,c("Sepal.Width", "Petal.Length", "Petal.Width", "Sepal.Length")],
                        iris[,c("Species")],
                        method="rf", trControl=ctrl, ntree=10)
  expect_equal(global_validation(model)[1:2],
               c("Accuracy"=0.96, "Kappa"=0.94),
               tolerance = 0.02)
})


test_that("kNNDM works with geographical coordinates and prediction points", {
  sf::sf_use_s2(TRUE)
  aoi <- sf::st_as_sfc("POLYGON ((0 0, 10 0, 10 10, 0 10, 0 0))", crs="epsg:4326")
  tpoints <- sf::st_as_sfc("MULTIPOINT ((1 1), (1 2), (2 2), (2 3), (1 4), (5 4))", crs="epsg:4326") |>
    sf::st_cast("POINT")
  set.seed(1)
  predpoints <- suppressWarnings(sf::st_sample(aoi, 20, type="regular")) |>
    sf::st_set_crs("epsg:4326")

  set.seed(1)
  kout <- knndm(tpoints, predpoints=predpoints, k=2, maxp=0.8)
  expect_identical(round(kout$W,1), 121095.2)
  expect_identical(kout$method, "hierarchical")
  expect_identical(kout$q, 3L)

})

test_that("kNNDM works with projected coordinates and prediction points", {
  aoi <- sf::st_as_sfc("POLYGON ((0 0, 10 0, 10 10, 0 10, 0 0))", crs="epsg:25832")
  tpoints <- sf::st_as_sfc("MULTIPOINT ((1 1), (1 2), (2 2), (2 3), (1 4), (5 4))", crs="epsg:25832") |>
    sf::st_cast("POINT")
  set.seed(1)
  predpoints <- sf::st_sample(aoi, 20, type="regular") |>
    sf::st_set_crs("epsg:25832")

  set.seed(1)
  kout <- knndm(tpoints, predpoints=predpoints, k=2, maxp=0.8, clustering = "kmeans")

  expect_identical(round(kout$W,4), 1.0919)
  expect_identical(kout$method, "kmeans")
  expect_identical(kout$q, 4L)

})

test_that("kNNDM works without crs and prediction points", {
  aoi <- sf::st_as_sfc("POLYGON ((0 0, 10 0, 10 10, 0 10, 0 0))")
  tpoints <- sf::st_cast(sf::st_as_sfc("MULTIPOINT ((1 1), (1 2), (2 2), (2 3), (1 4), (5 4))"), "POINT")
  set.seed(1)
  predpoints <- sf::st_sample(aoi, 20, type="regular")

  set.seed(1)
  kout <- suppressWarnings(knndm(tpoints, predpoints=predpoints, k=2, maxp=0.8))

  expect_identical(round(kout$W,6), 1.091896)
  expect_identical(kout$q, 3L)

  expect_warning(knndm(tpoints, predpoints=predpoints, k=2, maxp=0.8),
                 "Missing CRS in training or prediction points. Assuming projected CRS.")

})


test_that("kNNDM works with modeldomain and projected coordinates", {
  aoi <- sf::st_as_sfc("POLYGON ((0 0, 10 0, 10 10, 0 10, 0 0))", crs="epsg:25832")
  tpoints <- sf::st_as_sfc("MULTIPOINT ((1 1), (1 2), (2 2), (2 3), (1 4), (5 4))", crs="epsg:25832") |>
    sf::st_cast("POINT")

  set.seed(1)
  kout <- suppressMessages(knndm(tpoints, modeldomain = aoi, k=2, maxp=0.8, clustering = "kmeans"))

  expect_identical(round(kout$W,4), 1.2004)
  expect_identical(kout$method, "kmeans")
  expect_identical(kout$q, 4L)

  expect_message(knndm(tpoints, modeldomain = aoi, k=2, maxp=0.8, clustering = "kmeans"),
                 "1000 prediction points are sampled from the modeldomain")

})

test_that("kNNDM works with modeldomain and geographical coordinates", {
  sf::sf_use_s2(TRUE)
  aoi <- sf::st_as_sfc("POLYGON ((0 0, 10 0, 10 10, 0 10, 0 0))", crs="epsg:4326")
  tpoints <- sf::st_as_sfc("MULTIPOINT ((1 1), (1 2), (2 2), (2 3), (1 4), (5 4))", crs="epsg:4326") |>
    sf::st_cast("POINT")

  set.seed(1)
  kout <- suppressMessages(knndm(tpoints, modeldomain = aoi, k=2, maxp=0.8, clustering = "hierarchical"))

  expect_identical(round(kout$W,4), 133187.4275)
  expect_identical(kout$method, "hierarchical")
  expect_identical(kout$q, 3L)

  expect_message(knndm(tpoints, modeldomain = aoi, k=2, maxp=0.8, clustering = "hierarchical"),
                 "1000 prediction points are sampled from the modeldomain")

})

test_that("kNNDM works with modeldomain and no crs", {
  aoi <- sf::st_as_sfc("POLYGON ((0 0, 10 0, 10 10, 0 10, 0 0))")
  tpoints <- sf::st_as_sfc("MULTIPOINT ((1 1), (1 2), (2 2), (2 3), (1 4), (5 4))") |>
    sf::st_cast("POINT")

  set.seed(1)
  kout <- suppressWarnings(suppressMessages(knndm(tpoints, modeldomain = aoi, k=2, maxp=0.8)))

  expect_identical(round(kout$W,4), 1.2004)
  expect_identical(kout$method, "hierarchical")
  expect_identical(kout$q, 3L)

  expect_message(suppressWarnings(knndm(tpoints, modeldomain = aoi, k=2, maxp=0.8)),
                 "1000 prediction points are sampled from the modeldomain")

})

test_that("kNNDM works when no clustering is present", {
  sf::sf_use_s2(TRUE)
  aoi <- sf::st_as_sfc("POLYGON ((0 0, 10 0, 10 10, 0 10, 0 0))", crs="epsg:25832")

  set.seed(1)
  tpoints <- sf::st_sample(aoi, 10)

  set.seed(1)
  predpoints <- sf::st_sample(aoi, 20, type="regular")

  set.seed(1)
  kout <- suppressMessages(knndm(tpoints, predpoints = predpoints, k=2, maxp=0.8, clustering = "kmeans"))
  expect_equal(kout$q, "random CV")

  # for geographical coordinates
  set.seed(1)
  kout <- suppressMessages(knndm(sf::st_transform(tpoints,"epsg:4326"),
                                 predpoints = sf::st_transform(predpoints, "epsg:4326"),
                                 k=2, maxp=0.8, clustering = "hierarchical"))
  expect_equal(kout$q, "random CV")
})


test_that("kNNDM works with many points and different configurations", {
  aoi <- sf::st_as_sfc("POLYGON ((0 0, 10 0, 10 10, 0 10, 0 0))", crs="epsg:25832")
  sample_area <- sf::st_as_sfc("POLYGON ((0 0, 4 0, 4 4, 0 4, 0 0))", crs="epsg:25832")

  set.seed(1)
  tpoints <- sf::st_sample(sample_area, 100)

  set.seed(1)
  predpoints <- sf::st_sample(aoi, 1000)

  ks <- 2:10
  ps <- (1/ks)+0.1
  tune_grid <- data.frame(ks=ks, ps=ps)

  set.seed(1)
  kout <- apply(tune_grid, 1, function(j) {
    knndm(tpoints, predpoints=predpoints, k=j[[1]], maxp=j[[2]], clustering = "kmeans")
  })

  kout_W <- sapply(kout, function(x) round(x$W,3))
  kout_Gij <- sapply(kout, function(x) round(x$Gij[1],4))
  kout_Gjstar <- sapply(kout, function(x) round(x$Gjstar[1],4))

  w_expected <- c(2.184, 2.286, 2.468, 2.554, 2.570, 2.634, 2.678, 2.694, 2.688)
  Gij_expected <- rep(1.3886, length(w_expected))
  Gjstar_expected <- c(1.0981, 1.0981, 0.5400, 0.3812, 0.2505, 0.3812, 0.3099, 0.3099, 0.3812)

  expect_identical(round(kout_W,3), w_expected)
  expect_identical(round(kout_Gij,4), Gij_expected)
  expect_identical(round(Gjstar_expected,4), Gjstar_expected)

})


test_that("kNNDM recognizes erroneous input", {
  sf::sf_use_s2(TRUE)
  aoi <- sf::st_as_sfc("POLYGON ((0 0, 10 0, 10 10, 0 10, 0 0))", crs="epsg:25832")
  tpoints <- sf::st_as_sfc("MULTIPOINT ((1 1), (1 2), (2 2), (2 3), (1 4), (5 4))", crs="epsg:25832") |>
    sf::st_cast("POINT")

  set.seed(1)
  predpoints <- sf::st_sample(aoi, 20)

  # maxp to small
  expect_error(knndm(tpoints, predpoints=predpoints, k=2, maxp=0.4))
  # k larger than number of tpoints
  expect_error(knndm(tpoints, predpoints=predpoints, k=20, maxp=0.8))
  # different crs of tpoints and predpoints
  expect_error(knndm(tpoints, predpoints=sf::st_transform(predpoints, "epsg:25833"), k=2, maxp=0.8))
  # different crs of tpoints and modeldomain
  expect_error(knndm(tpoints, modeldomain=sf::st_transform(aoi, "epsg:25833"), k=2, maxp=0.8))
  # using kmeans with geographical coordinates
  expect_error(knndm(sf::st_transform(tpoints,"epsg:4326"), predpoints=sf::st_transform(predpoints, "epsg:4326"),
                     clustering="kmeans"))
})

test_that("kNNDM yields the expected results with SpatRast modeldomain", {
  set.seed(1234)

  # prepare sample data
  data(cookfarm)
  dat <- terra::aggregate(cookfarm[,c("DEM","TWI", "NDRE.M", "Easting", "Northing","VW")],
                          by=list(as.character(cookfarm$SOURCEID)),mean)
  pts <- dat[,-1]
  pts <- sf::st_as_sf(pts,coords=c("Easting","Northing"))
  sf::st_crs(pts) <- 26911
  studyArea <- terra::rast(system.file("extdata","predictors_2012-03-25.tif",package="CAST"))
  pts <- sf::st_transform(pts, crs = sf::st_crs(studyArea))

  knndm_folds <- knndm(pts, modeldomain = studyArea)
  expect_equal(as.numeric(knndm(pts, modeldomain = studyArea)$Gjstar[40]), 61.935505)
})


test_that("kNNDM works in feature space with kmeans clustering and raster as modeldomain", {
  set.seed(1234)

  # prepare sample data
  data(cookfarm)
  dat <- terra::aggregate(cookfarm[,c("DEM","TWI", "NDRE.M", "Easting", "Northing","VW")],
                          by=list(as.character(cookfarm$SOURCEID)),mean)
  pts <- dat[,-1]
  pts <- sf::st_as_sf(pts,coords=c("Easting","Northing"))
  sf::st_crs(pts) <- 26911
  studyArea <- terra::rast(system.file("extdata","predictors_2012-03-25.tif",package="CAST"))
  pts <- sf::st_transform(pts, crs = sf::st_crs(studyArea))

  studyArea <- studyArea[[names(studyArea) %in% names(pts)]]
  train_points <- pts[,names(pts) %in% names(studyArea)]

  knndm_folds <- knndm(train_points, modeldomain = studyArea, space="feature", clustering = "kmeans")

  expect_equal(round(as.numeric(knndm_folds$Gjstar[40]),4), 0.2132)

})


test_that("kNNDM works in feature space with hierarchical clustering and raster as modeldomain", {
  set.seed(1234)

  # prepare sample data
  data(cookfarm)
  dat <- terra::aggregate(cookfarm[,c("DEM","TWI", "NDRE.M", "Easting", "Northing","VW")],
                          by=list(as.character(cookfarm$SOURCEID)),mean)
  pts <- dat[,-1]
  pts <- sf::st_as_sf(pts,coords=c("Easting","Northing"))
  sf::st_crs(pts) <- 26911
  studyArea <- terra::rast(system.file("extdata","predictors_2012-03-25.tif",package="CAST"))
  pts <- sf::st_transform(pts, crs = sf::st_crs(studyArea))

  studyArea <- studyArea[[names(studyArea) %in% names(pts)]]
  tpoints <- pts[,names(pts) %in% names(studyArea)]

  knndm_folds <- knndm(tpoints, modeldomain = studyArea, space="feature", clustering = "hierarchical")

  expect_equal(round(as.numeric(knndm_folds$Gjstar[40]),4), 0.2132)


})

test_that("kNNDM works in feature space with clustered training points", {
  skip_if_not_installed("PCAmixdata")
  set.seed(1234)

  data(splotdata)
  splotdata <- splotdata[splotdata$Country == "Chile",]

  predictors <- c("bio_1", "bio_4", "bio_5", "bio_6",
                  "bio_8", "bio_9", "bio_12", "bio_13",
                  "bio_14", "bio_15", "elev")
  trainDat <- sf::st_drop_geometry(splotdata)
  predictors_sp <- terra::rast(system.file("extdata", "predictors_chile.tif",package="CAST"))

  knndm_folds <- knndm(trainDat[,predictors], modeldomain = predictors_sp, space = "feature",
                       clustering="kmeans", k=4, maxp=0.8)


  expect_equal(round(as.numeric(knndm_folds$Gjstar[40]),4), 0.8287)

})


test_that("kNNDM works in feature space with categorical variables and predpoints", {
  set.seed(1234)

  # prepare sample data
  data(cookfarm)
  dat <- terra::aggregate(cookfarm[,c("DEM","TWI", "NDRE.M", "Easting", "Northing","VW")],
                          by=list(as.character(cookfarm$SOURCEID)),mean)
  pts <- dat[,-1]
  pts <- sf::st_as_sf(pts,coords=c("Easting","Northing"))
  sf::st_crs(pts) <- 26911
  studyArea <- terra::rast(system.file("extdata","predictors_2012-03-25.tif",package="CAST"))
  pts <- sf::st_transform(pts, crs = sf::st_crs(studyArea))

  studyArea <- studyArea[[names(studyArea) %in% names(pts)]]

  prediction_points <- terra::spatSample(studyArea, 1000, "regular")
  train_points <- pts[,names(pts) %in% names(studyArea)]

  prediction_points$fct <- factor(sample(LETTERS[1:4], nrow(prediction_points), replace=TRUE))
  train_points$fct <- factor(sample(LETTERS[1:4], nrow(pts), replace=TRUE))


  knndm_folds <- knndm(tpoints=train_points, predpoints = prediction_points,
                       space="feature", clustering = "hierarchical")

  expect_equal(round(as.numeric(knndm_folds$Gjstar[40]),3), 0.057)

})


test_that("kNNDM works in feature space with clustered training points, categorical features ", {
  skip_if_not_installed("PCAmixdata")
  set.seed(1234)
  predictor_stack <- terra::rast(system.file("extdata","predictors_2012-03-25.tif",package="CAST"))
  predictors <- c("DEM","TWI", "NDRE.M", "Easting", "Northing", "fct")
  predictor_stack$fct <- factor(c(rep(LETTERS[1], terra::ncell(predictor_stack)/2),
                                  rep(LETTERS[2], terra::ncell(predictor_stack)/2)))

  predictor_stack <- predictor_stack[[predictors]]
  studyArea <- predictor_stack
  studyArea[!is.na(studyArea)] <- 1
  studyArea <- terra::as.polygons(studyArea, values = FALSE, na.all = TRUE) |>
    sf::st_as_sf() |>
    sf::st_union()

  pts <- clustered_sample(studyArea, 30, 5, 60)
  pts <- sf::st_transform(pts, crs = sf::st_crs(studyArea))
  pts <- terra::extract(predictor_stack, terra::vect(pts), ID=FALSE)

  knndm_folds_kproto <- knndm(tpoints=pts, modeldomain = predictor_stack, space="feature", clustering = "kmeans")
  knndm_folds_hclust <- knndm(tpoints=pts, modeldomain = predictor_stack, space="feature", clustering = "hierarchical")

  expect_equal(round(as.numeric(knndm_folds_kproto$Gjstar[20]),3), 0.077)
  expect_equal(round(as.numeric(knndm_folds_hclust$Gjstar[20]),3), 0.078)

})


test_that("kNNDM works in feature space with Mahalanobis distance", {
  data(splotdata)
  splotdata <- splotdata[splotdata$Country == "Chile",]

  predictors <- c("bio_1", "bio_4", "bio_5", "bio_6",
                  "bio_8", "bio_9", "bio_12", "bio_13",
                  "bio_14", "bio_15", "elev")
  trainDat <- sf::st_drop_geometry(splotdata)
  predictors_sp <- terra::rast(system.file("extdata", "predictors_chile.tif",package="CAST"))

  set.seed(1234)
  knndm_folds <- knndm(trainDat[,predictors], modeldomain = predictors_sp, space = "feature",
                       clustering="kmeans", k=4, maxp=0.8, useMD=TRUE)

  expect_equal(round(as.numeric(knndm_folds$Gjstar[40]),4), 1.1258)

})


test_that("kNNDM works in feature space with Mahalanobis distance without clustering", {
  set.seed(1234)

  # prepare sample data
  data(cookfarm)
  dat <- terra::aggregate(cookfarm[,c("DEM","TWI", "NDRE.M", "Easting", "Northing","VW")],
                          by=list(as.character(cookfarm$SOURCEID)),mean)
  pts <- dat[,-1]
  pts <- sf::st_as_sf(pts,coords=c("Easting","Northing"))
  sf::st_crs(pts) <- 26911
  studyArea <- terra::rast(system.file("extdata","predictors_2012-03-25.tif",package="CAST"))
  pts <- sf::st_transform(pts, crs = sf::st_crs(studyArea))

  studyArea <- studyArea[[names(studyArea) %in% names(pts)]]
  train_points <- pts[,names(pts) %in% names(studyArea)]

  expect_message(knndm(train_points, modeldomain = studyArea, space="feature", clustering = "kmeans", useMD = TRUE),
                 "Gij <= Gj; a random CV assignment is returned")

  expect_message(knndm(train_points, modeldomain = studyArea, space="feature", clustering = "hierarchical", useMD = TRUE),
                 "Gij <= Gj; a random CV assignment is returned")

})


test_that("Valid range of phi", {
  set.seed(1234)
  poly <- sf::st_polygon(list(matrix(c(0,0,0,50,50,50,50,0,0,0), ncol=2,
                                     byrow=TRUE)))
  poly_sfc <- sf::st_sfc(poly)
  tpoints_sfc <- sf::st_sample(poly_sfc, 50, type = "random")
  predpoints_sfc <- sf::st_sample(poly_sfc, 50, type = "regular")

  expect_error(nndm(tpoints_sfc, predpoints = predpoints_sfc, phi = -1),
               "phi must be positive or set to 'max'.")
})

test_that("NNDM detects wrong data and geometry types", {
  set.seed(1234)
  poly <- sf::st_polygon(list(matrix(c(0,0,0,50,50,50,50,0,0,0), ncol=2,
                                     byrow=TRUE)))
  poly_sfc <- sf::st_sfc(poly)
  tpoints_sfc <- sf::st_sample(poly_sfc, 50, type = "random")
  predpoints_sfc <- sf::st_sample(poly_sfc, 50, type = "regular")

  # tpoints
  expect_error(suppressWarnings(nndm(1, predpoints = predpoints_sfc)),
               "tpoints must be a sf/sfc object.")
  expect_error(nndm(poly, predpoints = predpoints_sfc),
               "tpoints must be a sf/sfc object.")
  expect_error(nndm(sf::st_sfc(poly), predpoints = predpoints_sfc),
               "tpoints must be a sf/sfc point object.")
  # predpoints
  expect_error(suppressWarnings(nndm(tpoints_sfc, predpoints = 1)),
               "predpoints must be a sf/sfc object.")
  expect_error(nndm(tpoints_sfc, predpoints = poly),
               "predpoints must be a sf/sfc object.")
  expect_error(nndm(tpoints_sfc, predpoints = poly_sfc),
               "predpoints must be a sf/sfc point object.")

  # model domain
  expect_error(suppressWarnings(nndm(tpoints_sfc, modeldomain = 1)),
               "modeldomain must be a sf/sfc object or a 'SpatRaster' object.")
  expect_error(nndm(tpoints_sfc, modeldomain = predpoints_sfc),
               "modeldomain must be a sf/sfc polygon object.")
})

test_that("NNDM detects different CRS in inputs", {
  sf::sf_use_s2(TRUE)
  set.seed(1234)
  poly <- sf::st_polygon(list(matrix(c(0,0,0,50,50,50,50,0,0,0), ncol=2,
                                     byrow=TRUE)))
  poly_sfc <- sf::st_sfc(poly)
  tpoints_sfc <- sf::st_sample(poly_sfc, 50, type = "random")
  predpoints_sfc <- sf::st_sample(poly_sfc, 50, type = "regular")

  tpoints_sfc_4326 <- sf::st_set_crs(tpoints_sfc, 4326)
  tpoints_sfc_3857 <- sf::st_set_crs(tpoints_sfc, 3857)
  predpoints_sfc_4326 <- sf::st_set_crs(predpoints_sfc, 4326)
  predpoints_sfc_3857 <- sf::st_set_crs(predpoints_sfc, 3857)
  poly_sfc_4326 <- sf::st_set_crs(poly_sfc, 4326)

  # tests
  expect_error(nndm(tpoints_sfc_3857, predpoints = predpoints_sfc),
               "tpoints and predpoints must have the same CRS")
  expect_error(nndm(tpoints_sfc_3857, modeldomain = poly_sfc_4326),
               "tpoints and modeldomain must have the same CRS")
})



test_that("NNDM yields the expected results for all data types", {
  set.seed(1234)
  poly <- sf::st_polygon(list(matrix(c(0,0,0,50,50,50,50,0,0,0), ncol=2,
                                     byrow=TRUE)))
  poly_sfc <- sf::st_sfc(poly)
  tpoints_sfc <- sf::st_sample(poly_sfc, 50, type = "random")
  predpoints_sfc <- sf::st_sample(poly_sfc, 50, type = "regular")

  # tpoints, predpoints
  expect_equal(as.numeric(nndm(tpoints_sfc, predpoints = tpoints_sfc)$Gjstar[1]), 3.7265881)
  # tpoints, modeldomain
  expect_equal(as.numeric(nndm(tpoints_sfc, modeldomain = poly_sfc)$Gjstar[5]), 4.9417614)
  # change phi
  expect_equal(as.numeric(nndm(tpoints_sfc, predpoints = tpoints_sfc, phi = 10)$Gjstar[10]), 4.8651321)
  # change min_train
  expect_equal(as.numeric(nndm(tpoints_sfc, predpoints = tpoints_sfc, phi = 20, min_train = 0.2)$Gjstar[15]), 3.466861)
  # length checks
  expect_equal(length(nndm(tpoints_sfc, predpoints = tpoints_sfc)$Gjstar), length(tpoints_sfc))
  expect_equal(length(nndm(tpoints_sfc, predpoints = tpoints_sfc)$Gi), length(tpoints_sfc))
  expect_gt(length(nndm(tpoints_sfc, modeldomain = poly_sfc)$Gij), 900)
})

test_that("NNDM yields the expected results for all CRS", {
  sf::sf_use_s2(TRUE)
  set.seed(1234)
  poly <- sf::st_polygon(list(matrix(c(0,0,0,50,50,50,50,0,0,0), ncol=2,
                                     byrow=TRUE)))
  poly_sfc <- sf::st_sfc(poly)
  tpoints_sfc <- sf::st_sample(poly_sfc, 50, type = "random")
  predpoints_sfc <- sf::st_sample(poly_sfc, 50, type = "regular")

  # Projected
  tpoints_3857 <- sf::st_set_crs(tpoints_sfc, 3857)
  predpoints_3857 <- sf::st_set_crs(predpoints_sfc, 3857)
  expect_equal(as.numeric(nndm(tpoints_3857, predpoints = predpoints_3857, phi = 10)$Gjstar[20]), 3.2921498)

  # Geographic
  tpoints_sf_4326 <- sf::st_set_crs(tpoints_sfc, 4326)
  predpoints_sf_4326 <- sf::st_set_crs(predpoints_sfc, 4326)
  expect_equal(as.numeric(nndm(tpoints_sf_4326, predpoints = predpoints_sf_4326, phi = 1000000)$Gjstar[20]), 355614.94)
})

test_that("NNDM yields the expected results with SpatRast modeldomain", {
  set.seed(1234)

  # prepare sample data
  data(cookfarm)
  dat <- terra::aggregate(cookfarm[,c("DEM","TWI", "NDRE.M", "Easting", "Northing","VW")],
                          by=list(as.character(cookfarm$SOURCEID)),mean)
  pts <- dat[,-1]
  pts <- sf::st_as_sf(pts,coords=c("Easting","Northing"))
  sf::st_crs(pts) <- 26911
  studyArea <- terra::rast(system.file("extdata","predictors_2012-03-25.tif",package="CAST"))
  pts <- sf::st_transform(pts, crs = sf::st_crs(studyArea))

  nndm_folds <- nndm(pts, modeldomain = studyArea, phi = 150)
  expect_equal(as.numeric(nndm(pts, modeldomain = studyArea, phi = 150)$Gjstar[5]), 63.828663)

})


loaddata <- function() {
  # prepare sample data:
  data("cookfarm")
  dat <- aggregate(cookfarm[,c("VW","Easting","Northing")],by=list(as.character(cookfarm$SOURCEID)),mean)
  pts <- sf::st_as_sf(dat,coords=c("Easting","Northing"))
  pts$ID <- 1:nrow(pts)
  set.seed(100)
  pts <- pts[1:30,]
  studyArea <- terra::rast(system.file("extdata","predictors_2012-03-25.tif",package="CAST"))[[1:8]]
  trainDat <- terra::extract(studyArea,pts,na.rm=FALSE)
  trainDat <- merge(trainDat,pts,by.x="ID",by.y="ID")

  # train a model:
  set.seed(100)
  variables <- c("DEM","NDRE.Sd","TWI")
  model <- caret::train(trainDat[,which(names(trainDat)%in%variables)],
                        trainDat$VW, method="rf", importance=TRUE, tuneLength=1,
                        trControl=caret::trainControl(method="cv",number=5,savePredictions=T))


  data <- list(
    studyArea = studyArea,
    trainDat = trainDat,
    variables = variables,
    model = model
  )

  return(data)
}

test_that("trainDI works in default for a trained model", {
  skip_if_not_installed("randomForest")
  dat <- loaddata()
  #...then calculate the DI of the trained model:
  DI <- trainDI(model=dat$model, verbose = F)

  #test threshold:
  expect_equal(as.numeric(round(DI$threshold,5)), 0.38986)
  # test trainDI
  expect_equal(DI$trainDI, c(0.09043580, 0.14046341, 0.16584582, 0.57617177, 0.26840303,
                             0.14353894, 0.19768329, 0.24022059, 0.06832037, 0.29150668,
                             0.18471625, 0.57617177, 0.12344463, 0.09043580, 0.14353894,
                             0.26896008, 0.22713731, 0.24022059, 0.20388725, 0.06832037,
                             0.23604264, 0.20388725, 0.91513568, 0.09558666, 0.14046341,
                             0.16214832, 0.37107762, 0.16214832, 0.18471625, 0.12344463))
  # test summary statistics of the DI
  expect_equal(as.numeric(colMeans(DI$train)),
               c(795.4426351,4.0277978,0.2577245))
})

test_that("trainDI (with LPD = TRUE) works in default for a trained model", {
  skip_if_not_installed("randomForest")
  dat <- loaddata()
  #...then calculate the DI of the trained model:
  DI <- trainDI(model=dat$model, LPD = TRUE, verbose = F)

  #test threshold:
  expect_equal(as.numeric(round(DI$threshold,5)), 0.38986)
  #test trainLPD
  expect_identical(DI$trainLPD, as.integer(c(3, 4, 6, 0, 7,
                                             6, 2, 1, 5, 3,
                                             4, 0, 1, 2, 6,
                                             5, 4, 4, 5, 7,
                                             3, 4, 0, 2, 3,
                                             6, 1, 7, 3, 2)))
  # test summary statistics of the DI
  expect_equal(as.numeric(colMeans(DI$train)),
               c(795.4426351,4.0277978,0.2577245))
})
