library(auk)
library(rARPACK)
library(softImpute)
library(foreach)
library(doParallel)
library(mvtnorm)

fr_data <- auk_ebd("~/Dropbox/citizen_science/citizen_science/ebd_FR_relSep-2020/ebd_FR_relSep-2020.txt")
f_in <- "~/Dropbox/citizen_science/citizen_science/ebd_FR_relSep-2020/ebd_FR_relSep-2020.txt"
# output text file
f_out <- "ebd_filtered_mallard.txt"
ebird_data <- f_in %>% 
  # 1. reference file
  auk_ebd() %>% 
  # 2. define filters
  auk_species(species = "Mallard") %>% 
  auk_country(country = "France") %>% 
  # 3. run filtering
  auk_filter(file = f_out) %>% 
  # 4. read text file into r data frame
  read_ebd()

ebird_data$observation_date <- format(ebird_data$observation_date, "%Y-%m")
ebird_data$observation_count[which(ebird_data$observation_count=="X")] <- NA

## Original data base
obs_id <- unique(ebird_data$observer_id)
n <- length(obs_id) # number of users
site_id <- unique(ebird_data$locality)
I <- length(site_id) # number of ecological sites
date_id <- unique(ebird_data$observation_date)
J <- length(date_id) # number of observation dates

# assign users to 100 groups (servers)
n <- 100
groups <- c(rep(1:99, each=24),rep(100, 89))
groups <- sample(groups, length(groups))
observations <- list()
for(c in 1:n){
  observations[[c]] <- ebird_data[which(ebird_data$observer_id%in%obs_id[groups==c]), c(29, 22, 27, 8)]
  observations[[c]]$observation_count <- as.numeric(observations[[c]]$observation_count)
  ids <- paste(observations[[c]]$locality,observations[[c]]$observation_date, sep="")
  idx_remove <- NULL
  for(id in unique(ids)){
    if(sum(id==ids)>1){
      idx_remove <- c(idx_remove,sample(which(id==ids),sum(id==ids)-1))
    }
  }
  if(!is.null(idx_remove)) observations[[c]] <- observations[[c]][-idx_remove,]
}
Nc <- sapply(observations, function(x) nrow(x))

# Select observations made after "2000-01"
dat <- do.call(rbind.data.frame, observations)
dat <- dat[which(dat$observation_date>"2000-01"),]
obs_id <- unique(dat$observer_id)
n <- length(obs_id)
site_id <- unique(dat$locality)
I <- length(site_id)
date_id <- unique(dat$observation_date)
J <- length(date_id)

site_obs_nb <- sapply(site_id, function(site){
  sum(site==dat$locality)
})
site_id <- site_id[which(site_obs_nb>3)]

time_obs_nb <- sapply(date_id, function(date){
  sum(date==dat$observation_date)
})
date_id <- date_id[which(time_obs_nb>20)]

dat <- dat[which(dat$locality%in%site_id & dat$observation_date%in%date_id),]
obs_id <- unique(dat$observer_id)
n <- length(obs_id)
site_id <- unique(dat$locality)
I <- length(site_id)
date_id <- unique(dat$observation_date)
J <- length(date_id)

# Divide observers into 10 groups
groups <- rmultinom(n, 1, rep(0.1, 10))
groups <- sapply(1:n, function(i) which(groups[,i]==1))
n <- 10
observations <- list()
for(c in 1:n){
  observations[[c]] <- dat[which(dat$observer_id%in%obs_id[groups==c]), ]
  observations[[c]]$observation_count <- as.numeric(observations[[c]]$observation_count)
  ids <- paste(observations[[c]]$locality,observations[[c]]$observation_date, sep="")
  idx_remove <- NULL
  for(id in unique(ids)){
    if(sum(id==ids)>1){
      idx_remove <- c(idx_remove,sample(which(id==ids),sum(id==ids)-1))
    }
  }
  if(!is.null(idx_remove)) observations[[c]] <- observations[[c]][-idx_remove,]
}
Nc <- sapply(observations, function(x) nrow(x))


# Initialize parameters of FedEM
Vt_central <- matrix(0,I,J)
hatS <- matrix(0, I, J)
for(i in 1:nrow(dat)){
  hatS[which(site_id==dat$locality[i]), which(date_id==dat$observation_date[i])] <- dat$observation_count[i]
}
hatS <- matrix(as.numeric(hatS), ncol = ncol(hatS))
hatS[which(is.na(hatS))] <- mean(hatS, na.rm=T)
Ht <- matrix(0, I, J)
# Initialize parameters of low-rank model
Theta <- matrix(0, I, J)
U <- matrix(0, 2*I, nrow=I)
V <- matrix(0, 2*J, nrow=J)
gamma <- 1e-4
tmax <- 2*1e3
nbatch <- 100
alpha <- 1e-3
beta <- 0
r = 2 # rank of the matrix
obj <- NULL

