library(tictoc)
source("KCD.R")
source("weighted_graph.R")
source("KCE.R")

n = 1000
n0 = 50
n1 = 950
sd = 1
cp = 700

simulations = 20
bootstrap_n = 500

KCD_power = rep(NA, simulations)
RKCD_power = rep(NA, simulations)
DY_power = rep(NA, simulations)
DXY_power = rep(NA, simulations)
fixed_power = rep(NA, simulations)

KCD_val = rep(NA, simulations)
KCD_loc = rep(NA, simulations)
DXY_val = rep(NA, simulations)
DXY_loc = rep(NA, simulations)
RKCD_val = rep(NA, simulations)
RKCD_loc = rep(NA, simulations)
DY_val = rep(NA, simulations)
DY_loc = rep(NA, simulations)
fixed_val = rep(NA, simulations)
fixed_loc = rep(NA, simulations)

hx_RKCD = 0.1
hy_RKCD = 0.1
hx_KCD = 0.1
hy_KCD = 0.1
hx_DXY = 0.1
hy_DXY = 0.1
hy_DY = 0.1

type = 'noise'

for (j in 1:simulations){
  
  print(paste("this is the ", j, '-th simulation'))
  
  
  if (type == 'noise'){ # y is scalar: noise change
    
    x = rnorm(n)
    noise = rnorm(n, sd = sd)
    truth = x
    truth[cp:n] = truth[cp:n] + 0.05 / (truth[cp:n] + 3)
    # truth[cp:n] = 2 * truth[cp:n]
    # truth[cp:n] = truth[cp:n]^2
    # truth[cp:n] = 0.01 * exp(truth[cp:n]) + truth[cp:n]
    # truth[cp:n] = 0.01 * max(1 - truth[cp:n], 0) + truth[cp:n]
    y = truth * noise
    
    dy = as.matrix(dist(y)^2)
    
  }else if (type == 'vector'){ # y is vector
    
    x = matrix(rnorm(5 * n), nrow = n, ncol = 5)
    y = matrix(rnorm(5 * n), nrow = n, ncol = 5)
    
    tmp = apply(x, 1, sum)
    for (i in 1:5){
      y[cp:n, i] = y[cp:n, i] + 0.1 * i * tmp[cp:n] * lambda
    }
    
    dy = as.matrix(dist(y)^2)
    
  }else if (type == 'graph'){ # y is graph
    x = matrix(rnorm(5 * n), nrow = n, ncol = 5)
    y = matrix(rnorm(5 * 5 * n), nrow = n, ncol = 25)
    
    tmp = apply(x, 1, sum)
    for (i in 1:n){
      if (i <= cp){
        for (rowId in 0:4){
          for (colId in 1:5){
            y[i, rowId * 5 + colId] = x[i, rowId + 1] * x[i, colId]
          }
        }
      }
      if (i > cp){
        for (rowId in 0:4){
          for (colId in 1:5){
            y[i, rowId * 5 + colId] =  x[i, rowId + 1] * x[i, colId] + 0.1 * (x[i, rowId + 1])^3 * (x[i, colId])^3
          }
        }
      }
    }  
    
    dy = as.matrix(dist(y)^2)
    
  }else if (type == 'distribution'){ # y is distribution
    
    library(transport)
    
    x = matrix(rnorm(5 * n), nrow = n, ncol = 5)
    y = matrix(rnorm(10 * n), nrow = n, ncol = 10)
    
    tmp = apply(x, 1, sum)
    for (i in 1:n){
      if (i <= cp){
        y[i,] = rnorm(n = 10, mean = 0, sd = 1)
      }
      if (i > cp){
        y[i,] = rnorm(n = 10, mean = mean(x[i,]) * lambda, sd = 1)
      }
    }  
    
    dy = matrix(NA, nrow = n, ncol = n)
    for (m in 1:n){
      for (l in 1:n){
        # d[m,j] = ks.test(x[m,], x[j,])$statistic
        dy[m, l] = wasserstein1d(y[m,], y[l,], p = 2)^2
      }
    }
    
  }
  
  dx = as.matrix(dist(x)^2)
  Kx_RKCD = exp( - dx / hx_RKCD)
  Ky_RKCD = exp( - dy / hy_RKCD)
  Kx_KCD = exp( - dx / hx_KCD)
  Ky_KCD = exp( - dy / hy_KCD)
  K_DXY = exp( - dx / hx_DXY) * exp( - dy / hy_DXY)
  K_DY = exp( - dy / hy_DY)
  
  this_res = all_stats_general(Kx_RKCD, Ky_RKCD, Kx_KCD, Ky_KCD, K_DY, K_DXY, n0, n1, h_fixed = h_fixed, Clive = TRUE)
  
  KCD_val[j] = this_res$KCD_val
  KCD_loc[j] = this_res$KCD_loc
  
  DXY_val[j] = this_res$DXY_val
  DXY_loc[j] = this_res$DXY_loc
  
  RKCD_val[j] = this_res$RKCD_val
  RKCD_loc[j] = this_res$RKCD_loc
  
  DY_val[j] = this_res$DY_val
  DY_loc[j] = this_res$DY_loc
  
  fixed_val[j] = this_res$fixed_val
  fixed_loc[j] = this_res$fixed_loc
  
  # ------------------- calculate power ---------------------
  KCD_power_seq = rep(NA, bootstrap_n)
  RKCD_power_seq = rep(NA, bootstrap_n)
  DY_power_seq = rep(NA, bootstrap_n)
  DXY_power_seq = rep(NA, bootstrap_n)
  fixed_power_seq = rep(NA, bootstrap_n)
  
  if (bootstrap_n > 0){
    for (r in 1:bootstrap_n){
      tic(paste("this is the", r, "-th bootstrap repetition"))
      
      new_index = sample(1:n, size = n)
      new_y = y[new_index]
      new_x = x[new_index]
      if (type != 'distribution'){
        new_dy = as.matrix(dist(new_y)^2)
      }else{
        new_dy = matrix(NA, nrow = n, ncol = n)
        for (m in 1:n){
          for (l in 1:n){
            # d[m,j] = ks.test(x[m,], x[j,])$statistic
            new_dy[m, l] = wasserstein1d(y[m,], y[l,], p = 2)^2
          }
        }
        
      }
      new_dx = as.matrix(dist(new_x)^2)
      
      new_Kx_RKCD = exp( - new_dx / hx_RKCD)
      new_Ky_RKCD = exp( - new_dy / hy_RKCD)
      new_Kx_KCD = exp( - new_dx / hx_KCD)
      new_Ky_KCD = exp( - new_dy / hy_KCD)
      new_Kx1 = exp( - new_dx / hx_DXY)
      new_Ky1 = exp( - new_dy / hy_DXY)
      new_K_DXY = new_Kx1 * new_Ky1
      new_K_DY = exp( - new_dy / hy_DY)
      
      permute_res = all_stats_general(new_Kx_RKCD, new_Ky_RKCD, new_Kx_KCD, new_Ky_KCD, new_K_DY, new_K_DXY, n0 = n0, n1 = n1)
      
      KCD_power_seq[r] = permute_res$KCD_val
      RKCD_power_seq[r] = permute_res$RKCD_val
      DY_power_seq[r] = permute_res$DY_val
      DXY_power_seq[r] = permute_res$DXY_val
      
      toc()
    }
  }
  
  KCD_power[j] = mean(KCD_val[j] > KCD_power_seq)
  RKCD_power[j] = mean(RKCD_val[j] < RKCD_power_seq)
  DY_power[j] = mean(DY_val[j] < DY_power_seq)
  DXY_power[j] = mean(DXY_val[j] < DXY_power_seq)
  
}


mean(1 - DY_power < 0.05)
mean(1 - DXY_power < 0.05)
mean(1 - KCD_power < 0.05)
mean(1 - RKCD_power < 0.05)

mean(abs(DY_loc - cp))
mean(abs(DXY_loc - cp))
mean(abs(RKCD_loc - cp))
mean(abs(KCD_loc - cp))
mean(abs(fixed_loc - cp))

sd(abs(DY_loc - cp)) / sqrt(simulations)
sd(abs(DXY_loc - cp)) / sqrt(simulations)
sd(abs(RKCD_loc - cp)) / sqrt(simulations)
sd(abs(KCD_loc - cp)) / sqrt(simulations)
sd(abs(fixed_loc - cp)) / sqrt(simulations)