# =========================================================================================
# This script contains wrapper functions for the glmnet such that we can do formula, data; instead of X, y.
# (plus "my.make.formula" function)
# =========================================================================================


library(glmnet)

# wrapper for glmnet
glmnet.f <- function(formula, data, ...)
{
  # Make sure the outcome variable is the first column
  mf = model.frame(formula = formula, data = data)
  t = terms.formula(formula, data=data)
  
  # get the outcome
  y = as.matrix(mf[,1])
  
  # get the predictors and remove the intercept column
  X = model.matrix(t, data=mf)
  X = X[,-1]   
  
  fit <- glmnet(X, y, ...)
  fit$formula <- formula
  
  return(fit)
}

# wrapper for cv.glmnet
cv.glmnet.f <- function(formula, data, ...)
{
  # Make sure the outcome variable is the first column
  mf = model.frame(formula = formula, data = data)
  t = terms.formula(formula, data=data)
  
  # get the outcome
  y = as.matrix(mf[,1])
  
  # get the predictors and remove the intercept column
  X = model.matrix(t, data=mf)
  X = X[,-1]   
  
  fit <- cv.glmnet(X, y, ...)
  fit$formula <- formula
  
  return(fit)
}

# wrapper for cv.glmnet
cv.glmnet.rep.f <- function(formula, data, rep = 10, ...)
{
  # Make sure the outcome variable is the first column
  mf = model.frame(formula = formula, data = data)
  t = terms.formula(formula, data=data)
  
  # get the outcome
  y = as.matrix(mf[,1])
  
  # get the predictors and remove the intercept column
  X = model.matrix(t, data=mf)
  X = X[,-1]
  
  rv = list()
  
  for (i in 1:rep) {
    
    rv[[i]] = cv.glmnet(X, y, ...)
    
    gc()
    
  }
  
  rv$rep = rep
  rv$formula = formula
  
  return(rv)
}


predict.glmnet.rep.f <- function(fit, data, ...){
  
  # Make sure the outcome variable is the first column
  mf = model.frame(formula = fit$formula, data = data)
  t = terms.formula(fit$formula,data=data)
  
  # get the predictors and remove the intercept column
  X = model.matrix(t, data=mf)
  X = X[,-1]
  
  yhat = matrix(nrow = nrow(X), ncol = fit$rep)
  
  for (i in 1:fit$rep) {
    
    yhat[,i] <- predict(fit[[i]], X,  ...)
    
  }
  
  return(yhat)
  
}

# wrapper for predict.glmnet
predict.glmnet.f <- function(fit, data, ...) {
  
  # Make sure the outcome variable is the first column
  mf = model.frame(formula = fit$formula, data = data)
  t = terms.formula(fit$formula,data=data)
  
  # get the predictors and remove the intercept column
  X = model.matrix(t, data=mf)
  X = X[,-1]   
  
  yhat <- predict(fit, X,  ...)
  
  return(yhat)
}

# Display only important coefficients
glmnet.tidy.coef <- function(x){
  x <- coef(x)
  df = data.frame(term=rownames(x),
                  estimate=matrix(x)[,1],
                  stringsAsFactors = FALSE)
  df = df[df[,2]!=0,]
  print(df)
}

# Complex formula
# log transformation & square transformation only applies to numerical variables
my.make.formula <- function(target,df,use.interactions2=F, use.interactions3 = F,use.logs=F,use.squares=F,use.cubics=F)
{
  # extract variables
  v = names(df)
  v = v[v!=target] 
  nv = length(v)
  
  if (use.interactions2)
  {
    f = paste0(target, " ~ . + .*. ")
  }else if(use.interactions3){
    f = paste0(target, " ~ .^3 ")
  }
  else 
  {
    f = paste(target, " ~ . ")
  }
  
  # use logs?
  if (use.logs)
  {
    for (i in 1:nv)  
    {
      if (is.numeric(df[,v[i]])){
        
        if (min(df[,v[i]]) > 0)
        {
          f = paste0(f, "+ log(", v[i], ") ")
        }
        
      }
    }
  }
  
  # use squares?
  if (use.squares)
  {
    for (i in 1:nv)  
    {
      if (is.numeric(df[,v[i]])){
        
        f = paste0(f, "+ I(", v[i], "^2) ")
        
      }
    }
  }
  
  # use cubics?
  if (use.cubics)
  {
    for (i in 1:nv)  
    {
      if (is.numeric(df[,v[i]])){
        f = paste0(f, "+ I(", v[i], "^3) ")
      }
    }
  }  
  
  return(as.formula(f))
}


