setwd("D:/experiment/Conference Paper/ICLR/ICLR 2025/code/table4")
rm(list = ls())
library(MASS)

library(gtools)

dpath          <- file.path("D:/experiment/Conference Paper/ICLR/ICLR 2025/Dataset/")   

d_index <- 4

Dataset       <- c("ailerons_all","bank_all","elevators_all","parkinsons","cpusmall","calhousing") 

savepath      <- paste0("D:/experiment/Conference Paper/ICLR/ICLR 2025/Result/",
                        paste0("FedAMRO-",Dataset[d_index],".txt"))

traindatapath    <- file.path(dpath, paste0(Dataset[d_index], ".train"))                
traindatamatrix  <- as.matrix(read.table(traindatapath))
trdata           <- traindatamatrix[ ,-1]
ylabel           <- traindatamatrix[ ,1]

length_tr        <- nrow(trdata)   
feature_tr       <- ncol(trdata)  
M                <- floor(feature_tr^1.5/2)
length_tr        <- floor(length_tr/M)*M

reptimes <- 10
b        <- 2
b_       <- b+2
comb     <- combinations(feature_tr, b)
N        <- nrow(comb)
U        <- 0.5
X        <- 1
Y        <- 1
#C        <- (U*X+Y)^2
C        <- 0.5
coe      <- 4
sigma    <- 0.1

runtime   <- c(rep(0, reptimes))
errorrate <- c(rep(0, reptimes))

