library(dplyr)
library(tidyr)
library(parallel)
library(rstudioapi)

# the following line is for getting the path of your current open file
current_path <- getActiveDocumentContext()$path 
# The next line set the working directory to the relevant one:
setwd(dirname(current_path))
# you can make sure you are in the right directory
print(getwd())

seed <- 123
num_experiments <- 100
ncores <- detectCores() - 1
num_producers <- 1000
num_sessions <- 1000
num_positions <- 100

viet_break_ties <- function(x, max_value){
  for (i in 1:(length(x) - 1)){
    if (x[i + 1] <= x[i]){
      x[i + 1] <-  x[i] + 1
    }
  }
  
  if (x[length(x)] > max_value){
    x[length(x)] <- max_value
    for (i in length(x):2){
      if (x[i - 1] < x[i]){
        break
      }else{
        x[i - 1] <- x[i] - 1
      }
    }
  }
  return(x)
}


rank_to_feedback <- function(rank) 100 / (log(10 + rank))^2

feedback_to_response <- function(feedback, quality, num_sessions, function_type){
  if (function_type == "mean_function"){
    mean(feedback) * quality 
  }else{
    max(feedback) * quality 
  }
} 

run_experiment <- function(producer, num_positions){
  results <- producer[sample(num_producers, num_positions, replace = TRUE), ] %>%
    mutate(control_score = quality + runif(num_positions, 0, 1),
           treatment_score = quality + runif(num_positions, 0, quality)) %>%
    mutate(control_rank = rank(-control_score),
           treatment_rank = rank(-treatment_score)) %>%
    mutate(control_feedback = rank_to_feedback(control_rank),
           treatment_feedback = rank_to_feedback(treatment_rank)) %>%
    mutate(oasis_score = treatment_assignment * treatment_score / sum(treatment_score) + 
             (1 - treatment_assignment) * control_score / sum(control_score)) %>%
    mutate(oasis_rank = rank(-oasis_score)) %>%
    mutate(oasis_feedback = rank_to_feedback(oasis_rank)) %>%
    mutate(unicorn_reverse_score = treatment_assignment * treatment_rank + 
             (1 - treatment_assignment) * control_rank) %>%
    mutate(unicorn_rank = rank(unicorn_reverse_score, ties.method = "random")) %>%
    mutate(unicorn_feedback = rank_to_feedback(unicorn_rank))
  
  results <- results %>% arrange(control_rank)
  results[results$treatment_assignment == 1, ] <- results[results$treatment_assignment == 1, ] %>% arrange(treatment_rank)
  results <- results %>% mutate(unicorn_lite_rank = 1:num_positions) %>%
    mutate(unicorn_lite_feedback = rank_to_feedback(unicorn_lite_rank))
  
  results <- results %>% arrange(control_rank)
  rerankingIndices <- (rbinom(nrow(results), 1, 0.2) | results$treatment_assignment == 1)
  results[rerankingIndices, ] <- results[rerankingIndices, ] %>% arrange(unicorn_rank)
  results <- results %>% mutate(unicorn_lite2_rank = 1:num_positions) %>%
    mutate(unicorn_lite2_feedback = rank_to_feedback(unicorn_lite2_rank))
  
  viet_positions <- viet_break_ties(sort(c(results$control_rank[results$viet_assignment == 1], 
                                           results$treatment_rank[results$viet_assignment == 2])), max_value = nrow(results))
  
  results2 <- results
  results2[viet_positions, ] <- results[results$viet_assignment %in% c(1, 2), ] %>% arrange(unicorn_rank[results$viet_assignment %in% c(1, 2)])
  results2[-viet_positions, ] <- results[results$viet_assignment == 0, ] %>% arrange(control_rank[results$viet_assignment == 0])
  
  results <- results2 %>% mutate(viet_rank = 1:num_positions) %>%
    mutate(viet_feedback = rank_to_feedback(viet_rank))
  
}

