rm(list=ls())


library(randRotation)
library(ggplot2)
library(gridExtra)
library(tidyr)
library(cowplot)

create_data <- function(m,d,n,rho_global,alphas,repeats){
  
  haar_rotation_matrix <- function(d) {
    # generate random matrix from standard normal distribution
    mat <- matrix(rnorm(d^2), nrow=d)
    
    # perform QR decomposition
    qr_decomp <- qr(mat)
    
    # extract orthogonal matrix from QR decomposition
    orth_mat <- qr.Q(qr_decomp)
    
    # adjust sign of first column to ensure determinant of matrix is +1
    if (det(orth_mat) < 0) {
      orth_mat[, 1] <- -orth_mat[, 1]
    }
    
    return(orth_mat)
  }
  
equal_subsets <- function(m, d) {
    # create vector of integers from 1 to m
    vec <- 1:m
    # calculate size of each subset
    subset_size <- ceiling(length(vec) / d)
    # split vector into d subsets of equal size
    subsets <- split(vec, rep(1:d, each=subset_size, length.out=length(vec)))
    return(subsets)
}

# Combined locally optimal test==
t1 <- function(x,m,d,alpha){
  s = 0
  for(j in 1:ncol(as.matrix(x))){
    s = s + sum(n^(1)*x[,j]^2) 
  }
  return(s > qchisq(1-alpha, df=d*m))
}

# Public coin test
t2 <- function(x,m,d,alpha){
  s = 0
  Z = rnorm(n=d,mean=0,sd=1)
  #u = randorth(d, type = "orthonormal", I.matrix = FALSE)
  u= haar_rotation_matrix(d)
  for(j in 1:ncol(as.matrix(x))){
    s = s + n^(1/2)*(u%*%x[,j])[1]
  }
  return((m^(-1/2)*s)^2 > qchisq(1-alpha, df=1))
}

# Private coin test
t3 <- function(x,m,d,alpha){
  ss = 0*vector(length=d)
  subsets <- equal_subsets(m,d)
  for(i in 1:d){
    stemp = 0
    for(j in subsets[[i]]){
      stemp = stemp + (n)^(1/2)*x[i,j]
    }
    ss[i] = stemp
  }
  s <- sum((d/m)*ss^2)
  
  return(s > qchisq(1-alpha, df=d))
}

record2 <- array(dim = c(length(alphas),repeats,5))
for(rep in 1:repeats){
  if(rep%%100==0){print(rep)}
  for(count in 1:length(alphas)){
    alpha <- alphas[count]  
    
    # Generate f using Rademacher's or 
    V <- rnorm(n=d,mean=0,sd=1)
    f <- rho_global*V/sqrt(sum(V^2))
    # f <- (2*rbinom(n = d,size = 1,prob=0.5)-rep(1,d))*rho_global/sqrt(d)
    
    # Generate data
    x <- matrix(nrow = d, ncol = m)
    for(j in 1:m){
      x[,j] <- f + rnorm(d,mean=0,sd=1/sqrt(n))
    }
    record2[count,rep,1] <- t1(x,m,d,alpha)
    #print("start_t2")
    record2[count,rep,2] <- t2(x,m,d,alpha) 
    #print("end_t2")
    # Exclude the coordinate wise private coin test whenever d does not divide m
    record2[count,rep,3] <- ifelse(0==(m%%d),t3(x,m,d,alpha),0)
    record2[count,rep,4] <- t1(as.matrix(x[,1]),1,d,alpha)
    record2[count,rep,5] <- t1(as.matrix(rowSums(x)/sqrt(m)),1,d,alpha)
  }
}

roc_curve <- matrix(nrow = 5, ncol = length(alphas))
for(count in 1:length(alphas)){
  roc_curve[1,count] <- mean(record2[count, ,1])
  roc_curve[2,count] <- mean(record2[count, ,2])
  roc_curve[3,count] <- mean(record2[count, ,3])
  roc_curve[4,count] <- mean(record2[count, ,4])
  roc_curve[5,count] <- mean(record2[count, ,5])
}

dat <- data.frame( TPR = cbind(t(roc_curve),alphas), FPR = alphas)
# Reshape data to long format
dat_long <- tidyr::pivot_longer(dat, cols = starts_with("TPR"), names_to = "Curve", values_to = "Value")

return(dat_long)
}

