setwd("D:/experiment/Conference Paper/NeurIPS2024/code")
rm(list = ls())
library(MASS)
 
d_index <- 8

dpath            <- file.path("D:/experiment/dataset/regression/") 

Dataset          <- c("elevators_all","bank_all","slice_all","Year_test",
                    "ailerons_all","calhousing","N-twitter","N-TomsHardware")

savepath         <- paste0("D:/experiment/Conference Paper/NeurIPS2024/Result/",
                        paste0("NCO-OMS-K-",Dataset[d_index],".txt"))

traindatapath    <- file.path(dpath, paste0(Dataset[d_index], ".train"))                
traindatamatrix  <- as.matrix(read.table(traindatapath))
trdata           <- traindatamatrix[ ,-1]
ylabel           <- traindatamatrix[ ,1]

length_tr        <- nrow(trdata) 
length_tr        <- floor(length_tr/10)*10
feature_tr       <- ncol(trdata)  

reptimes  <- 10

runtime   <- c(rep(0, reptimes))
errorrate <- c(rep(0, reptimes))

U         <- seq(0.1,1.0,by=0.1)
M         <- 10                   ################ the number of clients
K         <- length(U)            ################ the number of candidate hypothesis spaces
J         <- K                    ################ the number of selected hypothesis

Y         <- 1
C         <- (U+Y)^2
eta1      <- sqrt(log(K*length_tr))/sqrt(4*length_tr*(1+(K-J)/(J-1)))
eta_t     <- 1*min(eta1,(J-1)/(2*K-2*J))

for(re in 1:reptimes)
{
  W       <- matrix(0,nrow=K*M, ncol=feature_tr)         #### store the models of M clients
  order   <- sample(1:length_tr,length_tr,replace = F)   #dis
  Cum_loss<- 0
  flag    <- 0
  p       <- matrix(1/sqrt(K*length_tr),nrow=M, ncol=K) ################ the sampling probability
  p[,1]   <- 1-sqrt(K)/sqrt(length_tr)+1/(sqrt(K*length_tr)) 
  t1      <- proc.time()                                 #proc.time()
  
  for(t in 1:(length_tr/10))
  {
    for(j in 1:M)
    {
      ct    <- c(rep(0,K))
      #################### sampling hypotheses ###################################
      It    <- sample(1:K, 1, replace=T,prob=p[j,])

      #################### read the models of the j-th client and the data received by the j-th client 
      xt    <- trdata[order[M*(t-1)+j],]
      #################### prediction 
      for(i in 1:K)
      {
        y_i     <- crossprod(W[K*(j-1)+i,],xt)[1,1]
        #################### prediction       
        ins_loss   <- (y_i-ylabel[order[M*(t-1)+j]])^2
        if( i == It)
          Cum_loss <- Cum_loss + ins_loss
        
        ct[i]            <- ins_loss
        G_i              <- 2*(U[i]+Y)
        lambda_ti        <- U[i]/(2*G_i*sqrt(1+(K-J)/(J-1)))/sqrt(max((K-J)^2/(J-1)^2,t))
        
        W[K*(j-1)+i,]    <- W[K*(j-1)+i,] -lambda_ti*2*(y_i-ylabel[order[M*(t-1)+j]])*xt
        norm_i           <- sqrt(crossprod(W[K*(j-1)+i,],W[K*(j-1)+i,])[1,1])
        if(norm_i>U[i])
        {
          W[K*(j-1)+i,] <- W[K*(j-1)+i,]*U[i]/norm_i
        }        
      }
      ##################### update the sampling probability using binary search
      upper_lambda   <- 0    #### lambda^\ast < 0
      lambda_ast     <- upper_lambda
      tem            <- p[j,]*exp(-eta_t*(lambda_ast+ct)/C)
      sum_barpt      <- sum(tem)
      if(abs(sum_barpt - 1)>1e-5)
      {
        lower_lambda <- -max(ct)
        lambda_ast   <- lower_lambda
        tem          <- p[j,]*exp(-eta_t*(lambda_ast+ct)/C)
        sum_barpt    <- sum(tem)
        if(abs(sum_barpt - 1)>1e-5)
        {
          lambda_ast <- (lower_lambda+upper_lambda)/2 
          tem        <- p[j,]*exp(-eta_t*(lambda_ast+ct)/C)
          sum_barpt  <- sum(tem)
        }
        while(abs(sum_barpt - 1)>1e-5)
        {
          flag <- flag +1
          if(sum_barpt<1)      ############ lambda^ast should decrease
          {
            upper_lambda  <- lambda_ast
          }else{                        ############ lambda^ast should increase
            lower_lambda  <- lambda_ast
          }
          lambda_ast      <- (lower_lambda+upper_lambda)/2
          tem             <- p[j,]*exp(-eta_t*(lambda_ast+ct)/C)
          sum_barpt       <- sum(tem)
        }
      }
      p[j,] <- tem
    }
  }
  t2            <- proc.time()
  runtime[re]   <- (t2 - t1)[3]
  errorrate[re] <- Cum_loss/length_tr
}

save_result <- list(
  note     = c(" the next term are:alg_name--dataname--sam_num--run_time--tot_run_time--ave_run_time--err_num--all_err_rate--ave_err_rate--sd_time--sd_err"),
  alg_name = c("NCO-OMS-"),
  dataname = paste0(Dataset[d_index], ".train"),
  sam_num  = length_tr,
  run_time = as.character(runtime),
  tot_run_time = sum(runtime),
  ave_run_time = sum(runtime)/reptimes,
  err_num      = errorrate,
  ave_err_rate = sum(errorrate)/reptimes,
  sd_time      <- sd(runtime),
  sd_err       <-sd(errorrate)
)
write.table(save_result,file=savepath,row.names =TRUE, col.names =FALSE, quote = T)

sprintf("the number of sample is %d", length_tr)
sprintf("total running time is %.1f in dataset", sum(runtime))
sprintf("average running time is %.1f in dataset", sum(runtime)/reptimes)
sprintf("the average AMR is %f", sum(errorrate)/reptimes)
sprintf("standard deviation of run_time is %.5f in dataset", sd(runtime))
sprintf("standard deviation of AMR is %.5f in dataset", sd(errorrate))
