rm(list=ls())
setwd("~/two sample test")
source("R/accuracy.R")
source("R/matching.R")
source("R/generate.R")
source("R/checks.R")
source("R/hypot.R")
source("R/agg_functions.R")
source("R/ASE.R")
source("R/FPR_TPR.R")
source("R/roc_revisions.R")
library(nett)
library(ggplot2)
library(igraph)

calculate_auc <- function(x, y){
  x <- sort(x)
  y <- sort(y)
  sum(diff(x) * (head(y,-1)+tail(y,-1)))/2
}
K <- 15

res <- read.csv(paste('~/two sample test/res/new/K', K, '/null.csv', sep = ""), header  = FALSE)
res_alt <- read.csv(paste('~/two sample test/res/new/K', K, '/alt.csv', sep = ""), header  = FALSE)

colnames(res) <- c("seed", "iter", "tstat", 
                   "ase_bw001", "ase_bw01", "ase_bw1", "ase_bw10", "ase_bw100",
                   "j2", "j3", "j4", "j5", "j10", "j15", "j20")
colnames(res_alt) <- c("seed", "iter", "tstat", 
                   "ase_bw001", "ase_bw01", "ase_bw1", "ase_bw10", "ase_bw100",
                   "j2", "j3", "j4", "j5", "j10", "j15", "j20")
################################################################################
# test
cur_null <-res[res$seed == 1001,]
cur_alt <- res_alt[res_alt$seed == 1001,]

o <- get_fpr_tpr(cur_null$tstat, cur_alt$tstat)
o2 <- get_roc(cur_null$tstat, cur_alt$tstat)
o2
plot(o2$FPR, o2$TPR)
plot(o$fpr, o$tpr)
################################################################################
# Choosing best bandwidth
get_avg_roc <- function(colname){
  fpr <- seq(0, 1, 0.01)
  method_tpr <- rep(0, 101)

  for (seed in 1000:1049){
    #print(seed)
    cur_null <-res[res$seed == seed,]
    cur_alt <- res_alt[res_alt$seed == seed,]
    
    method_roc <- get_roc(as.matrix(cur_null[colname]), as.matrix(cur_alt[colname]))

    method_idx <- rep(0, 101)

    for (i in 0:100){
      method_idx[i+1] <- min(which(method_roc$FPR == i/100))
    }

    method_tpr <- method_tpr + method_roc$TPR[method_idx]/50
  }
  
  plotdata <- data.frame(fpr, tpr = method_tpr)
}

a <- get_avg_roc("tstat")
df_list <- lapply(colnames(res)[9:15], get_avg_roc)

################################################################################
cbbPalette <- c("#56B4E9",  "#E69F00", "#000000", "#009E73", "#F0E442", "#0072B2", "#D55E00", "#CC79A7")
ggplot(NULL) +
  geom_line(aes(x = df_list[[1]]$fpr, y = df_list[[1]]$tpr, colour = "2"),  size = 2) +
  geom_line(aes(x = df_list[[2]]$fpr, y = df_list[[2]]$tpr, colour = "3"),  size = 2) +
  geom_line(aes(x = df_list[[3]]$fpr, y = df_list[[3]]$tpr, colour = "4"),  size = 2) +
  geom_line(aes(x = df_list[[4]]$fpr, y = df_list[[4]]$tpr, colour = "5"),  size = 2) +
  geom_line(aes(x = df_list[[5]]$fpr, y = df_list[[5]]$tpr, colour = "10"),  size = 2) +
  geom_line(aes(x = df_list[[6]]$fpr, y = df_list[[6]]$tpr, colour = "15"),  size = 2) +
  geom_line(aes(x = df_list[[7]]$fpr, y = df_list[[7]]$tpr, colour = "20"),  size = 2) +
  ggplot2::scale_colour_manual(values=cbbPalette)+
  ggplot2::theme_bw() +
  ggplot2::theme(text = ggplot2::element_text(size=18))+
  ggplot2::coord_fixed(ratio = 1) +
  ggplot2::geom_abline(intercept =0 , slope = 1, linetype="dashed") +
  ggplot2::scale_x_continuous(limits = c(-0.01,1.01), expand = c(0,0)) +
  ggplot2::scale_y_continuous(limits = c(-0.01,1.01), expand = c(0,0)) +
  ggplot2::theme(
    legend.background = ggplot2::element_blank(),
    legend.title = ggplot2::element_blank(),
    legend.position = c(0.7, 0.2),
    legend.text = ggplot2::element_text(size=15),
    text = ggplot2::element_text(size=16)
  ) +
  ggplot2::guides(colour = ggplot2::guide_legend(keywidth = 4, keyheight = 1)
  ) +
  geom_abline() +
  xlab("FPR") +
  ylab("TPR")
