# Some code is from the example detailed at
# https://opensource.nibr.com/bamdd/src/02j_network_meta_analysis.html

library(brms)
library(netmeta)
library(rstan)
library(tidyverse)

set.seed(0)

## Make data
data("smokingcessation", package = "netmeta")

recode_trt <- c("A" = "No_intervention",
                "B" = "Self_help",
                "C" = "Individual_counselling",
                "D" = "Group_counselling")

smoking <- smokingcessation %>%
  mutate(studyn = 1:n()) %>%
  pivot_longer(-studyn,
               names_to = c(".value", "trtid"),
               names_pattern = "(.*)([1-9])") %>%
  filter(!is.na(n)) %>%
  transmute(
    study = factor(studyn),
    trtc = factor(recode_trt[treat], unname(recode_trt)),
    trtn = as.numeric(trtc),
    r = event,
    n
  )

control_args <- list(adapt_delta=0.95)

B <- model.matrix(~ 0 + trtc, data = smoking)
S <- model.matrix(~ 0 + study, data = smoking)

smoking_with_dummies <- bind_cols(
  dplyr::select(smoking, r, n, study, trtc),
  as_tibble(B),
  as_tibble(S)
)

fit.brm <- function(data, f, prior, ...) {
  brm(
    data = data,
    formula = f,
    family = binomial(),
    control = control_args,
    prior = prior,
    silent = 2,
    refresh = 0,
    ...
  )
}

full.model.f <- function() {
  as.formula(paste(
    "r | trials(n) ~ 0 + ",
    paste(colnames(S), collapse = " + "),
    " + ",
    paste(colnames(B), collapse = " + ")
  ))
}

partial.model.f <- function() {
  as.formula(paste(
    "r | trials(n) ~ 0 + Intercept + ",
    paste(colnames(B), collapse = " + ")
  ))
}

true_effects <- fit.brm(
  data = smoking_with_dummies,
  f = full.model.f(),
  prior = prior(class = b, normal(0, 3))
)
theta <- fixef(true_effects)[c(25:28)]

target_log_lik <- function (fit, data) {
  fit %>%
    log_lik(., newdata = data) %>%
    exp %>%
    mean %>%
    log
}

estimate.with.proxy <- function(
    target_task, use.proxy = TRUE, proxy.bias = 0., proxy.sd = .1
) {
  source_data <- smoking_with_dummies %>%
    filter(study !=  target_task)
  target_data <- smoking_with_dummies %>%
    filter(study == target_task)
  
  psi.post.mu <- fixef(true_effects)[target_task,1]
  z <- rnorm(1, mean=psi.post.mu+proxy.bias, sd=proxy.sd)
  
  classic_dat <- standata(
    full.model.f(),
    data = source_data,
    family = binomial(),
    control = control_args,
    prior = prior(class = b, normal(0, 3)),
    silent = 2,
    refresh = 0
  )
  classic_dat$z <- z
  classic_dat$use_proxy <- use.proxy
  classic_dat$target_task_id <- target_task
  classic_dat$proxy_sd <- proxy.sd
  classic_code <- read_file("classic.stan")
  classic_fit <- fit.brm(
    data = source_data,
    f = full.model.f(),
    prior = prior(class = b, normal(0, 3)),
    empty = TRUE
  )
  classic.fit <- stan(
    model_code = classic_code,
    data = classic_dat
  )
  classic_fit$fit <- classic.fit
  classic_fit <- rename_pars(classic_fit)
  
  Rweighted_dat <- standata(
    partial.model.f(),
    data = source_data,
    family = binomial(),
    control = control_args,
    prior = prior(class = b, normal(0, 3)),
    silent = 2,
    refresh = 0
  )
  Rweighted_dat$z <- z
  Rweighted_dat$use_proxy <- use.proxy
  Rweighted_dat$proxy_sd <- proxy.sd
  Rweighted_code <- read_file("Rweighted.stan")
  Rweighted_fit <- fit.brm(
    data = source_data,
    f = partial.model.f(),
    prior = prior(class = b, normal(0, 3)),
    empty = TRUE
  )
  Rweighted.fit <- stan(
    model_code = Rweighted_code,
    data = Rweighted_dat
  )
  Rweighted_fit$fit <- Rweighted.fit
  Rweighted_fit <- rename_pars(Rweighted_fit)
  
  Rweighted_log_lik <- target_log_lik(Rweighted_fit, target_data)
  classic_log_lik <- target_log_lik(classic_fit, target_data)
  list(Rweighted=Rweighted_log_lik, classic=classic_log_lik)
}

# Misleading proxy information
res <- list(
  Rweighted = matrix(nrow = 24, ncol = 1), classic = matrix(nrow = 24, ncol = 1)
)
for (i in 1:24) {
  proxy.bias <- rnorm(1, mean=0., sd=3.)
  this.res <- estimate.with.proxy(i, use.proxy = TRUE, proxy.bias=proxy.bias, proxy.sd = 3)
  res$Rweighted[i,j] <- this.res$Rweighted
  res$classic[i,j] <- this.res$classic
}
misleading <- res

# No proxy information
res <- list(Rweighted = rep(0, 24), classic = rep(0, 24))
for (i in 1:24) {
  this.res <- estimate.with.proxy(i, use.proxy = FALSE)
  res$Rweighted[i] <- this.res$Rweighted
  res$classic[i] <- this.res$classic
}
none <- res

# Weakly informative proxy information
res <- list(
  Rweighted = matrix(nrow = 24, ncol = 1), classic = matrix(nrow = 24, ncol = 1)
)
for (i in 1:24) {
  this.res <- estimate.with.proxy(i, use.proxy = TRUE, proxy.sd = 3)
  res$Rweighted[i,j] <- this.res$Rweighted
  res$classic[i,j] <- this.res$classic
}
weak <- res

# Highly informative proxy information
res <- list(
  Rweighted = matrix(nrow = 24, ncol = 1), classic = matrix(nrow = 24, ncol = 1)
)
for (i in 1:24) {
  this.res <- estimate.with.proxy(i, use.proxy = TRUE)
  res$Rweighted[i,j] <- this.res$Rweighted
  res$classic[i,j] <- this.res$classic
}
strong <- res

######

res <- bind_rows(misleading, none, weak, strong, .id="proxy")
ggplot(res, mapping=aes(x=proxy, y=Rweighted-classic)) +
  geom_violin() +
  scale_x_discrete(label=c("Misleading","None","Weakly informative","Highly informative")) +
  xlab("Amount of proxy information") +
  ylab("Relative performance of r-weighted learner") +
  geom_hline(yintercept=0.) +
  theme(axis.text=element_text(size=12),
        axis.title=element_text(size=15))
