setwd("D:/experiment/Conference Paper/ICLR/ICLR 2025/code/table4")
rm(list = ls())
library(MASS)

library(gtools)

dpath          <- file.path("D:/experiment/Conference Paper/ICLR/ICLR 2025/Dataset/")  

d_index <- 6

Dataset       <- c("ailerons_all","bank_all","elevators_all","parkinsons","cpusmall","calhousing") 

savepath      <- paste0("D:/experiment/Conference Paper/ICLR/ICLR 2025/Result/",
                        paste0("IndAlg3-",Dataset[d_index],".txt"))

traindatapath    <- file.path(dpath, paste0(Dataset[d_index], ".train"))                
traindatamatrix  <- as.matrix(read.table(traindatapath))
trdata           <- traindatamatrix[ ,-1]
ylabel           <- traindatamatrix[ ,1]

length_tr        <- nrow(trdata)    
feature_tr       <- ncol(trdata)  
M                <- floor(feature_tr^1.5/2)
length_tr2       <- floor(length_tr/M)*M

reptimes <- 10
k1       <- 2
k0       <- 4
alpha    <- 0.1                          ### 10^{-4}, 10^{-3},...,1
alpha1   <- 50
c        <- 50
lambda   <- 10  ## 0.1,1,10

runtime   <- c(rep(0, reptimes))
errorrate <- c(rep(0, reptimes))

