rm(list=ls())
setwd("~/two sample test")
source("R/matching.R")
source("R/hypot.R")
source("R/ASE.R")
source("R/roc_revisions.R")
library(nett)
library(Rcpp)
library(Matrix)
library(RcppArmadillo)
library(parallel)
library(ggplot2)

Rcpp::sourceCpp("src/models.cpp")
core_count <- detectCores() - 1
num_reps <- 100

n <- 10000
lambda <- 0.015
eigenvectors <- matrix(c(1, 1, 1, -1), nrow = 2)/ sqrt(2)
eigenvalues <- diag(c(5, 1))

# Covariance matrix
cov_matrix <- eigenvectors %*% eigenvalues %*% t(eigenvectors)
# Rotation matrix for 90 degrees counterclockwise
rotation_matrix <- matrix(c(0, 1, -1, 0), nrow = 2)

set.seed(123) 
X <- matrix(rnorm(n * 2), ncol = 2) %*% chol(cov_matrix)
Y <- matrix(rnorm(n * 2), ncol = 2) %*% chol(rotation_matrix %*% cov_matrix %*% t(rotation_matrix))
Y_alt <- matrix(rnorm(n * 2), ncol = 2)
K <- 2
d_ase <- 2


#NULL
res <- do.call(rbind, mclapply(1:num_reps, function(rep_id) {
  A1 <- sample_rdpg(X, lambda)
  flag = runif(1)
  if (flag > 0.5){
    A2 <- sample_rdpg(Y, lambda)
  } else {
    A2 <- sample_rdpg(X, lambda)
  }

  
  z1hat <- spec_clust(A1, K, niter = 50)
  z2hat <- spec_clust(A2, K, niter = 50)
  
  That <- sbm_tst_old(A1, A2, z1hat, z2hat)
  #OracleThat <- sbm_tst_old(A1, A2, z1, z2)
  #g1 <- log_moment(A1, d_nlcm)
  #g2 <- log_moment(A2, d_nlcm)
  #nlcm_dist <- l2norm_squared(g1 - g2)
  ase_dist <- fast_mmd(ase(A1, d_ase), ase(A2, d_ase), nBasis = 2^9, sigvec = 1)$biased
  
  #log_progress()
  return(data.frame(rep = rep_id, 
                    sbmts_stat = That,
                    ase_stat = ase_dist))
}, mc.cores = core_count))

#Alternative
res_alt <- do.call(rbind, mclapply(1:num_reps, function(rep_id) {
  A1 <- sample_rdpg(X, lambda)
  A2 <- sample_rdpg(Y_alt, lambda)
  
  z1hat <- spec_clust(A1, K, niter = 50)
  z2hat <- spec_clust(A2, K, niter = 50)
  
  
  That <- sbm_tst_old(A1, A2, z1hat, z2hat)
  #OracleThat <- sbm_tst_old(A1, A2, z1, z2)
  #g1 <- log_moment(A1, d_nlcm)
  #g2 <- log_moment(A2, d_nlcm)
  #nlcm_dist <- l2norm_squared(g1 - g2)
  ase_dist <- fast_mmd(ase(A1, d_ase), ase(A2, d_ase), sigvec = 1)$biased
  
  return(data.frame(rep = rep_id, 
                    sbmts_stat = That, 
                    ase_stat = ase_dist))
}, mc.cores = core_count))

sbm_roc <- get_roc(res$sbmts_stat, res_alt$sbmts_stat)
ase_roc <- get_roc(res$ase_stat, res_alt$ase_stat)

ggplot(NULL) +
  geom_line(aes(x = sbm_roc$FPR, y = sbm_roc$TPR, colour = "SBM Test")) +
  geom_line(aes(x = ase_roc$FPR, y = ase_roc$TPR, colour = "ASE")) +
  xlim(0,1) + 
  ylim(0,1) +
  geom_abline() +
  xlab("FPR") +
  ylab("TPR") +
  ggtitle(paste("(X,X_mixture) against (X, Y_alt)"))


ggplot(res, aes(x = sbmts_stat)) + 
  geom_histogram(color="black", fill="white") +
  ggtitle(paste("Distribution of T under null, (X,Y_mixture)"))

ggplot(res_alt, aes(x = sbmts_stat)) + 
  geom_histogram(color="black", fill="white") +
  ggtitle(paste("Distribution of T under alternative"))

ggplot(res, aes(x = ase_stat)) + 
  geom_histogram(color="black", fill="white") +
  ggtitle(paste("Distribution of d_ASE under null, (X,Y_mixture)"))

ggplot(res_alt, aes(x = ase_stat)) + 
  geom_histogram(color="black", fill="white") +
  ggtitle(paste("Distribution of d_ASE under alternative"))
################################################################################