for(re in 1:reptimes)
{
  order     <- sample(1:length_tr,length_tr,replace = F)   #dis
  tilde_c   <- c(rep(0, N))
  p         <- c(rep(1/N, N))
  q         <- c(rep(1/(feature_tr-b), (feature_tr-b)))
  delta_t   <- matrix(0,nrow = M, ncol=feature_tr)
  tilde_x   <- c(rep(0, feature_tr))
  hat_x     <- c(rep(0, feature_tr))
  z_x       <- matrix(0,nrow = M, ncol=feature_tr)
  dimension <- c(M, feature_tr, feature_tr)
  inst_X    <- array(c(rep(0, M*feature_tr^2)),dimension)
  
  Selt_num  <- matrix(0,nrow = M,ncol=feature_tr)
  w         <- matrix(0,nrow = N,ncol=b)
  mu        <- (4*feature_tr-3*b-b_)/(b_-b)
  mu_2      <- (2*feature_tr-b_-b)/(b_-b)
  xi_1      <- 2*(C-Y^2+(mu^2*U^2*X^2+2*mu*U*X*Y))^(0.5)
  alpha     <- (C-Y^2)^2/4+(2*mu_2^2*U^4*X^4+1*mu_2*U^2*X^2*C+(C-Y^2)^2/8)/M
  max_loss  <- 0.01
  xi_2      <- 2*mu_2^2*U^2*X^4/M +(2*mu_2+1)*X^2*C/16/M+C*X^2/4
  It        <- c(rep(1, M))
  error     <- 0
  
  t1    <- proc.time()                                     #proc.time()
  
  for(t in 1:(length_tr/M))
  {
    bar_c      <- c(rep(0, N))
    eta_t      <- coe*sqrt(log(N))/sqrt(max_loss*log(N)+alpha*t)
    lambda     <- 1/sqrt(xi_2*t)
    ###################### server selects I^{(j)}_t
    for(j in 1:M)
    {
      It[j]    <- sample(1:N, 1, replace=T,prob=p)
    }
    
    for(j in 1:M)
    {
      It_       <- It[j]
      ind       <- order[M*(t-1)+j]
      #################### prediction on clients 
      xt        <- trdata[ind,]
      Y_j       <- crossprod(w[It_,],xt[comb[It_,]])[1,1]
      error     <- error + (Y_j-ylabel[ind])^2
      
      subset_   <- setdiff(1:feature_tr,comb[It_,])
      J1        <- sample(subset_, 1, replace=T,prob=q)
      J2        <- sample(subset_, 1, replace=T,prob=q)
      nu11      <- rbinom(1,1,sigma)
      nu12      <- rbinom(1,1,sigma)
      nu21      <- rbinom(1,1,sigma)
      nu22      <- rbinom(1,1,sigma)
      
      tilde_x           <- delta_t[j,]
      tilde_x[J1]       <- (feature_tr-b)*(xt[J1] -delta_t[j,J1]) + delta_t[j,J1]

      hat_x             <- delta_t[j,]
      hat_x[J2]         <- (feature_tr-b)*(xt[J2] -delta_t[j,J2]) + delta_t[j,J2]
      if(nu11==1)
      {
        tilde_x[comb[It_,1]] <- (xt[comb[It_,1]] -delta_t[j,comb[It_,1]])/sigma + delta_t[j,comb[It_,1]]
      }
      if(nu12==1)
      {
        tilde_x[comb[It_,2]] <- (xt[comb[It_,2]] -delta_t[j,comb[It_,2]])/sigma + delta_t[j,comb[It_,2]]
      }
      if(nu21==1)
      {
        hat_x[comb[It_,1]]   <- (xt[comb[It_,1]] -delta_t[j,comb[It_,1]])/sigma + delta_t[j,comb[It_,1]]
      }
      if(nu22==1)
      {
        hat_x[comb[It_,2]]   <- (xt[comb[It_,2]] -delta_t[j,comb[It_,2]])/sigma + delta_t[j,comb[It_,2]]
      }        
      
      z_x[j,]                <- ylabel[ind]*(tilde_x+hat_x)
      inst_X[j,,]            <- tilde_x %*% t(hat_x)
      
      ######################################### updating delta_t
      obsered_id    <- c(comb[It_,],J1)
      if(J1 != J2)
        obsered_id  <- c(obsered_id,J2)
      Selt_num[j,obsered_id] <- Selt_num[j,obsered_id] + 1
      for(k in obsered_id)
      {
        delta_t[j,k]  <- delta_t[j,k]*(Selt_num[j,k]-1)/Selt_num[j,k] + xt[k]/Selt_num[j,k]
      }
    }
    
    ################################## federated updating on server
    sum_X      <- matrix(0,nrow = feature_tr, ncol=feature_tr)
    sum_z      <- c(rep(0, feature_tr))
    
    for(j in 1:M)
    {
      X_j      <- inst_X[j,,]
      sum_X    <- sum_X + X_j/M
      sum_z    <- sum_z + z_x[j,]/M
      
      for(i in 1:N)
      {
        Xw_ji       <- X_j[comb[i,],comb[i,]] %*% w[i,]
        nabla_ji    <- 2*Xw_ji - z_x[j,comb[i,]]
        tilde_c[i]  <- crossprod(w[i,],Xw_ji)[1,1] - crossprod(w[i,],z_x[j,comb[i,]])[1,1]
      }
      bar_c       <- bar_c + tilde_c/M
    }
    
    for(i in 1:N)
    {
      Xw          <- sum_X[comb[i,],comb[i,]] %*% w[i,]
      nabla       <- 2*Xw - sum_z[comb[i,]]
      w[i,]       <- w[i,] - lambda*nabla
      Norm        <- sqrt(crossprod(w[i,],w[i,])[1,1])
      if(Norm >U)
      {
        w[i,]     <- w[i,]*U/Norm
      }
    }  
    
    ######################################### solving p_t
    beta_t    <- 1/t
    tilde_p   <- p*exp(-bar_c*eta_t)
    p_t       <- tilde_p/sum(tilde_p)
    A         <- which(p_t<beta_t/N)
    while(length(A)>0)
    {
      A_      <- setdiff(1:N,A)
      tem_sum <-sum(tilde_p[A_])
      z_t     <- beta_t*tem_sum/(N-length(A)*beta_t)
      p_t[A]  <- beta_t/N
      p_t[A_] <- tilde_p[A_]/(tem_sum+length(A)*z_t)
      New_A   <- which(p_t<beta_t/N)
      if(length(New_A)==0)
        A <- factor()
      else
        A <- union(A,New_A)
    }
    p         <- p_t
    max_loss  <- max(0,max(bar_c^2))
  }
  
  t2 <- proc.time()
  runtime[re] <- (t2 - t1)[3]
  errorrate[re] <- error/length_tr
}

save_result <- list(
  note     = c("the next term are:alg_name--dataname--run_time--tot_run_time--ave_run_time--err_num--all_err_rate--ave_err_rate--sd_time--sd_err"),
  alg_name = c("FedAMRO"),
  dataname = paste0(Dataset[d_index], ".train"),
  sam_num  = length_tr,
  run_time = as.character(runtime),
  U        = U,
  C        = C,
  ave_run_time = sum(runtime)/reptimes,
  err_num  = errorrate,
  ave_err_rate = sum(errorrate)/reptimes,
  sd_time  <- sd(runtime),
  sd_err    <-sd(errorrate)
)
write.table(save_result,file=savepath,row.names =TRUE, col.names =FALSE, quote = T)

sprintf("the number of sample is %d", length_tr)
sprintf("total running time is %.1f in dataset", sum(runtime))
sprintf("average running time is %.1f in dataset", sum(runtime)/reptimes)
sprintf("the average MSE is %f", sum(errorrate)/reptimes)
sprintf("standard deviation of run_time is %.5f in dataset", sd(runtime))
sprintf("standard deviation of MSE is %.5f in dataset", sd(errorrate))
