setwd("D:/experiment/Conference Paper/ICLR/ICLR 2025/code/table4")
rm(list = ls())
library(MASS)

library(gtools)

dpath          <- file.path("D:/experiment/Conference Paper/ICLR/ICLR 2025/Dataset/")  

d_index <- 5

Dataset       <- c("ailerons_all","bank_all","elevators_all","parkinsons","cpusmall","calhousing") 

savepath      <- paste0("D:/experiment/Conference Paper/ICLR/ICLR 2025/Result/",
                        paste0("IndAlg1-",Dataset[d_index],".txt"))

traindatapath    <- file.path(dpath, paste0(Dataset[d_index], ".train"))                
traindatamatrix  <- as.matrix(read.table(traindatapath))
trdata           <- traindatamatrix[ ,-1]
ylabel           <- traindatamatrix[ ,1]

length_tr        <- nrow(trdata)    
feature_tr       <- ncol(trdata)  
M                <- floor(feature_tr^1.5/2)
length_tr        <- floor(length_tr/M)*M

reptimes <- 10
b        <- 2
comb     <- combinations(feature_tr, b)
N        <- nrow(comb)
C        <- 1
U        <- 1  # 0.1, 0.5, 1
coe      <- 4

runtime   <- c(rep(0, reptimes))
errorrate <- c(rep(0, reptimes))

for(re in 1:reptimes)
{
  order    <- sample(1:length_tr,length_tr,replace = F)   #dis
  t1       <- proc.time()                                     #proc.time()
  Batch    <- length_tr/M
  error    <- 0
  q        <- c(rep(1/(feature_tr-b), (feature_tr-b)))
  for(r in 1:M)
  {
    tilde_c  <- c(rep(0, N))
    p        <- c(rep(1/N, N))
    delta_t  <- c(rep(0, feature_tr))
    Selt_num <- c(rep(0, feature_tr))
    w        <- matrix(0,nrow = N,ncol=b)
    lambda   <- c(rep(0, N))
    Sum_grad <- c(rep(1, N))
    Sum_loss <- 0
    max_loss <- 0.01
    beg      <- (r-1)*Batch + 1
    end      <- r*Batch
    flag     <- 0
    for (t in beg:end)
    {
      flag   <- flag + 1
      beta_t <- C/flag
      It     <- sample(1:N, 1, replace=T,prob=p)
      xt     <- trdata[order[t],]
      sum    <- crossprod(w[It,],xt[comb[It,]])[1,1]
      error  <- error + (sum-ylabel[order[t]])^2
      
      subset_ <- setdiff(1:feature_tr,comb[It,])
      J1      <- sample(subset_, 1, replace=T,prob=q)
      J2      <- sample(subset_, 1, replace=T,prob=q)
      
      tilde_x            <- delta_t
      tilde_x[J1]        <- (feature_tr-b)*(xt[J1] -delta_t[J1]) + delta_t[J1]
      tilde_x[comb[It,]] <- xt[comb[It,]]
      hat_x              <- delta_t
      hat_x[J2]          <- (feature_tr-b)*(xt[J2] -delta_t[J2]) + delta_t[J2]
      hat_x[comb[It,]]   <- xt[comb[It,]]
      
      ################################## updating w_{t,i}
      for(i in 1:N)
      {
        tem1 <- (w[i,]%*%tilde_x[comb[i,]])[1,1] - ylabel[order[t]]
        tem2 <- (w[i,]%*%hat_x[comb[i,]])[1,1]   - ylabel[order[t]]
        tilde_c[i]  <- tem1*tem2 - ylabel[order[t]]^2
        
        nabla       <- tem1*hat_x[comb[i,]] + tem2*tilde_x[comb[i,]]
        Sum_grad[i] <- Sum_grad[i] + crossprod(nabla,nabla)
        lambda[i]   <- U/sqrt(Sum_grad[i]) 
        w[i,]       <- w[i,] - lambda[i]*nabla
        Norm        <- sqrt(crossprod(w[i,],w[i,])[1,1])
        if(Norm >U)
        {
          w[i,]     <- w[i,]*U/Norm
        }
      }
      
      ######################################### updating delta_t
      obsered_id    <- c(comb[It,],J1)
      if(J1 != J2)
        obsered_id  <- c(obsered_id,J2)
      Selt_num[obsered_id] <- Selt_num[obsered_id] + 1
      for(j in obsered_id)
      {
        delta_t[j]  <- delta_t[j]*(Selt_num[j]-1)/Selt_num[j] + xt[j]/Selt_num[j]
      }
      
      ######################################### solving p_t
      max_loss <- max(max_loss,max(tilde_c^2))
      Sum_loss <- Sum_loss + crossprod(p,tilde_c^2)[1,1]
      eta_t    <- coe*sqrt(log(N))/sqrt(max_loss*log(N)+Sum_loss)
      tilde_p  <- p*exp(-tilde_c*eta_t)
      p_t      <- tilde_p/sum(tilde_p)
      A        <- which(p_t<beta_t/N)
      while(length(A)>0)
      {
        A_      <- setdiff(1:N,A)
        tem_sum <-sum(tilde_p[A_])
        z_t     <- beta_t*tem_sum/(N-length(A)*beta_t)
        p_t[A]  <- beta_t/N
        p_t[A_] <- tilde_p[A_]/(tem_sum+length(A)*z_t)
        New_A   <- which(p_t<beta_t/N)
        if(length(New_A)==0)
          A <- factor()
        else
          A <- union(A,New_A)
      }
      p <- p_t
    }
  }
  t2 <- proc.time()
  

  runtime[re] <- (t2 - t1)[3]
  errorrate[re] <- error/length_tr
}

save_result <- list(
  note     = c("the next term are:alg_name--dataname--run_time--tot_run_time--ave_run_time--err_num--all_err_rate--ave_err_rate--sd_time--sd_err"),
  alg_name = c("IndAlg1"),
  dataname = paste0(Dataset[d_index], ".train"),
  sam_num  = length_tr,
  run_time = as.character(runtime),
  U        = U,
  ave_run_time = sum(runtime)/reptimes,
  err_num  = errorrate,
  ave_err_rate = sum(errorrate)/reptimes,
  sd_time  <- sd(runtime),
  sd_err    <-sd(errorrate)
)
write.table(save_result,file=savepath,row.names =TRUE, col.names =FALSE, quote = T)

sprintf("the number of sample is %d", length_tr)
sprintf("total running time is %.1f in dataset", sum(runtime))
sprintf("average running time is %.1f in dataset", sum(runtime)/reptimes)
sprintf("the average MSE is %f", sum(errorrate)/reptimes)
sprintf("standard deviation of run_time is %.5f in dataset", sd(runtime))
sprintf("standard deviation of MSE is %.5f in dataset", sd(errorrate))