for(re in 1:reptimes)
{
  order     <- sample(1:length_tr2,length_tr2,replace = F)   #dis

  error     <- 0
  
  t1        <- proc.time()                                     #proc.time()
  Batch     <- length_tr2/M

  length_tr <- Batch
  B         <- alpha1*floor((length_tr/feature_tr)^(1/3))
  N         <- floor(length_tr/B)
  eta       <- c*sqrt(2*log(feature_tr)/N)  ### 50, 10, 5, 1,0.5,0.1
  
  for(rr in 1:M)
  {
    p        <- c(rep(1/feature_tr, feature_tr))
    q        <- c(rep(1/feature_tr, feature_tr))
#    beg      <- (rr-1)*Batch + 1
#    end      <- rr*Batch
    for (r in 1:N)
    {
      tilde_c1  <- c(rep(0, feature_tr))
      tilde_c2  <- c(rep(0, feature_tr))
      
      Jb1     <- sample(1:feature_tr, 1, replace=T,prob=p)
      Jb2     <- sample(1:feature_tr, 1, replace=T,prob=q)
      
      #    Vb1     <- c(Jb1)
      
      Jb1_    <- Jb1
      if(Jb1%%2==0)
      {
        Jb1_  <- Jb1-1
        Ub1   <- c(Jb1,Jb1_)
      }
      else
      {
        Jb1_  <- Jb1+1
        Ub1   <- c(Jb1,Jb1_)
      }
      if(Jb2!=Jb1)
        Vb2     <- c(Jb1,Jb2)
      else
        Vb2     <- c(Jb1)
      
      Jb2_    <- Jb2
      if(Jb2%%2==0)
      {
        Jb2_  <- Jb2-1
        Ub2   <- c(Jb2,Jb2_)
      }
      else
      {
        Jb2_  <- Jb2+1
        Ub2   <- c(Jb2,Jb2_)
      }
      
      A <- diag(alpha,length(Vb2),length(Vb2))
      w <- c(rep(0, length(Vb2)))
      sum_g <- c(rep(0, length(Vb2)))
      beg       <- (rr-1)*Batch + (r-1)*B+1
      end       <- (rr-1)*Batch + r*B
      
      for(t in (beg:end))
      {
        A       <- A + trdata[order[t],Vb2] %*% t(trdata[order[t],Vb2])
        w       <- solve(A)%*%sum_g
        sum     <- crossprod(w,trdata[order[t],Vb2])[1,1]
        error   <- error + (sum-ylabel[order[t]])^2
        sum_g   <- sum_g + ylabel[order[t]]*trdata[order[t],Vb2]
      }
      ######################################### updating p_t
      
      ################################## computing tilde_c[Jb1]

      Y         <- ylabel[order[beg:end]]
      sum_Y     <- crossprod(Y,Y)[1,1]
      
      X         <- trdata[order[beg:end],Jb1]
      sum_XY    <- crossprod(Y,X)[1,1]
      sum_X     <- crossprod(X,X)[1,1]
      if(sum_X>0)
        tilde_c1[Jb1] <- (sum_Y-sum_XY^2/sum_X)/B
      else
        tilde_c1[Jb1] <- sum_Y/B
      tilde_c1[Jb1]   <- tilde_c1[Jb1]/sum(p[Ub1])
      ################################## computing tilde_c[Jb1_]
      X         <- trdata[order[beg:end],Jb1_]
      sum_XY    <- crossprod(Y,X)[1,1]
      sum_X     <- crossprod(X,X)[1,1]
      if(sum_X>0)
        tilde_c1[Jb1_] <- (sum_Y-sum_XY^2/sum_X)/B
      else
        tilde_c1[Jb1_] <- sum_Y/B
      tilde_c1[Jb1_]   <- tilde_c1[Jb1_]/sum(p[Ub1])
      
      tilde_p         <- p*exp(-tilde_c1*eta)
      p               <- tilde_p/sum(tilde_p)
      
      ################################## computing tilde_c[Jb2]
      X         <- trdata[order[beg:end],c(Jb1,Jb2)]
      XX        <- t(X) %*% X
      XX_       <- solve(lambda*diag(2)+XX)
      XY        <- Y %*% X
      tem       <- XY %*% XX_
      tilde_c2[Jb2] <- (sum_Y - crossprod(XY,tem)[1,1])/B
      tilde_c2[Jb2] <- tilde_c2[Jb2]/sum(q[Ub2])
      
      ################################## computing tilde_c[Jb2_]
      X         <- trdata[order[beg:end],c(Jb1,Jb2_)]
      XX        <- t(X) %*% X
      XX_       <- solve(lambda*diag(2)+XX)
      XY        <- Y %*% X
      tem       <- XY %*% XX_
      tilde_c2[Jb2_] <- (sum_Y - crossprod(XY,tem)[1,1])/B
      tilde_c2[Jb2_] <- tilde_c2[Jb2_]/sum(q[Ub2])
      
      tilde_p         <- q*exp(-tilde_c2*eta)
      q               <- tilde_p/sum(tilde_p)
    }
    
    
    Jb1     <- sample(1:feature_tr, 1, replace=T,prob=p)
    Jb2     <- sample(1:feature_tr, 1, replace=T,prob=q)
    
    #    Vb1     <- c(Jb1)
    
    Jb1_    <- Jb1
    if(Jb1%%2==0)
    {
      Jb1_  <- Jb1-1
      Ub1   <- c(Jb1,Jb1_)
    }
    else
    {
      Jb1_  <- Jb1+1
      Ub1   <- c(Jb1,Jb1_)
    }
    if(Jb2!=Jb1)
      Vb2     <- c(Jb1,Jb2)
    else
      Vb2     <- c(Jb1)
    
    Jb2_    <- Jb2
    if(Jb2%%2==0)
    {
      Jb2_  <- Jb2-1
      Ub2   <- c(Jb2,Jb2_)
    }
    else
    {
      Jb2_  <- Jb2+1
      Ub2   <- c(Jb2,Jb2_)
    }
    
    A <- diag(1,length(Vb2),length(Vb2))
    w <- c(rep(0, length(Vb2)))
    sum_g <- c(rep(0, length(Vb2)))
    for(t in ((rr-1)*Batch+N*B+1):((rr-1)*Batch+length_tr))
    {
      A       <- A + trdata[order[t],Vb2] %*% t(trdata[order[t],Vb2])
      w       <- solve(A)%*%sum_g
      sum     <- crossprod(w,trdata[order[t],Vb2])[1,1]
      error   <- error + (sum-ylabel[order[t]])^2
      sum_g   <- sum_g + ylabel[order[t]]*trdata[order[t],Vb2]
    }
  }
  
  t2 <- proc.time()
  runtime[re] <- (t2 - t1)[3]
  errorrate[re] <- error/length_tr2
}

save_result <- list(
  note     = c(" the next term are:alg_name--dataname--run_time--tot_run_time--ave_run_time--err_num--all_err_rate--ave_err_rate--sd_time--sd_err"),
  alg_name = c("IndAlg3"),
  dataname = paste0(Dataset[d_index], ".train"),
  sam_num  = length_tr,
  alpha   = alpha,
  eta      = c,
  B       = alpha1,
  ave_run_time = sum(runtime)/reptimes,
  err_num  = errorrate,
  ave_err_rate = sum(errorrate)/reptimes,
  sd_time  <- sd(runtime),
  sd_err    <-sd(errorrate)
)
write.table(save_result,file=savepath,row.names =TRUE, col.names =FALSE, quote = T)

sprintf("the number of sample is %d", length_tr)
sprintf("total running time is %.1f in dataset", sum(runtime))
sprintf("average running time is %.1f in dataset", sum(runtime)/reptimes)
sprintf("the average MSE is %f", sum(errorrate)/reptimes)
sprintf("standard deviation of run_time is %.5f in dataset", sd(runtime))
sprintf("standard deviation of MSE is %.5f in dataset", sd(errorrate))