for(t in 1:tmax){
  cat("\r",
      "iteration ",
      t,
      "/",
      tmax,
      " - ",
      round(100 * t / tmax),
      "%")  
  wrk_list <- list()
  # initialize workers
  for (c in 1:n) {
    wrk_list[[c]] <- list()
    wrk_list[[c]]$V <- observations[[c]]
    wrk_list[[c]]$V$observation_count <- 0
    wrk_list[[c]]$Delta <- NULL
  }
  cores = detectCores()
  cl <- makeCluster(cores[1] - 1) #not to overload your computer
  registerDoParallel(cl)
  # sample minibatches and compute stochastic approximations S_ij
  wrk_list <- foreach(k = 1:n) %dopar% {
    idx <- sample(1:nrow(observations[[k]]), min(nbatch, Nc[k]))
    Ick <- observations[[k]]$locality[idx]
    Jck <- observations[[k]]$observation_date[idx]
    Deltat <- data.frame(locality = Ick, observation_date = Jck, observation_count = NA)
    St <- data.frame(locality = Ick, observation_date = Jck, observation_count = NA)
    for(i in 1:length(idx)){
      ik <- Ick[i]
      jk <- Jck[i]
      if(is.na(observations[[k]]$observation_count[which((observations[[k]]$locality == ik) & (observations[[k]]$observation_date == jk))])){
        St$observation_count[which((St$locality == ik) & (St$observation_date == jk))] <- Theta[which(site_id==ik), which(date_id==jk)]
      } else{
        St$observation_count[which((St$locality == ik) & (St$observation_date == jk))] <- observations[[k]]$observation_count[which((observations[[k]]$locality == ik) & (observations[[k]]$observation_date == jk))]
      }
      # Compute local differences Delta
      Deltat$observation_count[which((Deltat$locality == ik) & (Deltat$observation_date == jk))] <- St$observation_count[which((St$locality == ik) & (St$observation_date == jk))] - hatS[which(site_id==ik), which(date_id==jk)]-wrk_list[[k]]$V$observation_count[which((wrk_list[[k]]$V$locality == ik) & (wrk_list[[k]]$V$observation_date == jk))]
      wrk_list[[k]]$V$observation_count[which((wrk_list[[k]]$V$locality == ik) & (wrk_list[[k]]$V$observation_date == jk))] <- wrk_list[[k]]$V$observation_count[which((wrk_list[[k]]$V$locality == ik) & (wrk_list[[k]]$V$observation_date == jk))] + alpha*Deltat$observation_count[which((Deltat$locality == ik) & (Deltat$observation_date == jk))]
    }
    list(
      St = St,
      Delta = Deltat,
      V = wrk_list[[k]]$V
    )
  }
  #stop cluster
  stopCluster(cl)
  
  Delta_ref <- matrix(0, I, J)
  for(k in 1:n){
    Delta <- (hatS - Theta)/n
    for(i in 1:nrow(wrk_list[[k]]$Delta)){
      Delta[which(site_id==wrk_list[[k]]$Delta$site_id[i]), which(date_id==wrk_list[[k]]$Delta$observation_date[i])] <- Delta_ref[which(site_id==wrk_list[[k]]$Deltat$site_id[i]), which(date_id==wrk_list[[k]]$Delta$observation_date[i])] + wrk_list[[k]]$Delta$observation_count[i]/n
    }
    Delta_ref <- Delta_ref + Delta
  }
  # update mean field, sufficient statistics and memory term
  Ht <- beta*Ht + Delta_ref + Vt_central # No need of Vt in distrib imput: all gradients go to 0
  hatS <- hatS + gamma*Ht 
  Vt_central <- Vt_central + alpha*Delta_ref 
  
  # Update the parameters of the low rank model: M step
  suv <- svds(hatS, r)
  Theta <- suv$u%*%diag(suv$d)%*%t(suv$v)
  val <- norm(Ht, type = "F") ^ 2
  obj <- c(obj, val)
}

missing <- matrix(NA, I, J)
for(i in 1:nrow(dat)){
  missing[which(site_id==dat$locality[i]), which(date_id==dat$observation_date[i])] <- dat$observation_count[i]
}
missing <- matrix(as.numeric(missing), ncol = ncol(missing))

tt <- softImpute(missing)
Wsft <- tt$u%*%diag(tt$d)%*%t(tt$v)

imputed <- missing
imputed[is.na(missing)] <- Theta[is.na(missing)]
Theta <- Theta[, order(date_id)]
imputed <- imputed[, order(date_id)]

# plot imputation results
library(scales)
op <- par(no.readonly = TRUE)
par(mar = c(5, 6, 2, 2))
dd <- data.frame(count = colSums(imputed), date=1:ncol(imputed))
plot(dd$date,colSums(imputed), pch="+", 
     col=alpha("blue", 0.3),
     xlab = "observation date",
     ylab = "Estimated bird count",
     cex.main = 2,
     cex.lab = 2,
     cex.axis = 2,
     xaxt="n")
axis(1, dd$date, date_id[order(date_id)], cex.axis = 1)
mod=loess(count~date, dd)
print(mod)
yfit=predict(mod, newdata=dd)
lines(dd$date,yfit, col="blue",lty=2, lwd=3)
par(new = TRUE)
plot(dd$date,colSums(is.na(missing[,order(date_id)])), col=alpha("red", 0.3), pch=2, axes = FALSE, bty = "n", xlab = "", ylab = "")
dd <- data.frame(count = colSums(is.na(missing[,order(date_id)])), date=1:ncol(imputed))
mod=loess(count~date, dd)
print(mod)
yfit=predict(mod, newdata=dd)
lines(dd$date,yfit, col="red",lty=2, lwd=3)
axis(side=4, at = pretty(range(colSums(is.na(missing[,order(date_id)])))), cex.lab = 2,
     cex.axis = 2)
mtext("Number of missing values", side=4, line=3)
#points(1:339, colSums(missing[,order(date_id)], na.rm=T), pch=2)
legend("topright",  legend=c("Estimated count", "Number of NA"), 
       col=c("blue", "red"),
       lty=2,
       lwd = 2,
       cex=1,
       pch=c(3,2))


