library(dplyr)
library(tidyr)
library(parallel)
library(rstudioapi)
library(mvtnorm)
library(RColorBrewer)

# the following line is for getting the path of your current open file
current_path <- getActiveDocumentContext()$path 
# The next line set the working directory to the relevant one:
setwd(dirname(current_path))
# you can make sure you are in the right directory
print(getwd())

seed <- 123
ncores <- detectCores() - 1
num_producers <- 1000
num_queries <- 50000
num_positions <- 100

set.seed(seed)

run_experiment <- function(num_positions, treatment_proportion, rho, alpha){
  scores <- rmvnorm(num_positions, sigma = rbind(c(1, rho), c(rho, 1)))
  prob <-  c(1 - treatment_proportion, treatment_proportion)
  results <- tibble(control_score = scores[, 1],
                    treatment_score = scores[, 2],
                    treatment_assignment = sample(c(0, 1), num_positions, replace = TRUE, prob = prob)) %>%
    mutate(control_rank = rank(-control_score),
           treatment_rank = rank(-treatment_score)) %>%
    mutate(unicorn_reverse_score = treatment_assignment * treatment_rank + 
             (1 - treatment_assignment) * control_rank) %>%
    mutate(unicorn_rank = rank(unicorn_reverse_score, ties.method = "random"))
  
  error <- tibble(treatment_proportion = NULL, alpha = NULL, cost = NULL, RMSE = NULL)
  for (alpha in seq(0, 1, 0.1)){
    results <- results %>% arrange(control_rank)
    rerankingIndices <- (rbinom(nrow(results), 1, alpha) | results$treatment_assignment == 1)
    results[rerankingIndices, ] <- results[rerankingIndices, ] %>% arrange(unicorn_rank)
    results <- results %>% mutate(unicorn_alpha_rank = 1:num_positions)
    error <- rbind(error, tibble(treatment_proportion = treatment_proportion, alpha = alpha, 
                                 cost = num_positions * (treatment_proportion + alpha * (1 - treatment_proportion)),
                                 inaccuracy = mean((results$unicorn_alpha_rank - results$unicorn_reverse_score)^2)))
  }
  
  return(error)
}

errors <- NULL
for (rho in c(0.8)){
  for (treatment_proportion in seq(0.1, 0.5, 0.1)){
    cat("Running experiment for treatment_proportion =", treatment_proportion, "and rho =", rho, "\n")
    ptm <- proc.time()[3]
    errors <- rbind(errors, 
                    do.call("rbind", mclapply(1:num_queries, function(q) run_experiment(num_positions, treatment_proportion, rho),
                                         mc.cores = ncores)))
    cat("Done. Time taken (in seconds) =", proc.time()[3] - ptm, "\n")
  }
}

errors <- errors %>% group_by(treatment_proportion, alpha) %>% summarise(inaccuracy = mean(inaccuracy), cost = mean(cost))

save(errors, file = "results/trade-off.Rdata")

errors$TP <- factor(errors$treatment_proportion)


library(ggplot2)
pdf("../images/trade-off.pdf", width = 4, height = 4)
errors %>% ggplot(aes(x = cost, y = inaccuracy, color = TP)) + geom_line() +
  theme(plot.margin = margin(0.3, 0.3, 0.3, 0.3, "cm"),
        axis.text.x = element_text(size = 22),
        axis.text.y = element_text(size = 22),
        axis.title.y = element_text(size = 22),
        axis.title.x = element_text(size = 22),
        strip.text.x = element_text(size = 22),
        strip.text.y = element_text(size = 22),
        legend.title = element_text(size = 22),
        legend.text = element_text(size = 22))
dev.off()

