rm(list=ls())
setwd("~/two sample test")
source("R/accuracy.R")
source("R/matching.R")
source("R/generate.R")
source("R/checks.R")
source("R/hypot.R")
source("R/ASE.R")
source("R/roc_revisions.R")
source("R/get_labels.R")
source("R/nclm.R")
library(nett)
library(Rcpp)
library(Matrix)
library(RcppArmadillo)
library(parallel)
library(pbmcapply)
library(ggplot2)

Rcpp::sourceCpp("src/models.cpp")
core_count <- detectCores() - 1
num_reps <- 100

n <- 10000
lambda <- 0.15
eigenvectors <- matrix(c(1, 1, 1, -1), nrow = 2)/ sqrt(2)
eigenvalues <- diag(c(5, 1))



# Covariance matrix
cov_matrix <- eigenvectors %*% eigenvalues %*% t(eigenvectors)
# Rotation matrix for 90 degrees counterclockwise
rotation_matrix <- matrix(c(0, 1, -1, 0), nrow = 2)


# X <- matrix(rnorm(n * 2), ncol = 2) %*% chol(cov_matrix)
# X2 <- matrix(rnorm(n * 2), ncol = 2) %*% chol(cov_matrix)
# XO <- matrix(rnorm(n * 2), ncol = 2) %*% chol(rotation_matrix %*% cov_matrix %*% t(rotation_matrix))
# Y_alt <- matrix(rnorm(n * 2), ncol = 2)
K <- 2
d_ase <- 2
d_nclm <- 10

################################################################################
set.seed(1000)
res <- do.call(rbind, mclapply(1:num_reps, function(rep_id) {
  A1 <- sample_rdpg(matrix(rnorm(n * 2), ncol = 2) %*% chol(cov_matrix), lambda)
  A2 <- sample_rdpg(matrix(rnorm(n * 2), ncol = 2) %*% chol(cov_matrix), lambda)
  
  z1hat <- spec_clust(A1, K, niter = 50)
  z2hat <- spec_clust(A2, K, niter = 50)
  
  #z1hat <- get_labels(A1, K, alpha, beta, niter)
  #z2hat <- get_labels(A2, K, alpha, beta, niter)
  
  That <- sbm_tst_old(A1, A2, z1hat, z2hat)
  ase_dist <- fast_mmd(ase(A1, d_ase), ase(A2, d_ase), nBasis = 2^9, sigvec = 0.5)$biased
  g1 <- log_moment(A1, d_nclm)
  g2 <- log_moment(A2, d_nclm)
  nclm <- l2norm_squared(g1 - g2)
  
  #log_progress()
  return(data.frame(rep = rep_id, 
                    sbmts_stat = That,
                    ase_stat = ase_dist,
                    nclm_stat = nclm))
}, mc.cores = core_count))

#Alternative
res_alt <- do.call(rbind, mclapply(1:num_reps, function(rep_id) {
  A1 <- sample_rdpg(matrix(rnorm(n * 2), ncol = 2) %*% chol(cov_matrix), lambda)
  A2 <- sample_rdpg(matrix(rnorm(n * 2), ncol = 2), lambda)
  
  
  z1hat <- spec_clust(A1, K, niter = 50)
  z2hat <- spec_clust(A2, K, niter = 50)
  
  
  That <- sbm_tst_old(A1, A2, z1hat, z2hat)
  ase_dist <- fast_mmd(ase(A1, d_ase), ase(A2, d_ase), sigvec = 0.5)$biased
  g1 <- log_moment(A1, d_nclm)
  g2 <- log_moment(A2, d_nclm)
  nclm <- l2norm_squared(g1 - g2)
  
  return(data.frame(rep = rep_id, 
                    sbmts_stat = That, 
                    ase_stat = ase_dist,
                    nclm_stat = nclm))
}, mc.cores = core_count))

sbm_roc <- get_roc(res$sbmts_stat, res_alt$sbmts_stat)
ase_roc <- get_roc(res$ase_stat, res_alt$ase_stat)
nclm_roc <- get_roc(res$nclm_stat, res_alt$nclm_stat)

cbbPalette <- c("#56B4E9",  "#E69F00", "#000000", "#009E73", "#F0E442", "#0072B2", "#D55E00", "#CC79A7")
ggplot(NULL) +
  geom_line(aes(x = sbm_roc$FPR, y = sbm_roc$TPR+0.01, colour = "SBM-TS"),  size = 2) +
  geom_line(aes(x = ase_roc$FPR, y = ase_roc$TPR-0.01, colour = "ASE"), size = 2) +
  geom_line(aes(x = nclm_roc$FPR, y = nclm_roc$TPR-0.01, colour = "NCLM"), size = 2) +
  ggplot2::scale_colour_manual(values=cbbPalette)+
  ggplot2::geom_line(size=2)  +
  ggplot2::theme_bw() +
  ggplot2::theme(text = ggplot2::element_text(size=18))+
  ggplot2::coord_fixed(ratio = 1) +
  ggplot2::geom_abline(intercept =0 , slope = 1, linetype="dashed") +
  ggplot2::scale_x_continuous(limits = c(-0.01,1.01), expand = c(0,0)) +
  ggplot2::scale_y_continuous(limits = c(-0.01,1.02), expand = c(0,0)) +
  ggplot2::theme(
    legend.background = ggplot2::element_blank(),
    legend.title = ggplot2::element_blank(),
    legend.position = c(0.7, 0.2),
    legend.text = ggplot2::element_text(size=15),
    text = ggplot2::element_text(size=16)
  ) +
  ggplot2::guides(colour = ggplot2::guide_legend(keywidth = 4, keyheight = 1)
  ) +
  geom_abline() +
  xlab("FPR") +
  ylab("TPR")