get_treatment_effect_estimates <- function(results){
  response <- results %>% group_by(producer_id) %>% 
    summarize(treatment_assignment = first(treatment_assignment),
              viet_assignment = first(viet_assignment),
              control_response = feedback_to_response(control_feedback, first(quality), num_sessions, function_type),
              treatment_response = feedback_to_response(treatment_feedback, first(quality), num_sessions, function_type),
              oasis_response = feedback_to_response(oasis_feedback, first(quality), num_sessions, function_type),
              unicorn_response = feedback_to_response(unicorn_feedback, first(quality), num_sessions, function_type),
              unicorn_lite_response = feedback_to_response(unicorn_lite_feedback, first(quality), num_sessions, function_type),
              unicorn_lite2_response = feedback_to_response(unicorn_lite2_feedback, first(quality), num_sessions, function_type),
              viet_response = feedback_to_response(viet_feedback, first(quality), num_sessions, function_type))
  
  treatment_effect <- mean(response$treatment_response) - mean(response$control_response)
  viet_estimate <- (mean(response$viet_response[response$viet_assignment == 2]) - mean(response$viet_response[response$viet_assignment == 1]))
  
  estimates <- response %>% group_by(treatment_assignment) %>%
    summarize(oasis_average = mean(oasis_response), 
#              oasis_sd = sd(oasis_response),
              unicorn_average = mean(unicorn_response), 
#              unicorn_sd = sd(unicorn_response),
              unicorn_lite_average = mean(unicorn_lite_response),
#              unicorn_lite_sd = sd(unicorn_lite_response),
              unicorn_lite2_average = mean(unicorn_lite2_response),
#              unicorn_lite2_sd = sd(unicorn_lite2_response)
             ) %>%
    ungroup() %>%
    summarize(oasis_estimate = sum(treatment_assignment * oasis_average  - 
                                     (1 - treatment_assignment) * oasis_average),
              # oasis_sd = sum(treatment_assignment * oasis_sd  + 
              #                  (1 - treatment_assignment) * oasis_sd),
              unicorn_estimate = sum(treatment_assignment * unicorn_average  - 
                                       (1 - treatment_assignment) * unicorn_average),
              # unicorn_sd = sum(treatment_assignment * unicorn_sd  + 
              #                  (1 - treatment_assignment) * unicorn_sd),
              unicorn_lite_estimate = sum(treatment_assignment * unicorn_lite_average  - 
                                       (1 - treatment_assignment) * unicorn_lite_average),
              # unicorn_lite_sd = sum(treatment_assignment * unicorn_lite_sd  + 
              #                    (1 - treatment_assignment) * unicorn_lite_sd)
              unicorn_lite2_estimate = sum(treatment_assignment * unicorn_lite2_average  - 
                                            (1 - treatment_assignment) * unicorn_lite2_average)) %>% unlist()
  estimates <- c(treatment_effect, estimates, viet_estimate)
  names(estimates)[1] <- "treatment_effect"
  names(estimates)[length(estimates)] <- "viet_estimate"
  return(estimates)
}

estimates <- NULL

for (i in 1:num_experiments){
  for (function_type in c("mean_function", "max_function")){
    for (treatment_proportion in c(0.1, 0.5)){
      cat("Running experiment", i, "for treatment_proportion =", treatment_proportion, "and function_type =", function_type, "\n")
      ptm <- proc.time()[3]
      set.seed(seed * i)
      producer <- tibble(producer_id = 1:num_producers, 
                         quality = rbeta(num_producers, 2, 5),
                         treatment_assignment = 0,
                         viet_assignment = 0)
      producer$treatment_assignment[sample(1:num_producers, ceiling(num_producers * treatment_proportion))] <- 1
      idx1 <- sample((1:num_producers)[producer$treatment_assignment == 0], ceiling(num_producers * 0.1))
      idx2 <- sample((1:num_producers)[producer$treatment_assignment == 1], ceiling(num_producers * 0.1))
      producer$viet_assignment[idx1] <- 1
      producer$viet_assignment[idx2] <- 2
      results <- do.call("rbind", mclapply(1:num_sessions, function(q) run_experiment(producer, num_positions), 
                                           mc.cores = 1))
      estimates <- rbind(estimates, c(treatment_proportion, factor(function_type, c("mean_function", "max_function")),
                                      get_treatment_effect_estimates(results)))
      cat("Done. Time taken (in seconds) =", proc.time()[3] - ptm, "\n")
    }
  }
}

