library(dplyr)
library(tidyr)
library(parallel)
library(rstudioapi)
library(mvtnorm)
library(RColorBrewer)

# the following line is for getting the path of your current open file
current_path <- getActiveDocumentContext()$path 
# The next line set the working directory to the relevant one:
setwd(dirname(current_path))
# you can make sure you are in the right directory
print(getwd())

seed <- 122
ncores <- detectCores() - 1
num_producers <- 1000
num_queries <- 50000
num_positions <- 100

set.seed(seed)

run_experiment <- function(num_positions, treatment_proportion, rho){
  scores <- rmvnorm(num_positions, sigma = rbind(c(1, rho), c(rho, 1)))
  prob <-  c(1 - treatment_proportion, treatment_proportion)
  results <- tibble(control_score = scores[, 1],
                    treatment_score = scores[, 2],
                    treatment_assignment = sample(c(0, 1), num_positions, replace = TRUE, prob = prob)) %>%
    mutate(control_rank = rank(-control_score),
           treatment_rank = rank(-treatment_score)) %>%
    mutate(unicorn_reverse_score = treatment_assignment * treatment_rank + 
             (1 - treatment_assignment) * control_rank) %>%
    mutate(unicorn_rank = rank(unicorn_reverse_score, ties.method = "random"))
  
  results <- results %>% arrange(control_rank)
  results[results$treatment_assignment == 1, ] <- results[results$treatment_assignment == 1, ] %>% arrange(treatment_rank)
  results <- results %>% mutate(unicorn_lite_rank = 1:num_positions) 
  
  results <- results %>% mutate(expected_rank = unicorn_reverse_score) %>%
    select(-control_score, -treatment_score, -control_rank, -treatment_rank, -unicorn_reverse_score) 
    
  
  return(results)
}


get_ranking_error <- function(results){
  results %>% gather("method", "observed_rank", -treatment_assignment, -expected_rank) %>%
    mutate(error = observed_rank - expected_rank) %>%
    group_by(method, expected_rank) %>% 
    summarize(error_average = mean(error),
              error_sd = sd(error),
              RMSE = sqrt(mean(error^2)),
              MAE = mean(abs(error))) %>%
    ungroup()
}

errors <- NULL


for (rho in c(-1.0, -0.4, 0.2, 0.8)){
  for (treatment_proportion in c(0.1, 0.5)){
    cat("Running experiment for treatment_proportion =", treatment_proportion, "and rho =", rho, "\n")
    ptm <- proc.time()[3]
    results <- do.call("rbind", mclapply(1:num_queries, function(q) run_experiment(num_positions, treatment_proportion, rho),
                                         mc.cores = ncores))
    errors <- rbind(errors, cbind(rho, treatment_proportion,  get_ranking_error(results)))
    cat("Done. Time taken (in seconds) =", proc.time()[3] - ptm, "\n")
  }
}

save(errors, file = "results/experiment_accuracy.Rdata")

library(ggplot2)


pdf("../images/ranking_errors.pdf", width = 16, height = 5)
errors %>% mutate(method = factor(method, levels = c("unicorn_rank", "unicorn_lite_rank"), labels = c("UniCoRn(1)", "UniCoRn(0)")), 
                  rho = factor(rho)) %>%
  gather("error_type", "error", RMSE, MAE) %>% ggplot(aes(x = expected_rank, y = error, color = method, linetype = rho)) + 
  geom_line() + 
  scale_color_manual(name = "method", values = brewer.pal(3, "Set1")) + 
  scale_linetype_discrete(name = expression(rho)) + 
  facet_grid(~treatment_proportion * error_type, scales="free_y", labeller = label_both) + theme_bw() + 
  xlab("expected rank") +
  theme(plot.margin = margin(0.3, 0.3, 0.3, 0.3, "cm"),
        axis.text.x = element_text(size = 22),
        axis.text.y = element_text(size = 22),
        axis.title.y = element_text(size = 22),
        axis.title.x = element_text(size = 22),
        strip.text.x = element_text(size = 22),
        strip.text.y = element_text(size = 22),
        legend.title = element_text(size = 22),
        legend.position = "bottom", 
        legend.text = element_text(size = 22))
dev.off()