# SIMULATION
set.seed(2023)
m <- 20
d <- 2
n <- 30
rho_global <- 0.5
alphas <- seq(from = 0.01, to = 0.99, by=0.01)
repeats <- 10000

# p <- ggplot(dat_long, aes(x = FPR, y = Value, color = Curve)) +
#   geom_line() +
#   labs(title = "", x = "False Positive Rate", y = "True Positive Rate") +
#   scale_color_discrete(name = "Curve") +
#   theme_minimal()

# SIMULATIONS FOR ARTICLE
dp1 <- create_data(m,d=30,n,rho_global=0.75*30^(1/4)/sqrt(n),alphas,repeats)
dp2 <- create_data(m,d=60,n,rho_global=0.75*60^(1/4)/sqrt(n),alphas,repeats)
dp3 <- create_data(m,d=90,n,rho_global=0.75*90^(1/4)/sqrt(n),alphas,repeats)
dp4 <- create_data(m,d=120,n,rho_global=0.75*120^(1/4)/sqrt(n),alphas,repeats)


library(ggalt)

p1 <-  ggplot(dp1, aes(x = FPR, y = Value, color = Curve)) +
  #geom_xspline(show.legend = FALSE, spline_shape = -0.1) +
  geom_line(show.legend = FALSE) +
  labs(title = "", x = "False Positive Rate", y = "True Positive Rate") +
  scale_color_discrete(name = "Curve") +
  theme_minimal()
p2 <-ggplot(dp2, aes(x = FPR, y = Value, color = Curve)) +
  geom_line(show.legend = FALSE) +
  labs(title = "", x = "False Positive Rate", y = "True Positive Rate") +
  scale_color_discrete(name = "Curve") +
  theme_minimal()
p3 <-  ggplot(dp3, aes(x = FPR, y = Value, color = Curve)) +
  geom_line(show.legend = FALSE) +
  labs(title = "", x = "False Positive Rate", y = "True Positive Rate") +
  scale_color_discrete(name = "Curve", labels=c('identity', 'Chi-square combined', "directional coordinated", "directional uncoordinated", "single trial", "Chi-square pooled")) +
  theme_minimal()
p4 <-  ggplot(dp4, aes(x = FPR, y = Value, color = Curve)) +
  geom_line(show.legend = FALSE) +
  labs(title = "", x = "False Positive Rate", y = "True Positive Rate") +
  scale_color_discrete(name = "Curve", labels=c('identity', 'Chi-square combined', "directional coordinated", "directional uncoordinated", "single trial", "Chi-square pooled")) +
  theme_minimal()
p <- grid.arrange(p1, p2, p3, ncol = 3)
temp <- ggplot(dp3, aes(x = FPR, y = Value, color = Curve)) +
  geom_line(show.legend = TRUE) +
  labs(title = "", x = "False Positive Rate", y = "True Positive Rate") +
  scale_color_discrete(name = "Curve", labels=c('identity', 'Chi-square combined', "directional coordinated", "directional uncoordinated", "single trial", "Chi-square pooled")) +
  theme_minimal()
legend <- cowplot::get_legend(temp)
#combined_plots <- grid.arrange(legend, p, ncol = 1, heights = c(0.2, 0.8))
#print(combined_plots)

setwd(dirname(rstudioapi::getActiveDocumentContext()$path))
getwd()
p <- grid.arrange(p1, p2, p3,p4, ncol = 2)

png(filename="roc_curves_d_large.png", width = 580, height = 480)
grid.arrange(p, legend, ncol = 2, widths = c(4, 1))
dev.off()