colnames(estimates)[c(1,2)] <- c("treatment_proportion", "function_type")


save(estimates, file = "results/treatment_effect_estimation.Rdata")


library(ggplot2)


pdf("../images/estimation_errors.pdf", width = 14, height = 6)
as_tibble(estimates) %>% #select(-oasis_sd, -unicorn_sd, -unicorn_lite_sd) %>%
  gather("method", "estimate", -treatment_proportion, -function_type, -treatment_effect) %>%
  mutate(treatment_proportion = factor(treatment_proportion)) %>%
  mutate(response_func = factor(function_type, labels= c(expression("avg_fn"), "max_fn"))) %>%
  group_by(function_type) %>%
  mutate(treatment_effect = mean(treatment_effect)) %>%
  ungroup() %>%
  group_by(treatment_proportion, function_type) %>%
  mutate(error = (estimate - treatment_effect)) %>%
  ggplot(aes(x = factor(method, levels = c("unicorn_lite_estimate", "unicorn_lite2_estimate", "unicorn_estimate","viet_estimate", "oasis_estimate"), 
                        labels = c("UniCoRn(0)", "UniCoRn(0.2)", "UniCoRn(1)", "HaThucEtAl", "OASIS")), y = error)) + 
  stat_boxplot() + facet_grid(response_func~treatment_proportion, 
                              labeller = labeller(response_func = label_value, treatment_proportion = label_both)) +
  xlab("method") + ylab("error")  + theme_bw() +  scale_y_continuous(breaks=c(-1,0, 1)) +
  theme(plot.margin = margin(0, 0, 0, 0, "cm"),
        axis.text.x = element_text(size = 25, angle = 30, hjust = 1),
        axis.text.y = element_text(size = 25),
        axis.title.y = element_text(size = 25),
        axis.title.x = element_blank(),
        strip.text.x = element_text(size = 25),
        strip.text.y = element_text(size = 25),
        legend.position = "bottom", 
        legend.title = element_text(size = 25),
        legend.text = element_text(size = 25)) +
  geom_abline(slope = 0, intercept = 0, linetype = "dashed", color = "grey")
dev.off()


cost <- tibble(method = factor(rep(c("UniCoRn(0)", "UniCoRn(0.2)", "UniCoRn(1)", "HaThucEtAl", "OASIS"), 2),
                               levels = c("UniCoRn(0)", "UniCoRn(0.2)", "UniCoRn(1)", "HaThucEtAl", "OASIS")),
               TP = factor(rep(c(0.1, 0.5), each = 5)),
               cost = c(110, 128, 200, 200, 200, 150, 160, 200, 200, 200))

pdf("../images/cost.pdf", width = 7, height = 6)
cost %>% ggplot(aes(x = factor(method), y = cost, fill = TP)) +
  geom_bar(stat="identity", position=position_dodge()) + 
  theme(plot.margin = margin(0, 0, 0, 0, "cm"),
        axis.text.x = element_text(size = 25, angle = 30, hjust = 1, vjust = 1),
        axis.text.y = element_text(size = 25),
        axis.title.y = element_text(size = 25),
        axis.title.x = element_blank(),
        strip.text.x = element_text(size = 25),
        strip.text.y = element_text(size = 25),
        legend.title = element_text(size = 25),
        legend.text = element_text(size = 25))
dev.off()