ggplot(res, aes(x = sbmts_stat)) + 
  ggplot2::theme(text = ggplot2::element_text(size=18))+
  xlab("SBM Test") + 
  geom_histogram(color="black", fill="white", bins = 20)

ggplot(res_alt, aes(x = sbmts_stat)) + 
  ggplot2::theme(text = ggplot2::element_text(size=18))+
  xlab("SBM Test") + 
  geom_histogram(color="black", fill="white", bins = 20)

ggplot(res, aes(x = ase_stat)) +   
  ggplot2::theme(text = ggplot2::element_text(size=18))+
  xlab("d_ASE") + 
  geom_histogram(color="black", fill="white", bins = 20) 

ggplot(res_alt, aes(x = ase_stat)) + 
  ggplot2::theme(text = ggplot2::element_text(size=18))+
  xlab("d_ASE") + 
  geom_histogram(color="black", fill="white", bins = 20)

################################################################################
set.seed(100) 
#NULL
res <- do.call(rbind, mclapply(1:num_reps, function(rep_id) {
  if (runif(1, 0, 1) > 0.5){
    A1 <- sample_rdpg(matrix(rnorm(n * 2), ncol = 2) %*% chol(cov_matrix), lambda)
    A2 <- sample_rdpg(matrix(rnorm(n * 2), ncol = 2) %*% 
                        chol(rotation_matrix %*% cov_matrix %*% t(rotation_matrix)),
                      lambda)
  } else {
    A1 <- sample_rdpg(matrix(rnorm(n * 2), ncol = 2) %*% chol(cov_matrix), lambda)
    A2 <- sample_rdpg(matrix(rnorm(n * 2), ncol = 2) %*% chol(cov_matrix), lambda)
  }
  z1hat <- spec_clust(A1, K, niter = 50)
  z2hat <- spec_clust(A2, K, niter = 50)
  
  #z1hat <- get_labels(A1, K, alpha, beta, niter)
  #z2hat <- get_labels(A2, K, alpha, beta, niter)
  
  That <- sbm_tst_old(A1, A2, z1hat, z2hat)
  ase_dist <- fast_mmd(ase(A1, d_ase), ase(A2, d_ase), nBasis = 2^9, sigvec = 0.5)$biased
  g1 <- log_moment(A1, d_nclm)
  g2 <- log_moment(A2, d_nclm)
  nclm <- l2norm_squared(g1 - g2)
  
  #log_progress()
  return(data.frame(rep = rep_id, 
                    sbmts_stat = That,
                    ase_stat = ase_dist,
                    nclm_stat = nclm))
}, mc.cores = core_count))


#Alternative
res_alt <- do.call(rbind, mclapply(1:num_reps, function(rep_id) {
  A1 <- sample_rdpg(matrix(rnorm(n * 2), ncol = 2) %*% chol(cov_matrix), lambda)
  A2 <- sample_rdpg(matrix(rnorm(n * 2), ncol = 2), lambda)
  
  z1hat <- spec_clust(A1, K, niter = 50)
  z2hat <- spec_clust(A2, K, niter = 50)
  
  #z1hat <- get_labels(A1, K, alpha, beta, niter)
  #z2hat <- get_labels(A2, K, alpha, beta, niter)
  
  
  That <- sbm_tst_old(A1, A2, z1hat, z2hat)
  ase_dist <- fast_mmd(ase(A1, d_ase), ase(A2, d_ase), nBasis = 2^9, sigvec = 0.5)$biased
  g1 <- log_moment(A1, d_nclm)
  g2 <- log_moment(A2, d_nclm)
  nclm <- l2norm_squared(g1 - g2)
  
  #log_progress()
  return(data.frame(rep = rep_id, 
                    sbmts_stat = That,
                    ase_stat = ase_dist,
                    nclm_stat = nclm))
}, mc.cores = core_count))


sbm_roc <- get_roc(res$sbmts_stat, res_alt$sbmts_stat)
ase_roc <- get_roc(res$ase_stat, res_alt$ase_stat)
nclm_roc <- get_roc(res$nclm_stat, res_alt$nclm_stat)


ggplot(NULL) +
  geom_line(aes(x = sbm_roc$FPR, y = sbm_roc$TPR+0.01, colour = "SBM-TS"),  size = 2) +
  geom_line(aes(x = ase_roc$FPR, y = ase_roc$TPR-0.01, colour = "ASE"), size = 2) +
  geom_line(aes(x = nclm_roc$FPR, y = nclm_roc$TPR-0.01, colour = "NCLM"), size = 2) +
  ggplot2::scale_colour_manual(values=cbbPalette)+
  ggplot2::geom_line(size=2)  +
  ggplot2::theme_bw() +
  ggplot2::theme(text = ggplot2::element_text(size=18))+
  ggplot2::coord_fixed(ratio = 1) +
  ggplot2::geom_abline(intercept =0 , slope = 1, linetype="dashed") +
  ggplot2::scale_x_continuous(limits = c(-0.01,1.01), expand = c(0,0)) +
  ggplot2::scale_y_continuous(limits = c(-0.01,1.01), expand = c(0,0)) +
  ggplot2::theme(
    legend.background = ggplot2::element_blank(),
    legend.title = ggplot2::element_blank(),
    legend.position = c(0.7, 0.2),
    legend.text = ggplot2::element_text(size=15),
    text = ggplot2::element_text(size=16)
  ) +
  ggplot2::guides(colour = ggplot2::guide_legend(keywidth = 4, keyheight = 1)
  ) +
  geom_abline() +
  xlab("FPR") +
  ylab("TPR")
