########################################
########################################
#
#       Clustering toy3 example
#
########################################
########################################

library(Canopy)
data(toy3)
R=toy3$R; X=toy3$X
dim(R);dim(X)
num_cluster=2:9 # Range of number of clusters to run
num_run=10 # How many EM runs per clustering step for each mutation cluster wave
canopy.cluster=canopy.cluster(R = R,
                              X = X,
                              num_cluster = num_cluster,
                              num_run = num_run)

# BIC to determine the optimal number of mutation clusters
bic_output=canopy.cluster$bic_output
plot(num_cluster,bic_output,xlab='Number of mutation clsuters',ylab='BIC',type='b',main='BIC for model selection')
abline(v=num_cluster[which.max(bic_output)],lty=2)

# Visualization of clustering result
Mu=canopy.cluster$Mu # VAF centroid for each cluster
Tau=canopy.cluster$Tau  # Prior for mutation cluster, with a K+1 component
sna_cluster=canopy.cluster$sna_cluster # cluster identity for each mutation
colc=c('green4','red3','royalblue1','darkorange1','royalblue4',
       'mediumvioletred','seagreen4','olivedrab4','steelblue4','lavenderblush4')
pchc=c(17,0,1,15,3,16,4,8,2,16)
plot((R/X)[,1],(R/X)[,2],xlab='Sample1 VAF',ylab='Sample2 VAF',col=colc[sna_cluster],pch=pchc[sna_cluster],ylim=c(0,max(R/X)),xlim=c(0,max(R/X)))
library(scatterplot3d)
scatterplot3d((R/X)[,1],(R/X)[,2],(R/X)[,3],xlim=c(0,max(R/X)),ylim=c(0,max(R/X)),zlim=c(0,max(R/X)),color=colc[sna_cluster],pch=pchc[sna_cluster],
              xlab='Sample1 VAF',ylab='Sample2 VAF',zlab='Sample3 VAF')


########################################
########################################
#
#       Clustering AML43
#
########################################
########################################

library(Canopy)
library(scatterplot3d)
data(AML43)
R=AML43$R; X=AML43$X
dim(R);dim(X)
num_cluster=4 # Range of number of clusters to run
num_run=10 # How many EM runs per clustering step for each mutation cluster wave
Tau_Kplus1=0.05
Mu.init=cbind(c(0.01,0.15,0.25,0.45),c(0.2,0.2,0.01,0.2))
canopy.cluster=canopy.cluster(R = R,
                              X = X,
                              num_cluster = num_cluster,
                              num_run = num_run,
                              Mu.init = Mu.init,
                              Tau_Kplus1=Tau_Kplus1)

# Visualization of clustering result
Mu=canopy.cluster$Mu # VAF centroid for each cluster
Tau=canopy.cluster$Tau  # Prior for mutation cluster, with a K+1 component
sna_cluster=canopy.cluster$sna_cluster # cluster identity for each mutation
colc=c('green4','red3','royalblue1','darkorange1','royalblue4',
       'mediumvioletred','seagreen4','olivedrab4','steelblue4','lavenderblush4')
pchc=c(17,0,1,15,3,16,4,8,2,16)
plot((R/X)[,1],(R/X)[,2],xlab='Sample1 VAF',ylab='Sample2 VAF',col=colc[sna_cluster],pch=pchc[sna_cluster],ylim=c(0,max(R/X)),xlim=c(0,max(R/X)))

table(sna_cluster) # the 5th cluster corresponds to the noise component

R=R[sna_cluster<=4,] # exclude mutations in the noise cluster
X=X[sna_cluster<=4,]
sna_cluster=sna_cluster[sna_cluster<=4]

R.cluster=round(Mu*100)  # Generate pseudo-SNAs correponding to each cluster. 
X.cluster=pmax(R.cluster,100)   # Total depth is set at 100 here but can be obtained as median across mutations in the cluster.
rownames(R.cluster)=rownames(X.cluster)=paste('SNA.cluster',1:4,sep='')


#######################################################
#######################################################
#######                                         #######
#######             CNA and SNA input           #######
#######                                         #######
#######################################################
#######################################################
library(Canopy)
data("MDA231")
projectname = MDA231$projectname ## name of project
R = MDA231$R; R ## mutant allele read depth (for SNAs)
X = MDA231$X; X ## total depth (for SNAs)
WM = MDA231$WM; WM ## observed major copy number (for CNA regions)
Wm = MDA231$Wm; Wm ## observed minor copy number (for CNA regions)
epsilonM = MDA231$epsilonM ## standard deviation of WM, pre-fixed here
epsilonm = MDA231$epsilonm ## standard deviation of Wm, pre-fixed here
## whether CNA regions harbor specific CNAs (only needed for overlapping CNAs)
C = MDA231$C; C
Y = MDA231$Y; Y ## whether SNAs are affected by CNAs


#######################################################
#######################################################
#######                                         #######
#######               MCMC sampling             #######
#######                                         #######
#######################################################
#######################################################
K = 3:5 # number of subclones
numchain = 15 # number of chains with random initiations
sampchain = canopy.sample(R = R, X = X, WM = WM, Wm = Wm, epsilonM = epsilonM, 
                          epsilonm = epsilonm, C = C, Y = Y, K = K, 
                          numchain = numchain, max.simrun = 100000,
                          min.simrun = 20000, writeskip = 200,
                          projectname = projectname, cell.line = TRUE,
                          plot.likelihood = TRUE)
save.image(file = paste(projectname, '_postmcmc_image.rda',sep=''),
           compress = 'xz')


#######################################################
#######################################################
#######                                         #######
#######   BIC to determine number of subclones  #######
#######                                         #######
#######################################################
#######################################################
library(Canopy)
projectname='MDA231'
load(paste(projectname, '_postmcmc_image.rda', sep=''))
burnin = 100
thin = 5
# If pdf = TRUE, a pdf will be generated.
bic = canopy.BIC(sampchain = sampchain, projectname = projectname, K = K,
                 numchain = numchain, burnin = burnin, thin = thin, pdf = TRUE)
optK = K[which.max(bic)]


#######################################################
#######################################################
#######                                         #######
#######         posterior tree evaluation       #######
#######                                         #######
#######################################################
#######################################################
post = canopy.post(sampchain = sampchain, projectname = projectname, K = K,
                   numchain = numchain, burnin = burnin, thin = thin, 
                   optK = optK, C = C, post.config.cutoff = 0.05)
samptreethin = post[[1]]   # list of all post-burnin and thinning trees
samptreethin.lik = post[[2]]   # likelihoods of trees in samptree
config = post[[3]]
config.summary = post[[4]]
print(config.summary)
# first column: tree configuration
# second column: posterior configuration probability in the entire tree space
# third column: posterior configuration likelihood in the subtree space
# note: if modes of posterior probabilities aren't obvious, run sampling longer.


#######################################################
#######################################################
#######                                         #######
#######          Tree output and plot           #######
#######                                         #######
#######################################################
#######################################################
# choose the configuration with the highest posterior likelihood
config.i = config.summary[which.max(config.summary[,3]),1]
cat('Configuration', config.i, 'has the highest posterior likelihood.\n')
output.tree = canopy.output(post, config.i, C)
pdf.name = paste(projectname, '_config_highest_likelihood.pdf', sep='')
canopy.plottree(output.tree, pdf = TRUE, pdf.name = pdf.name)

# plot posterior tree with third configuration
output.tree = canopy.output(post, 3, C)
canopy.plottree(output.tree, pdf=TRUE, pdf.name = paste(projectname, '_third_config.pdf', sep = ''))



#######################################################
#######################################################
#######                                         #######
#######            Toy: try it yourself         #######
#######                                         #######
#######################################################
#######################################################
library(Canopy)
data(toy)
projectname = 'toy'
R = toy$R; X = toy$X; WM = toy$WM; Wm = toy$Wm
epsilonM = toy$epsilonM; epsilonm = toy$epsilonm; Y = toy$Y

K = 3:5; numchain = 15
sampchain = canopy.sample(R = R, X = X, WM = WM, Wm = Wm, epsilonM = epsilonM, 
                          epsilonm = epsilonm, C = NULL, Y = Y, K = K, 
                          numchain = numchain, max.simrun = 100000,
                          min.simrun = 10000, writeskip = 200,
                          projectname = projectname, cell.line = FALSE,
                          plot.likelihood = TRUE)
save.image(file = paste(projectname, '_postmcmc_image.rda',sep=''),
           compress = 'xz')


#######################################################
#######################################################
#######                                         #######
#######   BIC to determine number of subclones  #######
#######                                         #######
#######################################################
#######################################################
library(Canopy)
data(toy)
projectname='toy'
load(paste(projectname, '_postmcmc_image.rda', sep=''))
burnin = 10
thin = 5
# If pdf = TRUE, a pdf will be generated.
bic = canopy.BIC(sampchain = sampchain, projectname = projectname, K = K,
                 numchain = numchain, burnin = burnin, thin = thin, pdf = TRUE)
optK = K[which.max(bic)]


#######################################################
#######################################################
#######                                         #######
#######         posterior tree evaluation       #######
#######                                         #######
#######################################################
#######################################################
post = canopy.post(sampchain = sampchain, projectname = projectname, K = K,
                   numchain = numchain, burnin = burnin, thin = thin, 
                   optK = optK, post.config.cutoff = 0.05)
samptreethin = post[[1]]   # list of all post-burnin and thinning trees
samptreethin.lik = post[[2]]   # likelihoods of trees in samptree
config = post[[3]]
config.summary = post[[4]]
print(config.summary)
# first column: tree configuration
# second column: posterior configuration probability in the entire tree space
# third column: posterior configuration likelihood in the subtree space
# note: if modes of posterior probabilities aren't obvious, run sampling longer.


#######################################################
#######################################################
#######                                         #######
#######          Tree output and plot           #######
#######                                         #######
#######################################################
#######################################################
# choose the configuration with the highest posterior likelihood
config.i = config.summary[which.max(config.summary[,3]),1]
cat('Configuration', config.i, 'has the highest posterior likelihood.\n')
output.tree = canopy.output(post, config.i, C=NULL)
pdf.name = paste(projectname, '_config_highest_likelihood.pdf', sep='')
canopy.plottree(output.tree, pdf = TRUE, pdf.name = pdf.name)
canopy.plottree(output.tree, pdf = FALSE)



#######################################################
#######################################################
#######                                         #######
#######      Toy2 dataset: try it yourself      #######
#######                                         #######
#######################################################
#######################################################
library(Canopy)
data(toy2)
projectname = 'toy2'
R = toy2$R; X = toy2$X; WM = toy2$WM; Wm = toy2$Wm
epsilonM = toy2$epsilonM; epsilonm = toy2$epsilonm; Y = toy2$Y
true.tree = toy2$true.tree  # true underlying tree

K = 3:6; numchain = 15
sampchain = canopy.sample(R = R, X = X, WM = WM, Wm = Wm, epsilonM = epsilonM, 
                          epsilonm = epsilonm, C = NULL, Y = Y, K = K, 
                          numchain = numchain, max.simrun = 100000,
                          min.simrun = 10000, writeskip = 200,
                          projectname = projectname, cell.line = FALSE,
                          plot.likelihood = TRUE)
save.image(file = paste(projectname, '_postmcmc_image.rda',sep=''),
           compress = 'xz')


#######################################################
#######################################################
#######                                         #######
#######   BIC to determine number of subclones  #######
#######                                         #######
#######################################################
#######################################################
library(Canopy)
data(toy2)
projectname='toy2'
load(paste(projectname, '_postmcmc_image.rda', sep=''))
burnin = 20
thin = 5
# If pdf = TRUE, a pdf will be generated.
bic = canopy.BIC(sampchain = sampchain, projectname = projectname, K = K,
                 numchain = numchain, burnin = burnin, thin = thin, pdf = TRUE)
optK = K[which.max(bic)]


#######################################################
#######################################################
#######                                         #######
#######         posterior tree evaluation       #######
#######                                         #######
#######################################################
#######################################################
post = canopy.post(sampchain = sampchain, projectname = projectname, K = K,
                   numchain = numchain, burnin = burnin, thin = thin, 
                   optK = optK, post.config.cutoff = 0.05)
samptreethin = post[[1]]   # list of all post-burnin and thinning trees
samptreethin.lik = post[[2]]   # likelihoods of trees in samptree
config = post[[3]]
config.summary = post[[4]]
print(config.summary)
# first column: tree configuration
# second column: posterior configuration probability in the entire tree space
# third column: posterior configuration likelihood in the subtree space
# note: if modes of posterior probabilities aren't obvious, run sampling longer.


#######################################################
#######################################################
#######                                         #######
#######          Tree output and plot           #######
#######                                         #######
#######################################################
#######################################################
# choose the configuration with the highest posterior likelihood
config.i = config.summary[which.max(config.summary[,3]),1]
cat('Configuration', config.i, 'has the highest posterior likelihood.\n')
output.tree = canopy.output(post, config.i, C=NULL)
pdf.name = paste(projectname, '_config_highest_likelihood.pdf', sep='')
canopy.plottree(output.tree, pdf = TRUE, pdf.name = pdf.name)
canopy.plottree(output.tree, pdf = FALSE)



#######################################################
#######################################################
#######                                         #######
#######               SNA clustering            #######
#######                                         #######
#######################################################
#######################################################
library(Canopy)
data(toy3)
R=toy3$R; X=toy3$X
dim(R);dim(X)
num_cluster=2:9 # Range of number of clusters to run
num_run=10 # How many EM runs per clustering step for each mutation cluster wave
canopy.cluster=canopy.cluster(R = R,
                              X = X,
                              num_cluster = num_cluster,
                              num_run = num_run)

# BIC to determine the optimal number of mutation clusters
bic_output=canopy.cluster$bic_output
plot(num_cluster,bic_output,xlab='Number of mutation clsuters',ylab='BIC',type='b',main='BIC for model selection')
abline(v=num_cluster[which.max(bic_output)],lty=2)
# Visualization of clustering result
Mu=canopy.cluster$Mu # VAF centroid for each cluster
Tau=canopy.cluster$Tau  # Prior for mutation cluster, with a K+1 component
sna_cluster=canopy.cluster$sna_cluster # cluster identity for each mutation
colc=c('green4','red3','royalblue1','darkorange1','royalblue4',
       'mediumvioletred','seagreen4','olivedrab4','steelblue4','lavenderblush4')
pchc=c(17,0,1,15,3,16,4,8,2,16)
plot((R/X)[,1],(R/X)[,2],xlab='Sample1 VAF',ylab='Sample2 VAF',col=colc[sna_cluster],pch=pchc[sna_cluster],ylim=c(0,max(R/X)),xlim=c(0,max(R/X)))
library(scatterplot3d)
scatterplot3d((R/X)[,1],(R/X)[,2],(R/X)[,3],xlim=c(0,max(R/X)),ylim=c(0,max(R/X)),zlim=c(0,max(R/X)),color=colc[sna_cluster],pch=pchc[sna_cluster],
              xlab='Sample1 VAF',ylab='Sample2 VAF',zlab='Sample3 VAF')


#######################################################
#######################################################
#######                                         #######
#######               MCMC sampling             #######
#######                                         #######
#######################################################
#######################################################
projectname='toy3'
K = 3:5 # number of subclones
numchain = 15 # number of chains with random initiations
sampchain = canopy.sample.cluster.nocna(R = R, X = X, sna_cluster = sna_cluster,
                                        K = K, numchain = numchain, 
                                        max.simrun = 100000, min.simrun = 20000,
                                        writeskip = 200, projectname = projectname,
                                        cell.line = FALSE, plot.likelihood = TRUE)
save.image(file = paste(projectname, '_postmcmc_image.rda',sep=''),
           compress = 'xz')


#######################################################
#######################################################
#######                                         #######
#######   BIC to determine number of subclones  #######
#######                                         #######
#######################################################
#######################################################
library(Canopy)
projectname='toy3'
load(paste(projectname, '_postmcmc_image.rda', sep=''))
burnin = 50
thin = 10
# If pdf = TRUE, a pdf will be generated.
bic = canopy.BIC(sampchain = sampchain, projectname = projectname, K = K,
                 numchain = numchain, burnin = burnin, thin = thin, pdf = TRUE)
optK = K[which.max(bic)]


#######################################################
#######################################################
#######                                         #######
#######         posterior tree evaluation       #######
#######                                         #######
#######################################################
#######################################################
post = canopy.post(sampchain = sampchain, projectname = projectname, K = K,
                   numchain = numchain, burnin = burnin, thin = thin, 
                   optK = optK, C=NULL, post.config.cutoff = 0.05)
samptreethin = post[[1]]   # list of all post-burnin and thinning trees
samptreethin.lik = post[[2]]   # likelihoods of trees in samptree
config = post[[3]]
config.summary = post[[4]]
print(config.summary)
# first column: tree configuration
# second column: posterior configuration probability in the entire tree space
# third column: posterior configuration likelihood in the subtree space
# note: if modes of posterior probabilities aren't obvious, run sampling longer.


#######################################################
#######################################################
#######                                         #######
#######          Tree output and plot           #######
#######                                         #######
#######################################################
#######################################################
# choose the configuration with the highest posterior likelihood
config.i = config.summary[which.max(config.summary[,3]),1]
cat('Configuration', config.i, 'has the highest posterior likelihood.\n')
output.tree = canopy.output(post, config.i, C=NULL)
pdf.name = paste(projectname, '_config_highest_likelihood.pdf', sep='')
canopy.plottree(output.tree, pdf = TRUE, pdf.name = pdf.name, txt = TRUE,
                txt.name = paste(projectname,'_mut.txt', sep = ''))
output.tree$P
toy3$realP
output.tree$Z[,c(1,2,4,3)]==toy3$realZ  # note that the clone order can be different.



falcon.output = function(readMatrix, tauhat, cn, st_bp, end_bp, nboot=NULL){
  if(is.null(nboot)){nboot = 10000}
  
  st_snp=c(1,tauhat)
  end_snp=c(tauhat,nrow(readMatrix))
  st_bp=st_bp[st_snp]
  end_bp=end_bp[end_snp]
  output=cbind(st_snp,end_snp,st_bp,end_bp,round(t(cn$ascn),3))
  colnames(output)[5:6]=c('Minor_copy','Major_copy')
  
  Major.sd=Minor.sd=rep(NA,nrow(output))
  output=cbind(output,Minor.sd,Major.sd)
  for(t in 1:nrow(output)){
    if(length(cn$Haplotype)==0) break
    if(t > length(cn$Haplotype)) break
    if(length(cn$Haplotype[[t]])==0) next
    cat('Running bootstrap for segment',t, '...\n')
    temp=readMatrix[output[t,1]:output[t,2],]
    haplo.temp=cn$Haplotype[[t]]
    t.cn1=t.cn2=n.cn1=n.cn2=rep(NA,nrow(temp))
    for(i in 1:length(haplo.temp)){
      if(haplo.temp[i]=='A'){
        t.cn1[i]=temp[i,'AT']
        t.cn2[i]=temp[i,'BT']
        n.cn1[i]=temp[i,'AN']
        n.cn2[i]=temp[i,'BN']
      } else {
        t.cn1[i]=temp[i,'BT']
        t.cn2[i]=temp[i,'AT']
        n.cn1[i]=temp[i,'BN']
        n.cn2[i]=temp[i,'AN']
      }
    }
    
    AN = readMatrix$AN
    BN = readMatrix$BN
    AT = readMatrix$AT
    BT = readMatrix$BT
    rdep=median(AT + BT)/median(AN + BN)
    t.cn1=t.cn1/rdep
    t.cn2=t.cn2/rdep
    
    filter=!(is.na(t.cn1) | is.na(t.cn2) | is.na(n.cn1) | is.na(n.cn2))
    t.cn1=t.cn1[filter]
    t.cn2=t.cn2[filter]
    n.cn1=n.cn1[filter]
    n.cn2=n.cn2[filter]
    
    cn1.boot=rep(NA,nboot)
    cn2.boot=rep(NA,nboot)
    for(i in 1:nboot){
      # if((i %%1000) ==0){ cat(i,'\t')}
      samp.temp=sample(1:length(t.cn1),replace = T)
      t.cn1.temp=t.cn1[samp.temp]
      t.cn2.temp=t.cn2[samp.temp]
      n.cn1.temp=n.cn1[samp.temp]
      n.cn2.temp=n.cn2[samp.temp]
      cn1.boot[i]=sum(t.cn1.temp)/sum(n.cn1.temp)
      cn2.boot[i]=sum(t.cn2.temp)/sum(n.cn2.temp)
    }
    output[t,"Major.sd"]=round(sd(cn1.boot),4)
    output[t,"Minor.sd"]=round(sd(cn2.boot),4)
  }
  return(output)
}


falcon.qc = function(readMatrix, tauhat, cn, st_bp, end_bp, rdep=NULL, length.thres = NULL, delta.cn.thres = NULL){
  if(is.null(length.thres)){
    length.thres=10^6
  }
  if(is.null(delta.cn.thres)){
    delta.cn.thres=0.3
  }
  if (is.null(rdep)) {rdep = median(readMatrix[,'AT'] + readMatrix[,'BT'])/median(
    readMatrix[,'AN'] + readMatrix[,'BN'])}
    
  tauhat.filter=rep(T,length(tauhat))
  for(i.change in 1:length(tauhat)){
    temp=max(abs(cn$ascn[,i.change+1]-cn$ascn[,i.change]))
    if (temp<=delta.cn.thres){
      tauhat.filter[i.change]=F
    }
  }
  tauhat=tauhat[tauhat.filter]
  cn = getASCN(readMatrix, tauhat=tauhat, rdep=rdep)
  
  st_snp=c(1,tauhat)
  end_snp=c(tauhat,nrow(readMatrix))
  st_bp=st_bp[st_snp]
  end_bp=end_bp[end_snp]
  output=cbind(st_snp,end_snp,st_bp,end_bp,t(cn$ascn))
  output.filter=(output[,"end_bp"]-output[,"st_bp"]+1)>=(length.thres)  # 1 Mb long at least
  output=output[output.filter,,drop=FALSE]
  tauhat=setdiff(unique(output[,"st_snp"],output[,"end_snp"]),c(1,nrow(readMatrix)))
  cn = getASCN(readMatrix, tauhat=tauhat, rdep=rdep)
  if(nrow(output)>1){
    tauhat.filter=rep(T,length(tauhat))
    for(i.change in 1:length(tauhat)){
      temp=max(abs(cn$ascn[,i.change+1]-cn$ascn[,i.change]))
      if (temp<=0.3){
        tauhat.filter[i.change]=FALSE
      }
    }
    tauhat=tauhat[tauhat.filter]
    cn = getASCN(readMatrix, tauhat=tauhat, rdep = rdep, threshold = 0.3) 
  }
  return(list(tauhat=tauhat, cn=cn))
}


falconx.qc = function(readMatrix, biasMatrix, tauhat, cn, st_bp, end_bp, length.thres = NULL, delta.cn.thres = NULL){
  if(is.null(length.thres)){
    length.thres=10^6
  }
  if(is.null(delta.cn.thres)){
    delta.cn.thres=0.3
  }
  
  tauhat.filter=rep(T,length(tauhat))
  for(i.change in 1:length(tauhat)){
    temp=max(abs(cn$ascn[,i.change+1]-cn$ascn[,i.change]))
    if (temp<=delta.cn.thres){
      tauhat.filter[i.change]=F
    }
  }
  tauhat=tauhat[tauhat.filter]
  cn = getASCN.x(readMatrix, biasMatrix, tauhat=tauhat)
  
  st_snp=c(1,tauhat)
  end_snp=c(tauhat,nrow(readMatrix))
  st_bp=st_bp[st_snp]
  end_bp=end_bp[end_snp]
  output=cbind(st_snp,end_snp,st_bp,end_bp,t(cn$ascn))
  output.filter=(output[,"end_bp"]-output[,"st_bp"]+1)>=(length.thres)  # 1 Mb long at least
  output=output[output.filter,,drop=FALSE]
  tauhat=setdiff(unique(output[,"st_snp"],output[,"end_snp"]),c(1,nrow(readMatrix)))
  cn = getASCN.x(readMatrix, biasMatrix, tauhat=tauhat)
  if(nrow(output)>1){
    tauhat.filter=rep(T,length(tauhat))
    for(i.change in 1:length(tauhat)){
      temp=max(abs(cn$ascn[,i.change+1]-cn$ascn[,i.change]))
      if (temp<=0.3){
        tauhat.filter[i.change]=FALSE
      }
    }
    tauhat=tauhat[tauhat.filter]
    cn = getASCN.x(readMatrix, biasMatrix, tauhat=tauhat, threshold = 0.3) 
  }
  return(list(tauhat=tauhat, cn=cn))
}


library(CODEX)
library(falconx)
#####################################################################################
###  Below is demo dataset consisting of 39 tumor-normal paired whole-exome
###  sequencing, published in Maxwell et al. (Nature Communications, 2017).
###  https://www.nature.com/articles/s41467-017-00388-9 
###  We focus on chr17, where copy-neutral loss-of-heterozygosity has been reported.
#####################################################################################


#####################################################################################
# Below are allelic reads, genotype, and genomic locations, which can be extracted
# from vcf files.
# rda files available for download at
# https://github.com/yuchaojiang/Canopy/tree/master/instruction
#####################################################################################

chr=17 
load('mymatrix.demo.rda')
load('genotype.demo.rda')
load('reads.demo.rda')

head(mymatrix) # genomic locations and SNP info across all loci from chr17
head(genotype[,1:12]) # genotype in blood (bloodGT1 and blood GT2) and tumor (tumorGT1 and tumor GT2)
# across all samples
head(reads[,1:12]) # allelic reads in blood (AN and BN) and tumor (AT and BT) across all samples

#####################################################################################
# Apply CODEX/CODEX2 to get total coverage bias
#####################################################################################

# Get GC content from a 50bp window centered at the SNP
pos=as.numeric(mymatrix[,'POS'])
ref=IRanges(start=pos-25,end=pos+25)
gc=getgc(chr,ref)  

# total read depth
Y=matrix(nrow=nrow(reads),ncol=ncol(reads)/2)  
for(j in 1:ncol(Y)){
  Y[,j]=reads[,2*j-1]+reads[,2*j]
}

# QC procedure
pos.filter=(apply(Y,1,median)>=20)
pos=pos[pos.filter]
ref=ref[pos.filter]
gc=gc[pos.filter]
genotype=genotype[pos.filter,]
reads=reads[pos.filter,]
mymatrix=mymatrix[pos.filter,]
Y=Y[pos.filter,]

# normalization
normObj=normalize2(Y,gc,K=1:3,normal_index=seq(1,77,2))
choiceofK(normObj$AIC,normObj$BIC,normObj$RSS,K=1:3,filename=paste('choiceofK.',chr,'.pdf',sep=''))
cat(paste('BIC is maximized at ',which.max(normObj$BIC),'.\n',sep=''))
Yhat=round(normObj$Yhat[[3]],0)

dim(mymatrix)  # raw vcf read.table, including genomic locations
dim(reads)   # allelic reads
dim(genotype)   # genotype
dim(Yhat)  # total coverage bias returned by CODEX



#################################################################################
# Generate input for FALCON-X: allelic read depth and genotype across all loci
#################################################################################

n = 39 # total number of samples
for (i in 1:n){
  cat('Generating input for sample',i,'...\n')
  ids = (4*i-3):(4*i)
  ids2 = (2*i-1):(2*i)
  mydata = as.data.frame(cbind(mymatrix[,1:2], genotype[,ids], reads[,ids], Yhat[,ids2]))
  colnames(mydata) = c("chr", "pos", "bloodGT1", "bloodGT2", "tumorGT1",
                       "tumorGT2", "AN", "BN", "AT", "BT", 'sN','sT')
  ids=which(as.numeric(mydata[,3])!=as.numeric(mydata[,4]))
  newdata0 = mydata[ids,]
  index.na=apply(is.na(newdata0), 1, any)
  newdata=newdata0[index.na==FALSE,]
  
  # Remove loci with multiple alternative alleles
  mul.alt.filter=rep(TRUE,nrow(newdata))
  for(s in 1:nrow(newdata)){
    filter1=!is.element(as.numeric(newdata[s,'tumorGT1']),
                        c(as.numeric(newdata[s,'bloodGT1']),
                          as.numeric(newdata[s,'bloodGT2'])))
    filter2=!is.element(as.numeric(newdata[s,'tumorGT2']),
                        c(as.numeric(newdata[s,'bloodGT1']),
                          as.numeric(newdata[s,'bloodGT2'])))
    if(filter1 | filter2){
      mul.alt.filter[s]=FALSE
    }
  }
  newdata=newdata[mul.alt.filter,]
  
  # write text at germline heterozygous loci, which is used as input for Falcon-X
  write.table(newdata, file=paste("sample",i,"_het.txt",sep=""), quote=F, row.names=F)
}


#################################################################################
# Apply FALCON-X to generate allele-specific copy number profiles
#################################################################################

# CODEX normalize total read depth across samples
# falcon-x profiles ASCN in each sample separately
k=10 # calling ASCN for the 10th sample
ascn.input=read.table(paste("sample",k,"_het.txt",sep=""),head=T)
readMatrix=ascn.input[,c('AN','BN','AT','BT')]
biasMatrix=ascn.input[,c('sN','sT')]

tauhat = getChangepoints.x(readMatrix, biasMatrix, pos=ascn.input$pos)
cn = getASCN.x(readMatrix, biasMatrix, tauhat=tauhat, pos=ascn.input$pos, threshold = 0.3)
# cn$tauhat would give the indices of change-points.
# cn$ascn would give the estimated allele-specific copy numbers for each segment.
# cn$Haplotype[[i]] would give the estimated haplotype for the major chromosome in segment i
# if this segment has different copy numbers on the two homologous chromosomes.
view(cn, pos=ascn.input$pos)

# Further curate Falcon-X's segmentation:
# Remove small segments based on genomic locations and combine consecutive segments with similar ASCN profiles
if(length(tauhat)>0){
  length.thres=10^6  # Threshold for length of segments, in base pair.
  delta.cn.thres=0.3  # Threshold of absolute copy number difference between consecutive segments.
  source('falconx.qc.R') # Can be downloaded from
  # https://github.com/yuchaojiang/Canopy/tree/master/instruction 
  falcon.qc.list = falconx.qc(readMatrix = readMatrix,
                              biasMatrix = biasMatrix,
                              tauhat = tauhat,
                              cn = cn,
                              st_bp = ascn.input$pos,
                              end_bp = ascn.input$pos,
                              length.thres = length.thres,
                              delta.cn.thres = delta.cn.thres)
  
  tauhat=falcon.qc.list$tauhat
  cn=falcon.qc.list$cn
}

view(cn,pos=ascn.input$pos)





library(falcon)

# This is a demo dataset from relapse genome of neuroblastoma with matched normal
# from Eleveld et al. (Nature Genetics 2015).

# Falcon takes as input germline heterozygous variants, which can be called by 
# GATK or VarScan2.

# The rda file can be downloaded at:
# https://www.dropbox.com/s/jbgjy4ne82hw5np/preprocessed.rda?dl=0
load('preprocessed.rda')

# calculate depth ratio (total read counts of tumor versus normal)
rdep.relapse=sum(relapse$Tumor_ReadCount_Total)/sum(relapse$Normal_ReadCount_Total)
rdep.primary=sum(primary$Tumor_ReadCount_Total)/sum(primary$Normal_ReadCount_Total)

# Falcon processes each chromosome separately and here we only show demonstration
# on a few chromosomes, for example, chr 14 where a copy-neutral loss of heterozygosity 
# has been previously reported.

for(chr in c(4,7,11,14,17,20)){
  cat(chr)
  load('preprocessed.rda')
  library(falcon)
  primary.chr=primary[which(primary[,'Chromosome']==chr),]
  relapse.chr=relapse[which(relapse[,'Chromosome']==chr),]
  rm(primary);rm(relapse)
  
  
  ###########################################
  ###########################################
  #
  #        Relapse genome
  #
  ###########################################
  ###########################################  
  
  
  ###########################################
  # Focus on germline heterozygous variants.
  ###########################################
  
  # remove variants with missing genotype
  relapse.chr=relapse.chr[relapse.chr[,'Match_Norm_Seq_Allele1']!=' ',]
  relapse.chr=relapse.chr[relapse.chr[,'Match_Norm_Seq_Allele2']!=' ',]
  relapse.chr=relapse.chr[relapse.chr[,'Reference_Allele']!=' ',]
  relapse.chr=relapse.chr[relapse.chr[,'TumorSeq_Allele1']!=' ',]
  relapse.chr=relapse.chr[relapse.chr[,'TumorSeq_Allele2']!=' ',]
  
  # get germline heterozygous loci (normal allele1 != normal allele2)
  relapse.chr=relapse.chr[(as.matrix(relapse.chr[,'Match_Norm_Seq_Allele1'])!=as.matrix(relapse.chr[,'Match_Norm_Seq_Allele2'])),]
  
  
  ############################################################
  # QC procedures to remove false neg and false pos variants.
  # The thresholds can be adjusted.
  ############################################################
  
  # remove indels (this can be relaxed but we think indels are harder to call than SNPs)
  indel.filter1=nchar(as.matrix(relapse.chr[,'Reference_Allele']))<=1
  indel.filter2=nchar(as.matrix(relapse.chr[,'Match_Norm_Seq_Allele1']))<=1
  indel.filter3=nchar(as.matrix(relapse.chr[,'Match_Norm_Seq_Allele2']))<=1
  indel.filter4=nchar(as.matrix(relapse.chr[,'TumorSeq_Allele1']))<=1
  indel.filter5=nchar(as.matrix(relapse.chr[,'TumorSeq_Allele2']))<=1
  relapse.chr=relapse.chr[indel.filter1 & indel.filter2 & indel.filter3 & indel.filter4 & indel.filter5,]
  
  # total number of reads greater than 30 in both tumor and normal
  depth.filter1=(relapse.chr[,"Normal_ReadCount_Ref"]+relapse.chr[,"Normal_ReadCount_Alt"])>=30
  depth.filter2=(relapse.chr[,"Tumor_ReadCount_Ref"]+relapse.chr[,"Tumor_ReadCount_Alt"])>=30
  relapse.chr=relapse.chr[depth.filter1 & depth.filter2,]
  
  
  #########################
  # Generate FALCON input.
  #########################
  
  # Data frame with four columns: tumor ref, tumor alt, normal ref, normal alt.
  readMatrix.relapse=as.data.frame(relapse.chr[,c('Tumor_ReadCount_Ref',
                                                  'Tumor_ReadCount_Alt',
                                                  'Normal_ReadCount_Ref',
                                                  'Normal_ReadCount_Alt')])
  colnames(readMatrix.relapse)=c('AT','BT','AN','BN')
  dim(readMatrix.relapse); dim(relapse.chr)
  
  
  ###############################
  # Run FALCON and view results.
  ###############################
  
  tauhat.relapse=getChangepoints(readMatrix.relapse)
  cn.relapse = getASCN(readMatrix.relapse, tauhat=tauhat.relapse, rdep = rdep.relapse, threshold = 0.3)
  
  # Chromosomal view of segmentation results.
  pdf(file=paste('falcon.relapse.',chr,'.pdf',sep=''),width=5,height=8)
  view(cn.relapse,pos=relapse.chr[,'Start_position'], rdep = rdep.relapse)
  dev.off()
  
  # save image file.
  save.image(file=paste('falcon_relapse_',chr,'.rda',sep=''))
  
  
  ########################################
  # Further curate FALCON's segmentation.
  ########################################
  
  # From the pdf above, we see that:
  # (1) There are small segments that need to be removed;
  # (2) Consecutive segments with similar allelic cooy number states need to be combined.
  if(length(tauhat.relapse)>0){
    length.thres=10^6  # Threshold for length of segments, in base pair.
    delta.cn.thres=0.3  # Threshold of absolute copy number difference between consecutive segments.
    source('falcon_demo/falcon.qc.R') # Can be downloaded from
    # https://github.com/yuchaojiang/Canopy/blob/master/instruction/falcon.qc.R
    falcon.qc.list = falcon.qc(readMatrix = readMatrix.relapse,
                               tauhat = tauhat.relapse,
                               cn = cn.relapse,
                               st_bp = relapse.chr[,"Start_position"],
                               end_bp = relapse.chr[,"End_position"],
                               rdep = rdep.relapse,
                               length.thres = length.thres,
                               delta.cn.thres = delta.cn.thres)
    
    tauhat.relapse=falcon.qc.list$tauhat
    cn.relapse=falcon.qc.list$cn
  }
  
  # Chromosomal view of QC'ed segmentation results.
  pdf(file=paste('falcon.relapse.qc.',chr,'.pdf',sep=''),width=5,height=8)
  view(cn.relapse,pos=relapse.chr[,'Start_position'], rdep = rdep.relapse)
  dev.off()
  
  
  #################################################
  # Generate Canopy's input with s.d. measurement.
  #################################################
  
  # This is to generate table output including genomic locations for 
  # segment boudaries.
  # For Canopy's input, we use Bootstrap-based method to estimate the
  # standard deviations for the allele-specific copy numbers.
  
  source('falcon_demo/falcon.output.R') # Can be downloaded from
  # https://github.com/yuchaojiang/Canopy/blob/master/instruction/falcon.output.R
  falcon.output=falcon.output(readMatrix = readMatrix.relapse,
                              tauhat = tauhat.relapse,
                              cn = cn.relapse,
                              st_bp = relapse.chr[,"Start_position"],
                              end_bp = relapse.chr[,"End_position"],
                              nboot = 5000)
  falcon.output = cbind(chr=rep(chr,nrow(falcon.output)), falcon.output)
  write.table(falcon.output, file=paste('faclon.relapse.output.',chr,'.txt',sep=''), col.names =T, row.names = F, sep='\t', quote = F)
  

  
  ###########################################
  ###########################################
  #
  #        Primary tumor
  #
  ###########################################
  ###########################################  
  
  
  ###########################################
  # Focus on germline heterozygous variants.
  ###########################################
  
  # remove variants with missing genotype
  primary.chr=primary.chr[primary.chr[,'Match_Norm_Seq_Allele1']!=' ',]
  primary.chr=primary.chr[primary.chr[,'Match_Norm_Seq_Allele2']!=' ',]
  primary.chr=primary.chr[primary.chr[,'Reference_Allele']!=' ',]
  primary.chr=primary.chr[primary.chr[,'TumorSeq_Allele1']!=' ',]
  primary.chr=primary.chr[primary.chr[,'TumorSeq_Allele2']!=' ',]
  
  # get germline heterozygous loci (normal allele1 != normal allele2)
  primary.chr=primary.chr[(as.matrix(primary.chr[,'Match_Norm_Seq_Allele1'])!=as.matrix(primary.chr[,'Match_Norm_Seq_Allele2'])),]
  
  
  ############################################################
  # QC procedures to remove false neg and false pos variants.
  # The thresholds can be adjusted.
  ############################################################
  
  # remove indels (this can be relaxed but we think indels are harder to call than SNPs)
  indel.filter1=nchar(as.matrix(primary.chr[,'Reference_Allele']))<=1
  indel.filter2=nchar(as.matrix(primary.chr[,'Match_Norm_Seq_Allele1']))<=1
  indel.filter3=nchar(as.matrix(primary.chr[,'Match_Norm_Seq_Allele2']))<=1
  indel.filter4=nchar(as.matrix(primary.chr[,'TumorSeq_Allele1']))<=1
  indel.filter5=nchar(as.matrix(primary.chr[,'TumorSeq_Allele2']))<=1
  primary.chr=primary.chr[indel.filter1 & indel.filter2 & indel.filter3 & indel.filter4 & indel.filter5,]
  
  # total number of reads greater than 30 in both tumor and normal
  depth.filter1=(primary.chr[,"Normal_ReadCount_Ref"]+primary.chr[,"Normal_ReadCount_Alt"])>=30
  depth.filter2=(primary.chr[,"Tumor_ReadCount_Ref"]+primary.chr[,"Tumor_ReadCount_Alt"])>=30
  primary.chr=primary.chr[depth.filter1 & depth.filter2,]
  
  
  #########################
  # Generate FALCON input.
  #########################
  
  # Data frame with four columns: tumor ref, tumor alt, normal ref, normal alt.
  readMatrix.primary=as.data.frame(primary.chr[,c('Tumor_ReadCount_Ref',
                                                  'Tumor_ReadCount_Alt',
                                                  'Normal_ReadCount_Ref',
                                                  'Normal_ReadCount_Alt')])
  colnames(readMatrix.primary)=c('AT','BT','AN','BN')
  dim(readMatrix.primary); dim(primary.chr)
  
  
  ###############################
  # Run FALCON and view results.
  ###############################
  
  tauhat.primary=getChangepoints(readMatrix.primary)
  cn.primary = getASCN(readMatrix.primary, tauhat=tauhat.primary, rdep = rdep.primary, threshold = 0.3)
  
  # Chromosomal view of segmentation results.
  pdf(file=paste('falcon.primary.',chr,'.pdf',sep=''),width=5,height=8)
  view(cn.primary,pos=primary.chr[,'Start_position'], rdep = rdep.primary)
  dev.off()
  
  # save image file.
  save.image(file=paste('falcon_primary_',chr,'.rda',sep=''))
  
  
  ########################################
  # Further curate FALCON's segmentation.
  ########################################
  
  # From the pdf above, we see that:
  # (1) There are small segments that need to be removed;
  # (2) Consecutive segments with similar allelic cooy number states need to be combined.
  if(length(tauhat.primary)>0){
    length.thres=10^6  # Threshold for length of segments, in base pair.
    delta.cn.thres=0.3  # Threshold of absolute copy number difference between consecutive segments.
    source('falcon_demo/falcon.qc.R') # Can be downloaded from
    # https://github.com/yuchaojiang/Canopy/blob/master/instruction/falcon.qc.R
    falcon.qc.list = falcon.qc(readMatrix = readMatrix.primary,
                               tauhat = tauhat.primary,
                               cn = cn.primary,
                               st_bp = primary.chr[,"Start_position"],
                               end_bp = primary.chr[,"End_position"],
                               rdep = rdep.primary,
                               length.thres = length.thres,
                               delta.cn.thres = delta.cn.thres)
    
    tauhat.primary=falcon.qc.list$tauhat
    cn.primary=falcon.qc.list$cn
  }
 
  # Chromosomal view of QC'ed segmentation results.
  pdf(file=paste('falcon.primary.qc.',chr,'.pdf',sep=''),width=5,height=8)
  view(cn.primary,pos=primary.chr[,'Start_position'], rdep = rdep.primary)
  dev.off()
  
  
  #################################################
  # Generate Canopy's input with s.d. measurement.
  #################################################
  
  # This is to generate table output including genomic locations for 
  # segment boudaries.
  # For Canopy's input, we use Bootstrap-based method to estimate the
  # standard deviations for the allele-specific copy numbers.
  
  source('falcon_demo/falcon.output.R') # Can be downloaded from
  # https://github.com/yuchaojiang/Canopy/blob/master/instruction/falcon.output.R
  falcon.output=falcon.output(readMatrix = readMatrix.primary,
                              tauhat = tauhat.primary,
                              cn = cn.primary,
                              st_bp = primary.chr[,"Start_position"],
                              end_bp = primary.chr[,"End_position"],
                              nboot = 5000)
  falcon.output = cbind(chr=rep(chr,nrow(falcon.output)), falcon.output)
  write.table(falcon.output, file=paste('faclon.primary.output.',chr,'.txt',sep=''), col.names =T, row.names = F, sep='\t', quote = F)  
}




# Above we automated the QC procedure after FALCON's initial call.
# However, sometimes further manual correction / curation is needed.
# Visual eyecheck is thus strongly recommended.
# Below is a manual correction for chr7.

chr=7
load("falcon_primary_7.rda")

tauhat.primary

if(length(tauhat.primary)>0){
  length.thres=10^6  # Threshold for length of segments, in base pair.
  delta.cn.thres=0.3  # Threshold of absolute copy number difference between consecutive segments.
  source('falcon_demo/falcon.qc.R') # Can be downloaded from
  # https://github.com/yuchaojiang/Canopy/blob/master/instruction/falcon.qc.R
  falcon.qc.list = falcon.qc(readMatrix = readMatrix.primary,
                             tauhat = tauhat.primary,
                             cn = cn.primary,
                             st_bp = primary.chr[,"Start_position"],
                             end_bp = primary.chr[,"End_position"],
                             rdep = rdep.primary,
                             length.thres = length.thres,
                             delta.cn.thres = delta.cn.thres)
  
  tauhat.primary=falcon.qc.list$tauhat
  cn.primary=falcon.qc.list$cn
}

tauhat.primary
tauhat.primary=c(tauhat.primary,37821)
cn.primary = getASCN(readMatrix.primary, tauhat=tauhat.primary, rdep = rdep.primary, threshold = 0.3)


# Chromosomal view of QC'ed segmentation results.
pdf(file=paste('falcon.primary.qc.',chr,'.pdf',sep=''),width=5,height=8)
view(cn.primary,pos=primary.chr[,'Start_position'], rdep = rdep.primary)
dev.off()

source('falcon_demo/falcon.output.R') # Can be downloaded from
# https://github.com/yuchaojiang/Canopy/blob/master/instruction/falcon.output.R
falcon.output=falcon.output(readMatrix = readMatrix.primary,
                            tauhat = tauhat.primary,
                            cn = cn.primary,
                            st_bp = primary.chr[,"Start_position"],
                            end_bp = primary.chr[,"End_position"],
                            nboot = 5000)
falcon.output = cbind(chr=rep(chr,nrow(falcon.output)), falcon.output)
write.table(falcon.output, file=paste('faclon.primary.output.',chr,'.txt',sep=''), col.names =T, row.names = F, sep='\t', quote = F)



addsamptree = function(tree, tree.new) {
    r=exp(tree.new$likelihood - tree$likelihood)
    randr=runif(1,0,1)
    if(r >= randr){
        returntree = tree.new
        return.id =1
    } else{
        returntree = tree
        return.id = 0
    }
    return(list(returntree,return.id))
} 


canopy.BIC = function(sampchain, projectname, K, numchain, burnin, thin, 
                      pdf = NULL) {
  if (is.null(pdf)) {
    pdf = FALSE
  }
  lik.k = rep(NA, length(K))
  BIC = rep(NA, length(K))
  ki = 1
  for (k in K) {
    sampchaink = sampchain[[ki]]
    temp.tree = sampchaink[[1]][[1]]
    s = nrow(temp.tree$VAF)
    n = ncol(temp.tree$VAF)
    t = ncol(temp.tree$Q)
    if(is.null(t)){t=0}
    numchain = length(sampchaink)
    # burn-in
    if( (burnin + 1) > length(sampchaink[[1]]) ) {
      stop("not enough trees after burn-in and thin; adjust parameters") 
    }
    samptreenew = sampchaink[[1]][(burnin + 1):length(sampchaink[[1]])]
    numpostburn = length(samptreenew)
    # thinning
    temp <- thin * c(1:(numpostburn/thin))
    samptreethin = samptreenew[temp]
    length(samptreethin)
    for (numi in 2:numchain) {
      if( (burnin + 1) > length(sampchaink[[numi]]) ) {
        stop("not enough trees after burn-in and thin; adjust parameters") 
      }
      samptreenew = sampchaink[[numi]][(burnin + 1):
                                         length(sampchaink[[numi]])]
      numpostburn = length(samptreenew)
      temp <- thin * c(1:(numpostburn/thin))
      samptreethin = c(samptreethin, samptreenew[temp])
    }
    samptreelik = rep(NA, length(samptreethin))
    for (treei in 1:length(samptreethin)) {
      samptreelik[treei] = samptreethin[[treei]]$likelihood
    }
    samptreethin = samptreethin[which((rank(-1 * samptreelik,
                                            ties.method = "first")) <= (length(samptreethin)/numchain))]
    samptreelik = rep(NA, length(samptreethin))
    for (treei in 1:length(samptreethin)) {
      samptreelik[treei] = samptreethin[[treei]]$likelihood
    }
    lik.temp = mean(samptreelik)
    K.data = 2 * (2 * k - 3) + 2 * t + s + (k - 1) * n
    N = s * n * 2 + t * n * 4 + s
    BIC.temp = 2 * lik.temp - K.data * log(N)
    lik.k[ki] = lik.temp
    BIC[ki] = BIC.temp
    cat("k =", k, ": mean likelihood", lik.temp,'; BIC',BIC.temp, ".\n")
    ki = ki + 1
  }
  if (pdf) {
    pdf(file = paste(projectname, "_BIC.pdf", sep = ""), height = 5, 
        width = 5)
  }
  plot(K, BIC, xlab = "Number of subclones", ylab = "BIC", type = "b", 
       xaxt = "n")
  axis(1, at = K)
  abline(v = K[which.max(BIC)], lty = 2)
  title(paste("BIC for model selection", projectname))
  if (pdf) {
    dev.off()
  }
  return(BIC)
} 


# Expectation function
canopy.cluster.Estep = function(Tau,Mu,R,X){
    s=nrow(R) # number of mutations
    K=nrow(Mu) # number of mutation clusters
    Mu=pmax(Mu, 0.001)
    pG=matrix(nrow=K+1,ncol=s,data=log(Tau)) # hidden parameters specifying probability of which component generated each data item
    pG[1:K,]=pG[1:K,]+log(Mu)%*%t(R)+log(1-Mu)%*%(t(X-R))
    for(j in 1:ncol(R)){
        pG[K+1,]=pG[K+1,]+lbeta(R[,j]+1,X[,j]-R[,j]+1)
    }
    if(Tau[length(Tau)]!=0){
        pGtemp=pG
        pGtemp=pGtemp-matrix(ncol=ncol(pGtemp),nrow=nrow(pGtemp),
                             data=apply(pGtemp,2,max),byrow = TRUE) # prevent underflow
        pGtemp=exp(pGtemp)
        pGtemp=pGtemp/matrix(nrow=nrow(pG),ncol=ncol(pG),
                             data=colSums(pGtemp),byrow=TRUE)
        pG[K+1,(rank(pGtemp[K+1,],ties.method='random'))<=(ncol(pG)*(1-Tau[K+1]))]=-Inf 
    }
    pG=pG-matrix(ncol=ncol(pG),nrow=nrow(pG),data=apply(pG,2,max),byrow = TRUE) # prevent underflow
    pG=exp(pG)
    pG=pG/matrix(nrow=nrow(pG),ncol=ncol(pG),data=colSums(pG),byrow = TRUE)
    return(pG)
}

# Maximization function
canopy.cluster.Mstep = function (pG,R,X,Tau_Kplus1){
    s=nrow(R) # number of mutations
    K=nrow(pG)-1 # number of mutation clusters
    Tau=rep(NA,K+1)
    Tau[1:K]=(1-Tau_Kplus1)*apply(pG[1:K,],1,sum)/(s-sum(pG[K+1,]))
    Tau[K+1]=Tau_Kplus1
    pGtemp=pG[1:K,]
    Mu=(pGtemp%*%R)/(pGtemp%*%X)
    Mu=round(pmax(Mu,0.0001),4)
    return(list(Mu=Mu,Tau=Tau))
}

canopy.cluster=function(R, X, num_cluster, num_run, Mu.init = NULL,
                        Tau_Kplus1 = NULL){
  if(is.null(Tau_Kplus1)){
    Tau_Kplus1=0    # proportion of noise, uniformly distributed between 0 and 1
  }
  
  #remove rows with zero or NA reference alleles for any sample
  zeroNARef = apply(X, 1, function(x) any(x==0 | is.na(x)))
  if (any(zeroNARef)){
    cat("Removing variants with NA or zero total allele depth for any sample \n")
    R = R[!zeroNARef, ]
    X = X[!zeroNARef, ]
  }
  
  VAF=R/X
  
  s=nrow(R)
  r=pmax(R,1);x=pmax(X,1) # for log()
  Mu_output=Tau_output=pGrank_output=bic_output=vector('list',length(num_cluster))
  for(K in num_cluster){
    cat('Running EM with',K,'clusters...\t')
    Mu_run=Tau_run=pGrank_run=bic_run=vector('list',num_run)
    for(run in 1:num_run){
      cat(run,'  ')
      bic.temp=0
      Tau=rep(NA,K+1)
      Tau[K+1]=Tau_Kplus1 
      Tau[1:K]=(1-Tau_Kplus1)/K
      
      if(K==1){
        Mu=t(as.matrix(apply(R/X,2,mean)))
      } else{
        if (run==1 & (!is.null(Mu.init))){
          Mu=Mu.init
        } else if(run<=(num_run/2)){
          # using hierarchical clustering to find initial values of centers
          VAF.pheat=pheatmap(VAF,cluster_rows = TRUE,cluster_cols = FALSE,kmeans_k=K,silent=TRUE, clustering_distance_rows = "euclidean")
          Mu=pmax(VAF.pheat$kmeans$centers,0.001)
        } else{
          if(ncol(R)>1){
            VAF.pheat=pheatmap(VAF,cluster_rows = TRUE,cluster_cols = FALSE,kmeans_k=K,silent=TRUE, clustering_distance_rows = "correlation")
            Mu=pmax(VAF.pheat$kmeans$centers,0.001) 
          } else{
            VAF.pheat=pheatmap(VAF,cluster_rows = TRUE,cluster_cols = FALSE,kmeans_k=K,silent=TRUE, clustering_distance_rows = "euclidean")
            Mu=pmax(VAF.pheat$kmeans$centers,0.001)
          }
        }
      }
      diff=1
      numiters=1
      while( (diff>0.001 & !is.na(diff) & numiters <= 300) || numiters <= 30){ 
        numiters=numiters+1
        pG=canopy.cluster.Estep(Tau,Mu,r,x)
        curM=canopy.cluster.Mstep(pG,R,X,Tau_Kplus1)
        curTau=curM$Tau
        curMu=curM$Mu
        
        diff=max(max(abs(Tau-curTau)),max(abs(Mu-curMu)))
        Mu=curMu
        Tau=curTau
        #cat('Iteration:',numiters-1,'\t','diff =',diff,'\n')
      }
      if (!is.na(diff) & numiters < 300 ){
        dim(pG)
        pGrank=apply(pG,2,which.max)
        for (i in 1:s){
          if(pGrank[i]<=K){
            muk=Mu[pGrank[i],]
            for(j in 1:ncol(R)){
              bic.temp=bic.temp+log(Tau[pGrank[i]])+r[i,j]*log(muk[j])+(x[i,j]-r[i,j])*log(1-muk[j])
            }
          }
          if(pGrank[i]==(K+1)){
            for(j in 1:ncol(R)){
              bic.temp=bic.temp+log(Tau[pGrank[i]])+lbeta(r[i,j]+1,x[i,j]-r[i,j]+1)
            }
          }
        }
        bic.temp=2*bic.temp-3*(length(Tau)-2+length(Mu))*log(length(R)+length(X))
        Mu_run[[run]]=Mu
        Tau_run[[run]]=Tau
        pGrank_run[[run]]=pGrank
        bic_run[[run]]=bic.temp
      } else {
        bic_run[[run]] = NA
      }
    }
    cat('\n')
    if ( all(sapply(bic_run, is.na)) ) {
      Mu_output[[which(num_cluster==K)]]=NA
      Tau_output[[which(num_cluster==K)]]=NA
      pGrank_output[[which(num_cluster==K)]]=NA
      bic_output[[which(num_cluster==K)]]=NA
      cat ( "EM did not converge in any runs with", K , "clusters \n" )
    } else {
      Mu_output[[which(num_cluster==K)]]=Mu_run[[which.max(bic_run)]]
      Tau_output[[which(num_cluster==K)]]=Tau_run[[which.max(bic_run)]]
      pGrank_output[[which(num_cluster==K)]]=pGrank_run[[which.max(bic_run)]]
      bic_output[[which(num_cluster==K)]]=bic_run[[which.max(bic_run)]]
      nConverged = sum(!sapply(bic_run, is.na))
      cat ( "EM converged in", nConverged,  "out of", num_run  ,"runs with", K , "clusters \n" )
    }
  }
  bic_output=as.numeric(bic_output)
  Mu=round(Mu_output[[which.max(bic_output)]],3)
  Tau=round(Tau_output[[which.max(bic_output)]],3)
  pGrank=pGrank_output[[which.max(bic_output)]]
  sna_cluster=pGrank
  return(list(bic_output=bic_output,Mu=Mu,Tau=Tau,sna_cluster=sna_cluster))
}


canopy.output = function(post, config.i, C = NULL) {
    samptreethin = post[[1]]
    samptreethin.lik = post[[2]]
    config = post[[3]]
    config.summary = post[[4]]
    if (is.null(C)) {
      C = diag(nrow(samptreethin[[1]]$cna))
      colnames(C) = rownames(C) = rownames(samptreethin[[1]]$cna)
    }
    tree.loc = which(config == config.i)
    output.tree = samptreethin[[tree.loc[which.max(samptreethin.lik[tree.loc])]]]
    # change cna names based on their inferred major and minor copy
    # numbers
    cnacopy.temp = output.tree$cna.copy
    if(!is.null(cnacopy.temp)){
        cna.newname = rep(NA, ncol(cnacopy.temp))
        for (j in 1:ncol(cnacopy.temp)) {
            if (cnacopy.temp[2, j] == 0 & cnacopy.temp[1, j] <= 1) {
                cna.newname[j] = paste((rownames(C))[which(C[, j] == 1)], 
                                       "_del", sep = "")
            } else if (cnacopy.temp[2, j] == 0 & cnacopy.temp[1, j] > 1) {
                cna.newname[j] = paste((rownames(C))[which(C[, j] == 1)], 
                                       "_LOH", sep = "")
            } else if (cnacopy.temp[2, j] >= 1) {
                cna.newname[j] = paste((rownames(C))[which(C[, j] == 1)], 
                                       "_dup", sep = "")
            }
        }
        rownames(output.tree$cna) = colnames(output.tree$cna.copy) = colnames(output.tree$Q) = colnames(output.tree$H) = cna.newname
    }
    output.tree$clonalmut = getclonalcomposition(output.tree)
    return(output.tree)
} 


canopy.plottree = function(tree, pdf = NULL, pdf.name = NULL, txt = NULL, 
                           txt.name = NULL) {
    if (is.null(pdf)) {
        pdf = FALSE
    }
    if (is.null(txt)){
        txt = FALSE
    }
    if (pdf & is.null(pdf.name)) {
        stop("pdf.name has to be provided if pdf = TRUE!")
    }
    if (txt & is.null(txt.name)){
        stop("txt.name has to be provided if txt = TRUE")
    }
    if (!is.null(pdf.name)) {
        pdf.split = strsplit(pdf.name, "\\.")[[1]]
        if (length(pdf.split) < 2 | pdf.split[2] != "pdf") {
            stop("pdf.name has to end with .pdf!")
        }
    }
    if (pdf) {
        pdf(file = pdf.name, height = 6, width = 6)
    }
    nf <- layout(matrix(c(1, 2, 3), 3, 1, byrow = TRUE), widths = c(3, 
        3, 3), heights = c(1.3, 1, 1), respect = TRUE)
    par(mar = c(1, 7, 1, 10))
    # plot tree
    K = ncol(tree$Z)
    plot(tree, label.offset = 0.1, type = "cladogram", direction = "d", 
        show.tip.label = FALSE)
    nodelabels()
    tiplabels()
    snaedge = rep(NA, nrow(tree$sna))
    for (k in 1:nrow(tree$sna)) {
        snaedge[k] = intersect(which(tree$edge[, 1] == tree$sna[k, 2]), 
            which(tree$edge[, 2] == tree$sna[k, 3]))
    }
    if(!is.null(tree$cna)){
        cnaedge = rep(NA, nrow(tree$cna))
        for (k in 1:nrow(tree$cna)) {
            cnaedge[k] = intersect(which(tree$edge[, 1] == tree$cna[k, 2]), 
                                   which(tree$edge[, 2] == tree$cna[k, 3]))
        }
    } else{
        cnaedge=NULL
    }
    edge.label = sort(unique(c(snaedge, cnaedge)))
    edgelabels(paste("mut", 1:length(edge.label), sep = ""), edge.label, 
        frame = "n", col = 2, cex = 1.2)
    tiplabels("Normal", 1, adj = c(0.2, 1.5), frame = "n", cex = 1.2, 
        col = 4)
    tiplabels(paste("Clone", 1:(K - 2), sep = ""), 2:(K - 1), adj = c(0.5, 
        1.5), frame = "n", cex = 1.2, col = 4)
    tiplabels(paste("Clone", (K - 1), sep = ""), K, adj = c(0.8, 1.5), 
        frame = "n", cex = 1.2, col = 4)
    # plot clonal frequencies
    par(mar = c(1, 7, 0.5, 9.5))
    P = tree$P
    image(1:nrow(P), 1:ncol(P), axes = FALSE, ylab = "", xlab = "", 
        P, breaks = 0:100/100, col = tim.colors(100))
    axis(4, at = 1:ncol(P), colnames(P), cex.axis = 1.2, las = 1, tick = FALSE)
    abline(h = seq(0.5, ncol(P) + 0.5, 1), v = seq(0.5, nrow(P) + 0.5, 
        1), col = "grey")
    for (i in 1:nrow(P)) {
        for (j in 1:ncol(P)) {
            txt.temp <- sprintf("%0.3f", P[i, j])
            if (P[i, j] <= 0.05 | P[i, j] >= 0.95) {
                text(i, j, txt.temp, cex = 0.7, col = "white")
            } else {
                text(i, j, txt.temp, cex = 0.7)
            }
        }
    }
    sna.name = rownames(tree$sna)
    cna.name = rownames(tree$cna)
    # plot mutations
    plot(c(0, 1), c(0, 1), ann = FALSE, bty = "n", type = "n", xaxt = "n", 
        yaxt = "n")
    txt.output=matrix(nrow=length(edge.label),ncol=1)
    for (i in 1:length(edge.label)) {
        txt.temp = paste("mut", i, ": ", paste(c(sna.name[which(snaedge == 
            edge.label[i])], cna.name[which(cnaedge == edge.label[i])]), 
            collapse = ", "), sep = "")
        text(x = 0, y = 0.95 - 0.1 * (i - 1), txt.temp, pos = 4, cex = 1.2)
        txt.output[i,1]=txt.temp
    }
    
    if (txt){
        write.table(txt.output, file = txt.name, col.names = FALSE,
                    row.names = FALSE, quote = FALSE, sep = '\t')
    }
    
    if (!is.null(pdf.name)) {
        text(x = 0.5, y = 0.1, pdf.split[1], font = 2, cex = 1.2)
    }
    if (pdf) {
        dev.off()
    }
    par(mfrow=c(1,1))
} 


canopy.post = function (sampchain, projectname, K, numchain, burnin, thin, 
          optK, C = NULL, post.config.cutoff = NULL) 
{
    if (is.null(C)) {
        C = diag(nrow(sampchain[[1]][[1]][[1]]$cna))
        colnames(C) = rownames(C) = rownames(sampchain[[1]][[1]][[1]]$cna)
    }
    if (is.null(post.config.cutoff)) {
        post.config.cutoff = 0.05
    }
    if (post.config.cutoff > 1 | post.config.cutoff <= 0) {
        stop("post.config.cutoff has to be between 0 and 1!")
    }
    sampchaink = sampchain[[which(K == optK)]]
    numchain = length(sampchaink)
    samptreenew = sampchaink[[1]][(burnin + 1):length(sampchaink[[1]])]
    numpostburn = length(samptreenew)
    temp <- thin * c(1:(numpostburn/thin))
    samptreethin = samptreenew[temp]
    length(samptreethin)
    for (numi in 2:numchain) {
        samptreenew = sampchaink[[numi]][(burnin + 1):length(sampchaink[[numi]])]
        numpostburn = length(samptreenew)
        temp <- thin * c(1:(numpostburn/thin))
        samptreethin = c(samptreethin, samptreenew[temp])
    }
    samptreethin.lik = rep(NA, length(samptreethin))
    for (treei in 1:length(samptreethin)) {
        samptreethin.lik[treei] = samptreethin[[treei]]$likelihood
    }
    samptreethin = samptreethin[which((rank(-1 * samptreethin.lik, 
                                            ties.method = "first")) <= 5 * (length(samptreethin)/numchain))]
    samptreethin.lik = rep(NA, length(samptreethin))
    for (treei in 1:length(samptreethin)) {
        samptreethin.lik[treei] = samptreethin[[treei]]$likelihood
    }
    if(!is.null(sampchain[[1]][[1]][[1]]$cna)){
        for (i in 1:length(samptreethin)) {
            samptreethin[[i]] = sortcna(samptreethin[[i]], C)
        }
    }
    for (i in 1:length(samptreethin)) {
        samptreethin[[i]]$clonalmut = getclonalcomposition(samptreethin[[i]])
    }
    config = rep(NA, length(samptreethin))
    config[1] = 1
    categ = 1
    for (i in 2:length(samptreethin)) {
        for (categi in 1:categ) {
            list.a = samptreethin[[i]]$clonalmut
            list.b = samptreethin[[which(config == categi)[1]]]$clonalmut
            if ((sum(is.element(list.a, list.b)) == optK) & 
                (sum(is.element(list.b, list.a)) == optK)) {
                config[i] = categi
            }
        }
        if (is.na(config[i])) {
            config[i] = categ + 1
            categ = categ + 1
        }
    }
    z.temp = (samptreethin.lik - mean(samptreethin.lik))/sd(samptreethin.lik)
    samptreethin = samptreethin[z.temp <= 1.5 & z.temp >= -1.5]
    samptreethin.lik = samptreethin.lik[z.temp <= 1.5 & z.temp >= -1.5]
    config = config[z.temp <= 1.5 & z.temp >= -1.5]
    config.summary = matrix(nrow = length(unique(config)), ncol = 3)
    colnames(config.summary) = c("Configuration", "Post_prob", 
                                 "Mean_post_lik")
    config.summary[, 1] = unique(config)
    for (i in 1:nrow(config.summary)) {
        configi = config.summary[i, 1]
        configi.temp = which(config == configi)
        config.summary[i, 2] = round(length(configi.temp)/length(config), 
                                     3)
        config.summary[i, 3] = round(max(samptreethin.lik[which(config == 
                                                                    configi)]), 2)
    }
    minor.config = which(config.summary[, 2] < post.config.cutoff)
    if(length(minor.config)==nrow(config.summary)){
          message('No configuration has posterior probablity greater than the threshold!\nRun sampling longer or reduce the threshold.')
          stop
    }
    if (length(minor.config) > 0) {
        config.sel = rep(TRUE, length(config))
        for (i in minor.config) {
            config.sel[which(config == config.summary[i,1])] = FALSE
        }
        samptreethin = samptreethin[config.sel]
        samptreethin.lik = samptreethin.lik[config.sel]
        config = config[config.sel]
        config.summary = config.summary[-minor.config, , drop = FALSE]
        for (i in 1:nrow(config.summary)) {
            config[which(config == config.summary[i, 1])] = i
        }
        config.summary[, 1] = 1:nrow(config.summary)
        config.summary[, 2] = round(config.summary[, 2]/sum(config.summary[, 2]), 3)
    }
    # below is added to computed CCFs for point mutations / SNAs
    for(treei in 1:length(samptreethin)){
        output.tree=samptreethin[[treei]]
        output.tree.Z=output.tree$Z[,2:ncol(output.tree$Z),drop = FALSE]
        output.tree.P=apply(output.tree$P[2:nrow(output.tree$P),,drop=FALSE],2,function(x){x/sum(x)})
        output.tree$CCF= output.tree.Z %*% output.tree.P 
        samptreethin[[treei]]=output.tree
    }
    return(list(samptreethin, samptreethin.lik, config, config.summary))
}


canopy.sample.cluster.nocna = function(R, X, sna_cluster, K, numchain, 
                                 max.simrun, min.simrun, writeskip, 
                                 projectname, cell.line = NULL,
                                 plot.likelihood = NULL) {
    if(length(sna_cluster)!=nrow(R)){
        stop('Length of sna_cluster should be the same as row numbers of R and X!')
    }
    if (!is.matrix(R)) {
        stop("R should be a matrix!")
    }
    if (!is.matrix(X)) {
        stop("X should be a matrix!")
    }
    if (min(K) < 2) {
        stop("Smallest number of subclones should be >= 2!\n")
    }
    if (is.null(cell.line)) {
        cell.line = FALSE
    }
    if (is.null(plot.likelihood)) {
        plot.likelihood = TRUE
    }
    if ( plot.likelihood){
        pdf(file = paste(projectname, "_likelihood.pdf", sep = ""), width = 10, height = 5)
    }
    
    sampname = colnames(R)
    sna.name = rownames(R)
    sampchain = vector("list", length(K))
    ki = 1
    for (k in K) {
        cat("Sample in tree space with", k, "subclones\n")
        sampchaink = vector("list", numchain)
        sampchaink.lik=vector('list',numchain)
        sampchaink.accept.rate=vector('list',numchain)
        for (numi in 1:numchain) {  # numi: number of chain
            cat("\tRunning chain", numi, "out of", numchain, "...\n")
            ###################################### Tree initialization #####
            text = paste(paste(paste(paste("(", 1:(k - 1), ",", sep = ""), 
                                     collapse = ""), k, sep = ""), paste(rep(")", (k - 1)), 
                                                                         collapse = ""), ";", sep = "")
            runif.temp=runif(1)
            if(k == 5 & runif.temp<0.5){
                text = c('(1,((2,3),(4,5)));')
            }else if(k == 6 & runif.temp < 1/3){
                text = c('(1,((2,3),(4,(5,6))));')
            }else if(k == 6 & runif.temp > 2/3){
                text = c('(1,(2,((3,4),(5,6))));')
            }else if(k == 7 & runif.temp > 1/4 & runif.temp <= 2/4){
                text=c('(1,((2,3),(4,(5,(6,7)))));')
            }else if(k == 7 & runif.temp > 2/4 & runif.temp <= 3/4){
                text = c('(1,((2,3),((4,5),(6,7))));')
            }else if(k == 7 & runif.temp > 3/4){
                text = c('(1,((2,(3,4)),(5,(6,7))));')
            }
            tree <- read.tree(text = text)
            tree$sna.cluster=initialsna(tree,paste('cluster',unique(sna_cluster),sep=''))
            sna.mat = cbind(sna=1:nrow(R),(tree$sna.cluster)[sna_cluster,2:3])
            colnames(sna.mat) = c("sna", "sna.st.node", "sna.ed.node")
            rownames(sna.mat) = sna.name
            tree$sna=sna.mat
            #tree$sna = initialsna(tree, sna.name)
            # if(k>=5){tree$relation=getrelation(tree)}
            tree$Z = getZ(tree, sna.name)
            tree$P = initialP(tree, sampname, cell.line)
            tree$VAF = tree$Z%*%tree$P/2
            tree$likelihood = getlikelihood.sna(tree, R, X)
            ###################################### Sample in tree space #####
            sampi = 1
            writei = 1
            samptree = vector("list", max.simrun)
            samptree.lik=rep(NA, max.simrun)
            samptree.accept=rep(NA, max.simrun)
            samptree.accept.rate=rep(NA, max.simrun)
            
            while(sampi <= min.simrun){
                ######### sample sna mutation cluster positions
                tree.new=tree
                tree.new$sna.cluster=sampsna.cluster(tree)
                sna.mat = cbind(sna=1:nrow(R),(tree.new$sna.cluster)[sna_cluster,2:3])
                colnames(sna.mat) = c("sna", "sna.st.node", "sna.ed.node")
                rownames(sna.mat) = sna.name
                tree.new$sna=sna.mat
                tree.new$Z = getZ(tree.new, sna.name)
                tree.new$VAF=tree.new$Z%*%tree.new$P/2
                tree.new$likelihood = getlikelihood.sna(tree.new, R, X)
                tree.temp=addsamptree(tree,tree.new)
                tree=tree.temp[[1]]
                samptree.accept[sampi]=tree.temp[[2]]
                if (sampi%%writeskip == 0) {
                    samptree[[writei]] = tree
                    writei = writei + 1
                }
                samptree.lik[sampi]=tree$likelihood
                samptree.accept.rate[sampi]=mean(samptree.accept[max(1,sampi-999):sampi])
                sampi = sampi + 1
                ######## sample P (clonal proportions)
                tree.new = tree
                tree.new$P = sampP(tree.new, cell.line)
                tree.new$VAF = tree.new$Z%*%tree.new$P/2
                tree.new$likelihood = getlikelihood.sna(tree.new, R, X)
                tree.temp=addsamptree(tree,tree.new)
                tree=tree.temp[[1]]
                samptree.accept[sampi]=tree.temp[[2]]
                if (sampi%%writeskip == 0) {
                    samptree[[writei]] = tree
                    writei = writei + 1
                }
                samptree.lik[sampi]=tree$likelihood
                samptree.accept.rate[sampi]=mean(samptree.accept[max(1,sampi-999):sampi])
                sampi = sampi + 1
            }
            while(sampi <= max.simrun){
                ######### sample sna positions
                tree.new = tree
                tree.new$sna = sampsna(tree)
                tree.new$Z = getZ(tree.new, sna.name)
                tree.new$VAF=tree.new$Z%*%tree.new$P/2
                tree.new$likelihood = getlikelihood.sna(tree.new, R, X)
                tree.temp=addsamptree(tree,tree.new)
                tree=tree.temp[[1]]
                samptree.accept[sampi]=tree.temp[[2]]
                if (sampi%%writeskip == 0) {
                    samptree[[writei]] = tree
                    writei = writei + 1
                }
                samptree.lik[sampi]=tree$likelihood
                samptree.accept.rate[sampi]=mean(samptree.accept[max(1,sampi-999):sampi])
                # if ((sampi >= 2*min.simrun) & (samptree.lik[sampi] <= mean(samptree.lik[max((sampi-1000),1):max((sampi-1),1)])) &
                #     (samptree.accept.rate[sampi] <= mean(samptree.accept.rate[max((sampi-1000),1):max((sampi-1),1)])) &
                #     (samptree.accept.rate[sampi] <= 0.1)) break
                sampi = sampi + 1
                ######## sample P (clonal proportions)
                tree.new = tree
                tree.new$P = sampP(tree.new, cell.line)
                tree.new$VAF = tree.new$Z%*%tree.new$P/2
                tree.new$likelihood = getlikelihood.sna(tree.new, R, X)
                tree.temp=addsamptree(tree,tree.new)
                tree=tree.temp[[1]]
                samptree.accept[sampi]=tree.temp[[2]]
                if (sampi%%writeskip == 0) {
                    samptree[[writei]] = tree
                    writei = writei + 1
                }
                samptree.lik[sampi]=tree$likelihood
                samptree.accept.rate[sampi]=mean(samptree.accept[max(1,sampi-999):sampi])
                # if ((sampi >= 2*min.simrun) & (samptree.lik[sampi] <= mean(samptree.lik[max((sampi-1000),1):max((sampi-1),1)])) &
                #     (samptree.accept.rate[sampi] <= mean(samptree.accept.rate[max((sampi-1000),1):max((sampi-1),1)])) &
                #     (samptree.accept.rate[sampi] <= 0.1)) break
                sampi = sampi + 1
            }
            sampchaink[[numi]] = samptree[1:(writei - 1)]
            sampchaink.lik[[numi]]=samptree.lik
            sampchaink.accept.rate[[numi]]=samptree.accept.rate
        }
        ###################################### plotting and saving #####
        if (plot.likelihood) {
            par(mfrow=c(1,2))
            xmax=ymin=ymax=rep(NA,numchain)
            for(i in 1:numchain){
                xmax[i]=max(which((!is.na(sampchaink.lik[[i]]))))
                ymin[i]=sampchaink.lik[[i]][1]
                ymax[i]=sampchaink.lik[[i]][xmax[i]]
            }
            plot(sampchaink.lik[[1]],xlim=c(1,max(xmax)),ylim=c(min(ymin),max(ymax)),
                 xlab='Iteration',ylab='Log-likelihood',type='l',
                 main=paste('Post. likelihood:',k,'branches'))
            for(numi in 2:numchain){
                points(sampchaink.lik[[numi]],xlim=c(1,max(xmax)),ylim=c(min(ymin),max(ymax)),col=numi,type='l')
            }
            plot(sampchaink.accept.rate[[1]],ylim=c(0,1),xlim=c(1,max(xmax)),
                 xlab='Iteration',ylab='Acceptance rate',type='l',
                 main=paste('Acceptance rate:',k,'branches'))
            for(numi in 2:numchain){
                points(sampchaink.accept.rate[[numi]],ylim=c(0,1),xlim=c(1,max(xmax)),col=numi,type='l')
            }
            par(mfrow=c(1,1))
        }
        sampchain[[ki]] = sampchaink
        
        ki = ki + 1
    }
    if(plot.likelihood) {
        dev.off()
    }
    return(sampchain)
} 


canopy.sample.cluster = function(R, X, sna_cluster, WM, Wm, epsilonM, epsilonm, C = NULL, 
                         Y, K, numchain, max.simrun, min.simrun, writeskip, projectname, 
                         cell.line = NULL, plot.likelihood = NULL) {
    if(length(sna_cluster)!=nrow(R)){
        stop('Length of sna_cluster should be the same as row numbers of R and X!')
    }
    if (!is.matrix(R)) {
        stop("R should be a matrix!")
    }
    if (!is.matrix(X)) {
        stop("X should be a matrix!")
    }
    if (!is.matrix(WM)) {
        stop("WM should be a matrix!")
    }
    if (!is.matrix(Wm)) {
        stop("Wm should be a matrix!")
    }
    if (min(K) < 2) {
        stop("Smallest number of subclones should be >= 2!\n")
    }
    if (is.null(cell.line)) {
        cell.line = FALSE
    }
    if (is.null(plot.likelihood)) {
        plot.likelihood = TRUE
    }
    if (is.null(C)) {
        C = diag(nrow(WM))
        colnames(C) = rownames(C) = rownames(WM)
    }
    if (any(colSums(C) != 1)) {
        stop("Matrix C should have one and only one 1 for each column!")
    }
    if ( plot.likelihood){
        pdf(file = paste(projectname, "_likelihood.pdf", sep = ""), width = 10, height = 5)
    }
    
    sampname = colnames(R)
    sna.name = rownames(R)
    cna.region.name = rownames(C)
    cna.name = colnames(C)
    sampchain = vector("list", length(K))
    ki = 1
    for (k in K) {
        cat("Sample in tree space with", k, "subclones\n")
        sampchaink = vector("list", numchain)
        sampchaink.lik=vector('list',numchain)
        sampchaink.accept.rate=vector('list',numchain)
        for (numi in 1:numchain) {  # numi: number of chain
            cat("\tRunning chain", numi, "out of", numchain, "...\n")
            ###################################### Tree initialization #####
            text = paste(paste(paste(paste("(", 1:(k - 1), ",", sep = ""), 
                                     collapse = ""), k, sep = ""), paste(rep(")", (k - 1)), 
                                                                         collapse = ""), ";", sep = "")
            runif.temp=runif(1)
            if(k == 5 & runif.temp<0.5){
                text = c('(1,((2,3),(4,5)));')
            }else if(k == 6 & runif.temp < 1/3){
                text = c('(1,((2,3),(4,(5,6))));')
            }else if(k == 6 & runif.temp > 2/3){
                text = c('(1,(2,((3,4),(5,6))));')
            }else if(k == 7 & runif.temp > 1/4 & runif.temp <= 2/4){
                text=c('(1,((2,3),(4,(5,(6,7)))));')
            }else if(k == 7 & runif.temp > 2/4 & runif.temp <= 3/4){
                text = c('(1,((2,3),((4,5),(6,7))));')
            }else if(k == 7 & runif.temp > 3/4){
                text = c('(1,((2,(3,4)),(5,(6,7))));')
            }
            tree <- read.tree(text = text)
            tree$sna.cluster=initialsna(tree,paste('cluster',unique(sna_cluster),sep=''))
            sna.mat = cbind(sna=1:nrow(R),(tree$sna.cluster)[sna_cluster,2:3])
            colnames(sna.mat) = c("sna", "sna.st.node", "sna.ed.node")
            rownames(sna.mat) = sna.name
            tree$sna=sna.mat
            #tree$sna = initialsna(tree, sna.name)
            # if(k>=5){tree$relation=getrelation(tree)}
            tree$Z = getZ(tree, sna.name)
            tree$P = initialP(tree, sampname, cell.line)
            tree$cna = initialcna(tree, cna.name)
            tree$cna.copy = initialcnacopy(tree)
            CMCm = getCMCm(tree, C)  # get major and minor copy per clone
            tree$CM = CMCm[[1]]
            tree$Cm = CMCm[[2]]  # major/minor copy per clone
            tree$Q = getQ(tree, Y, C)
            tree$H = tree$Q  # start as all SNAs that precede CNAs land 
            # on major copies
            tree$VAF = getVAF(tree, Y)
            tree$likelihood = getlikelihood(tree, R, X, WM, Wm, epsilonM, 
                                            epsilonm)
            ###################################### Sample in tree space #####
            sampi = 1
            writei = 1
            samptree = vector("list", max.simrun)
            samptree.lik=rep(NA, max.simrun)
            samptree.accept=rep(NA, max.simrun)
            samptree.accept.rate=rep(NA, max.simrun)
            while(sampi <= max.simrun){
                if( sampi <= min.simrun){
                    ######### sample sna mutation cluster positions
                    tree.new=tree
                    tree.new$sna.cluster=sampsna.cluster(tree)
                    sna.mat = cbind(sna=1:nrow(R),(tree.new$sna.cluster)[sna_cluster,2:3])
                    colnames(sna.mat) = c("sna", "sna.st.node", "sna.ed.node")
                    rownames(sna.mat) = sna.name
                    tree.new$sna=sna.mat
                    tree.new$Z = getZ(tree.new, sna.name)
                    tree.new$Q = getQ(tree.new, Y, C)
                    tree.new$H = tree.new$Q
                    tree.new$VAF = getVAF(tree.new, Y)
                    tree.new$likelihood = getlikelihood(tree.new, R, X, 
                                                        WM, Wm, epsilonM, epsilonm)
                    tree.temp=addsamptree(tree,tree.new)
                    tree=tree.temp[[1]]
                    samptree.accept[sampi]=tree.temp[[2]]
                    if (sampi%%writeskip == 0) {
                        samptree[[writei]] = tree
                        writei = writei + 1
                    }
                    samptree.lik[sampi]=tree$likelihood
                    samptree.accept.rate[sampi]=mean(samptree.accept[max(1,sampi-999):sampi])
                    sampi = sampi + 1
                } else{
                    ######### sample sna positions
                    tree.new = tree
                    tree.new$sna = sampsna(tree)
                    tree.new$Z = getZ(tree.new, sna.name)
                    tree.new$Q = getQ(tree.new, Y, C)
                    tree.new$H = tree.new$Q
                    tree.new$VAF = getVAF(tree.new, Y)
                    tree.new$likelihood = getlikelihood(tree.new, R, X, 
                                                        WM, Wm, epsilonM, epsilonm)
                    tree.temp=addsamptree(tree,tree.new)
                    tree=tree.temp[[1]]
                    samptree.accept[sampi]=tree.temp[[2]]
                    if (sampi%%writeskip == 0) {
                        samptree[[writei]] = tree
                        writei = writei + 1
                    }
                    samptree.lik[sampi]=tree$likelihood
                    samptree.accept.rate[sampi]=mean(samptree.accept[max(1,sampi-999):sampi])
                    if ((sampi >= 2*min.simrun) & (samptree.lik[sampi] <= mean(samptree.lik[max((sampi-1000),1):max((sampi-1),1)])) &
                        (samptree.accept.rate[sampi] <= mean(samptree.accept.rate[max((sampi-1000),1):max((sampi-1),1)])) &
                        (samptree.accept.rate[sampi] <= 0.1)) break
                    sampi = sampi + 1
                }
                ######## sample cna positions
                tree.new = tree
                tree.new$cna = sampcna(tree)
                CMCm = getCMCm(tree.new, C)
                tree.new$CM = CMCm[[1]]
                tree.new$Cm = CMCm[[2]]
                tree.new$Q = getQ(tree.new, Y, C)
                tree.new$H = tree.new$Q
                tree.new$VAF = getVAF(tree.new, Y)
                tree.new$likelihood = getlikelihood(tree.new, R, X, 
                                                    WM, Wm, epsilonM, epsilonm)
                tree.temp=addsamptree(tree,tree.new)
                tree=tree.temp[[1]]
                samptree.accept[sampi]=tree.temp[[2]]
                if (sampi%%writeskip == 0) {
                    samptree[[writei]] = tree
                    writei = writei + 1
                }
                samptree.lik[sampi]=tree$likelihood
                samptree.accept.rate[sampi]=mean(samptree.accept[max(1,sampi-999):sampi])
                if ((sampi >= 2*min.simrun) & (samptree.lik[sampi] <= mean(samptree.lik[max((sampi-1000),1):max((sampi-1),1)])) &
                    (samptree.accept.rate[sampi] <= mean(samptree.accept.rate[max((sampi-1000),1):max((sampi-1),1)])) &
                    (samptree.accept.rate[sampi] <= 0.1)) break
                sampi = sampi + 1
                ######## sample P (clonal proportions)
                tree.new = tree
                tree.new$P = sampP(tree.new, cell.line)
                tree.new$VAF = getVAF(tree.new, Y)
                tree.new$likelihood = getlikelihood(tree.new, R, X, 
                                                    WM, Wm, epsilonM, epsilonm)
                tree.temp=addsamptree(tree,tree.new)
                tree=tree.temp[[1]]
                samptree.accept[sampi]=tree.temp[[2]]
                if (sampi%%writeskip == 0) {
                    samptree[[writei]] = tree
                    writei = writei + 1
                }
                samptree.lik[sampi]=tree$likelihood
                samptree.accept.rate[sampi]=mean(samptree.accept[max(1,sampi-999):sampi])
                if ((sampi >= 2*min.simrun) & (samptree.lik[sampi] <= mean(samptree.lik[max((sampi-1000),1):max((sampi-1),1)])) &
                    (samptree.accept.rate[sampi] <= mean(samptree.accept.rate[max((sampi-1000),1):max((sampi-1),1)])) &
                    (samptree.accept.rate[sampi] <= 0.1)) break
                sampi = sampi + 1
                ######## sample major and minor copy number
                tree.new = tree
                tree.new$cna.copy = sampcnacopy(tree.new)
                CMCm = getCMCm(tree.new, C)
                tree.new$CM = CMCm[[1]]
                tree.new$Cm = CMCm[[2]]
                tree.new$VAF = getVAF(tree.new, Y)
                tree.new$likelihood = getlikelihood(tree.new, R, X, 
                                                    WM, Wm, epsilonM, epsilonm)
                tree.temp=addsamptree(tree,tree.new)
                tree=tree.temp[[1]]
                samptree.accept[sampi]=tree.temp[[2]]
                if (sampi%%writeskip == 0) {
                    samptree[[writei]] = tree
                    writei = writei + 1
                }
                samptree.lik[sampi]=tree$likelihood
                samptree.accept.rate[sampi]=mean(samptree.accept[max(1,sampi-999):sampi])
                if ((sampi >= 2*min.simrun) & (samptree.lik[sampi] <= mean(samptree.lik[max((sampi-1000),1):max((sampi-1),1)])) &
                    (samptree.accept.rate[sampi] <= mean(samptree.accept.rate[max((sampi-1000),1):max((sampi-1),1)])) &
                    (samptree.accept.rate[sampi] <= 0.1)) break
                sampi = sampi + 1
                ######## sample whether SNA falls in major or minor allele
                if (any(tree$Q == 1)) {
                    tree.new = tree
                    q.temp = which(tree.new$Q == 1)
                    q.temp.change = q.temp[sample.int(1, n = length(q.temp))]
                    tree.new$H[q.temp.change] = 1 - tree.new$H[q.temp.change]
                    tree.new$VAF = getVAF(tree.new, Y)
                    tree.new$likelihood = getlikelihood(tree.new, R, X, 
                                                        WM, Wm, epsilonM, epsilonm)
                    tree.temp=addsamptree(tree,tree.new)
                    tree=tree.temp[[1]]
                    samptree.accept[sampi]=tree.temp[[2]]
                    if (sampi%%writeskip == 0) {
                        samptree[[writei]] = tree
                        writei = writei + 1
                    }
                    samptree.lik[sampi]=tree$likelihood
                    samptree.accept.rate[sampi]=mean(samptree.accept[max(1,sampi-999):sampi])
                    if ((sampi >= 2*min.simrun) & (samptree.lik[sampi] <= mean(samptree.lik[max((sampi-1000),1):max((sampi-1),1)])) &
                        (samptree.accept.rate[sampi] <= mean(samptree.accept.rate[max((sampi-1000),1):max((sampi-1),1)])) &
                        (samptree.accept.rate[sampi] <= 0.1)) break
                    sampi = sampi + 1
                }
            }
            sampchaink[[numi]] = samptree[1:(writei - 1)]
            sampchaink.lik[[numi]]=samptree.lik
            sampchaink.accept.rate[[numi]]=samptree.accept.rate
        }
        ###################################### plotting and saving #####
        if (plot.likelihood) {
            par(mfrow=c(1,2))
            xmax=ymin=ymax=rep(NA,numchain)
            for(i in 1:numchain){
                xmax[i]=max(which((!is.na(sampchaink.lik[[i]]))))
                ymin[i]=sampchaink.lik[[i]][1]
                ymax[i]=sampchaink.lik[[i]][xmax[i]]
            }
            
            plot(sampchaink.lik[[1]],xlim=c(1,max(xmax)),ylim=c(min(ymin),max(ymax)),
                 xlab='Iteration',ylab='Log-likelihood',type='l',
                 main=paste('Post. likelihood:',k,'branches'))
            for(numi in 2:numchain){
                points(sampchaink.lik[[numi]],xlim=c(1,max(xmax)),ylim=c(min(ymin),max(ymax)),col=numi,type='l')
            }
            
            plot(sampchaink.accept.rate[[1]],ylim=c(0,1),xlim=c(1,max(xmax)),
                 xlab='Iteration',ylab='Acceptance rate',type='l',
                 main=paste('Acceptance rate:',k,'branches'))
            for(numi in 2:numchain){
                points(sampchaink.accept.rate[[numi]],ylim=c(0,1),xlim=c(1,max(xmax)),col=numi,type='l')
            }
            par(mfrow=c(1,1))
        }
        sampchain[[ki]] = sampchaink
        
        ki = ki + 1
    }
    if(plot.likelihood) {
        dev.off()
    }
    return(sampchain)
} 


canopy.sample.nocna = function(R, X, K, numchain, 
                               max.simrun, min.simrun, writeskip, 
                               projectname, cell.line = NULL,
                               plot.likelihood = NULL) {
    if (!is.matrix(R)) {
        stop("R should be a matrix!")
    }
    if (!is.matrix(X)) {
        stop("X should be a matrix!")
    }
    if (min(K) < 2) {
        stop("Smallest number of subclones should be >= 2!\n")
    }
    if (is.null(cell.line)) {
        cell.line = FALSE
    }
    if (is.null(plot.likelihood)) {
        plot.likelihood = TRUE
    }
    if ( plot.likelihood){
        pdf(file = paste(projectname, "_likelihood.pdf", sep = ""), width = 10, height = 5)
    }
    sampname = colnames(R)
    sna.name = rownames(R)
    sampchain = vector("list", length(K))
    ki = 1
    for (k in K) {
        cat("Sample in tree space with", k, "subclones\n")
        sampchaink = vector("list", numchain)
        sampchaink.lik=vector('list',numchain)
        sampchaink.accept.rate=vector('list',numchain)
        for (numi in 1:numchain) {  # numi: number of chain
            cat("\tRunning chain", numi, "out of", numchain, "...\n")
            ###################################### Tree initialization #####
            text = paste(paste(paste(paste("(", 1:(k - 1), ",", sep = ""), 
                                     collapse = ""), k, sep = ""), paste(rep(")", (k - 1)), 
                                                                         collapse = ""), ";", sep = "")
            runif.temp=runif(1)
            if(k == 5 & runif.temp<0.5){
                text = c('(1,((2,3),(4,5)));')
            }else if(k == 6 & runif.temp < 1/3){
                text = c('(1,((2,3),(4,(5,6))));')
            }else if(k == 6 & runif.temp > 2/3){
                text = c('(1,(2,((3,4),(5,6))));')
            }else if(k == 7 & runif.temp > 1/4 & runif.temp <= 2/4){
                text=c('(1,((2,3),(4,(5,(6,7)))));')
            }else if(k == 7 & runif.temp > 2/4 & runif.temp <= 3/4){
                text = c('(1,((2,3),((4,5),(6,7))));')
            }else if(k == 7 & runif.temp > 3/4){
                text = c('(1,((2,(3,4)),(5,(6,7))));')
            }
            tree <- read.tree(text = text)
            tree$sna = initialsna(tree, sna.name)
            # if(k>=5){tree$relation=getrelation(tree)}
            tree$Z = getZ(tree, sna.name)
            tree$P = initialP(tree, sampname, cell.line)
            tree$VAF = tree$Z%*%tree$P/2
            tree$likelihood = getlikelihood.sna(tree, R, X)
            ###################################### Sample in tree space #####
            sampi = 1
            writei = 1
            samptree = vector("list", max.simrun)
            samptree.lik=rep(NA, max.simrun)
            samptree.accept=rep(NA, max.simrun)
            samptree.accept.rate=rep(NA, max.simrun)
            while(sampi <= max.simrun){
                ######### sample sna positions
                tree.new = tree
                tree.new$sna = sampsna(tree)
                tree.new$Z = getZ(tree.new, sna.name)
                tree.new$VAF = tree.new$Z%*%tree.new$P/2
                tree.new$likelihood = getlikelihood.sna(tree.new, R, X)
                tree.temp=addsamptree(tree,tree.new)
                tree=tree.temp[[1]]
                samptree.accept[sampi]=tree.temp[[2]]
                if (sampi%%writeskip == 0) {
                    samptree[[writei]] = tree
                    writei = writei + 1
                }
                samptree.lik[sampi]=tree$likelihood
                samptree.accept.rate[sampi]=mean(samptree.accept[max(1,sampi-999):sampi])
                # if ((sampi >= 2*min.simrun) & (samptree.lik[sampi] <= mean(samptree.lik[max((sampi-1000),1):max((sampi-1),1)])) &
                #     (samptree.accept.rate[sampi] <= mean(samptree.accept.rate[max((sampi-1000),1):max((sampi-1),1)])) &
                #     (samptree.accept.rate[sampi] <= 0.1)) break
                sampi = sampi + 1
                ######## sample P (clonal proportions)
                tree.new = tree
                tree.new$P = sampP(tree.new, cell.line)
                tree.new$VAF = tree.new$Z%*%tree.new$P/2
                tree.new$likelihood = getlikelihood.sna(tree.new, R, X)
                tree.temp=addsamptree(tree,tree.new)
                tree=tree.temp[[1]]
                samptree.accept[sampi]=tree.temp[[2]]
                if (sampi%%writeskip == 0) {
                    samptree[[writei]] = tree
                    writei = writei + 1
                }
                samptree.lik[sampi]=tree$likelihood
                samptree.accept.rate[sampi]=mean(samptree.accept[max(1,sampi-999):sampi])
                # if ((sampi >= 2*min.simrun) & (samptree.lik[sampi] <= mean(samptree.lik[max((sampi-1000),1):max((sampi-1),1)])) &
                #     (samptree.accept.rate[sampi] <= mean(samptree.accept.rate[max((sampi-1000),1):max((sampi-1),1)])) &
                #     (samptree.accept.rate[sampi] <= 0.1)) break
                sampi = sampi + 1
            }
            sampchaink[[numi]] = samptree[1:(writei - 1)]
            sampchaink.lik[[numi]]=samptree.lik
            sampchaink.accept.rate[[numi]]=samptree.accept.rate
        }
        ###################################### plotting and saving #####
        if (plot.likelihood) {
            par(mfrow=c(1,2))
            xmax=ymin=ymax=rep(NA,numchain)
            for(i in 1:numchain){
                xmax[i]=max(which((!is.na(sampchaink.lik[[i]]))))
                ymin[i]=sampchaink.lik[[i]][1]
                ymax[i]=sampchaink.lik[[i]][xmax[i]]
            }
            
            plot(sampchaink.lik[[1]],xlim=c(1,max(xmax)),ylim=c(min(ymin),max(ymax)),
                 xlab='Iteration',ylab='Log-likelihood',type='l',
                 main=paste('Post. likelihood:',k,'branches'))
            for(numi in 2:numchain){
                points(sampchaink.lik[[numi]],xlim=c(1,max(xmax)),ylim=c(min(ymin),max(ymax)),col=numi,type='l')
            }
            
            plot(sampchaink.accept.rate[[1]],ylim=c(0,1),xlim=c(1,max(xmax)),
                 xlab='Iteration',ylab='Acceptance rate',type='l',
                 main=paste('Acceptance rate:',k,'branches'))
            for(numi in 2:numchain){
                points(sampchaink.accept.rate[[numi]],ylim=c(0,1),xlim=c(1,max(xmax)),col=numi,type='l')
            }
            par(mfrow=c(1,1))
        }
        sampchain[[ki]] = sampchaink
        ki = ki + 1
    }
    if(plot.likelihood) {
        dev.off()
    }
    return(sampchain)
} 


canopy.sample.parallel = function(R, X, WM, Wm, epsilonM, epsilonm, C = NULL, 
                                  Y, K, numchain, max.simrun, min.simrun, writeskip, projectname, 
                                  cell.line = NULL, plot.likelihood = NULL) {
  if (!is.matrix(R)) {
    stop("R should be a matrix!")
  }
  if (!is.matrix(X)) {
    stop("X should be a matrix!")
  }
  if (!is.matrix(WM)) {
    stop("WM should be a matrix!")
  }
  if (!is.matrix(Wm)) {
    stop("Wm should be a matrix!")
  }
  if (min(K) < 2) {
    stop("Smallest number of subclones should be >= 2!\n")
  }
  if (numchain < 2) {
    stop("Number of chains should be >= 2!\n")
  }
  if (is.null(cell.line)) {
    cell.line = FALSE
  }
  if (is.null(plot.likelihood)) {
    plot.likelihood = TRUE
  }
  if (is.null(C)) {
    C = diag(nrow(WM))
    colnames(C) = rownames(C) = rownames(WM)
  }
  if (any(colSums(C) != 1)) {
    stop("Matrix C should have one and only one 1 for each column!")
  }
  if ( plot.likelihood){
    pdf(file = paste(projectname, "_likelihood.pdf", sep = ""), width = 10, height = 5)
  }
  
  sampname = colnames(R)
  sna.name = rownames(R)
  cna.region.name = rownames(C)
  cna.name = colnames(C)
  sampchain = vector("list", length(K))
  
  ki = 1
  for (k in K) {
    cat("Sample in tree space with", k, "subclones\n")
    
    #begin parallel
    nCores <- detectCores() - 1
    cl <- makeCluster(nCores)
    clusterExport(cl, list("k", "R", "X", "WM", "Wm", "epsilonM", "epsilonm", "C",
                           "Y",  "max.simrun", "min.simrun", "writeskip", "projectname",
                           "cell.line", "plot.likelihood",
                           "sampname", "sna.name", "cna.region.name", "cna.name",
                           "canopy.sample.single.numi"), envir = environment())
    clusterEvalQ(cl, library(Canopy)) 
    
    sampchainList = parLapply(cl, 1:numchain, function(numi) {
      sampchainOut = canopy.sample.single.numi(numi, k, R, X, WM, Wm, epsilonM, epsilonm, C, 
                                                 Y,  max.simrun, min.simrun, writeskip, projectname, 
                                                 cell.line, plot.likelihood, 
                                                 sampname, sna.name, cna.region.name, cna.name)
      return( sampchainOut )
    })
    
    stopCluster(cl)
    #end parallel 
    
    sampchaink = lapply(sampchainList, function(x) x[["sampchaink"]])
    sampchaink.lik = lapply(sampchainList, function(x) x[["sampchaink.lik"]])
    sampchaink.accept.rate = lapply(sampchainList, function(x) x[["sampchaink.accept.rate"]])
    
    
    ###################################### plotting and saving #####
    if (plot.likelihood) {
      par(mfrow=c(1,2))
      xmax=ymin=ymax=rep(NA,numchain)
      for(i in 1:numchain){
        xmax[i]=max(which((!is.na(sampchaink.lik[[i]]))))
        ymin[i]=sampchaink.lik[[i]][1]
        ymax[i]=sampchaink.lik[[i]][xmax[i]]
      }
      
      plot(sampchaink.lik[[1]],xlim=c(1,max(xmax)),ylim=c(min(ymin),max(ymax)),
           xlab='Iteration',ylab='Log-likelihood',type='l',
           main=paste('Post. likelihood:',k,'branches'))
      for(numi in 2:numchain){
        points(sampchaink.lik[[numi]],xlim=c(1,max(xmax)),ylim=c(min(ymin),max(ymax)),col=numi,type='l')
      }
      
      plot(sampchaink.accept.rate[[1]],ylim=c(0,1),xlim=c(1,max(xmax)),
           xlab='Iteration',ylab='Acceptance rate',type='l',
           main=paste('Acceptance rate:',k,'branches'))
      for(numi in 2:numchain){
        points(sampchaink.accept.rate[[numi]],ylim=c(0,1),xlim=c(1,max(xmax)),col=numi,type='l')
      }
      par(mfrow=c(1,1))
    }
    sampchain[[ki]] = sampchaink
    
    ki = ki + 1
  }
  if(plot.likelihood) {
    dev.off()
  }
  return(sampchain)
} 



canopy.sample.single.numi = function(numi, k, R, X, WM, Wm, epsilonM, epsilonm, C, 
                                     Y, max.simrun, min.simrun, writeskip, projectname, 
                                     cell.line, plot.likelihood, 
                                     sampname, sna.name, cna.region.name, cna.name){
  
  # cat("\tRunning chain", numi, "out of", numchain, "...\n")
  ###################################### Tree initialization #####
  text = paste(paste(paste(paste("(", 1:(k - 1), ",", sep = ""), 
                           collapse = ""), k, sep = ""), paste(rep(")", (k - 1)), 
                                                               collapse = ""), ";", sep = "")
  runif.temp=runif(1)
  if(k == 5 & runif.temp<0.5){
    text = c('(1,((2,3),(4,5)));')
  }else if(k == 6 & runif.temp < 1/3){
    text = c('(1,((2,3),(4,(5,6))));')
  }else if(k == 6 & runif.temp > 2/3){
    text = c('(1,(2,((3,4),(5,6))));')
  }else if(k == 7 & runif.temp > 1/4 & runif.temp <= 2/4){
    text=c('(1,((2,3),(4,(5,(6,7)))));')
  }else if(k == 7 & runif.temp > 2/4 & runif.temp <= 3/4){
    text = c('(1,((2,3),((4,5),(6,7))));')
  }else if(k == 7 & runif.temp > 3/4){
    text = c('(1,((2,(3,4)),(5,(6,7))));')
  }
  tree <- read.tree(text = text)
  tree$sna = initialsna(tree, sna.name)
  # if(k>=5){tree$relation=getrelation(tree)}
  tree$Z = getZ(tree, sna.name)
  tree$P = initialP(tree, sampname, cell.line)
  tree$cna = initialcna(tree, cna.name)
  tree$cna.copy = initialcnacopy(tree)
  CMCm = getCMCm(tree, C)  # get major and minor copy per clone
  tree$CM = CMCm[[1]]
  tree$Cm = CMCm[[2]]  # major/minor copy per clone
  tree$Q = getQ(tree, Y, C)
  tree$H = tree$Q  # start as all SNAs that precede CNAs land 
  # on major copies
  tree$VAF = getVAF(tree, Y)
  tree$likelihood = getlikelihood(tree, R, X, WM, Wm, epsilonM, 
                                  epsilonm)
  ###################################### Sample in tree space #####
  sampi = 1
  writei = 1
  samptree = vector("list", max.simrun)
  samptree.lik=rep(NA, max.simrun)
  samptree.accept=rep(NA, max.simrun)
  samptree.accept.rate=rep(NA, max.simrun)
  while (sampi <= max.simrun) {  # sampi: MCMC iteration number
    ######### sample sna positions
    tree.new = tree
    tree.new$sna = sampsna(tree)
    tree.new$Z = getZ(tree.new, sna.name)
    tree.new$Q = getQ(tree.new, Y, C)
    tree.new$H = tree.new$Q
    tree.new$VAF = getVAF(tree.new, Y)
    tree.new$likelihood = getlikelihood(tree.new, R, X, 
                                        WM, Wm, epsilonM, epsilonm)
    tree.temp=addsamptree(tree,tree.new)
    tree=tree.temp[[1]]
    samptree.accept[sampi]=tree.temp[[2]]
    if (sampi%%writeskip == 0) {
      samptree[[writei]] = tree
      writei = writei + 1
    }
    samptree.lik[sampi]=tree$likelihood
    samptree.accept.rate[sampi]=mean(samptree.accept[max(1,sampi-999):sampi])
    # if ((sampi >= 2*min.simrun) & (samptree.lik[sampi] <= mean(samptree.lik[max((sampi-1000),1):max((sampi-1),1)])) &
    #     (samptree.accept.rate[sampi] <= mean(samptree.accept.rate[max((sampi-1000),1):max((sampi-1),1)])) &
    #     (samptree.accept.rate[sampi] <= 0.1)) break
    sampi = sampi + 1
    ######## sample cna positions
    tree.new = tree
    tree.new$cna = sampcna(tree)
    CMCm = getCMCm(tree.new, C)
    tree.new$CM = CMCm[[1]]
    tree.new$Cm = CMCm[[2]]
    tree.new$Q = getQ(tree.new, Y, C)
    tree.new$H = tree.new$Q
    tree.new$VAF = getVAF(tree.new, Y)
    tree.new$likelihood = getlikelihood(tree.new, R, X, 
                                        WM, Wm, epsilonM, epsilonm)
    tree.temp=addsamptree(tree,tree.new)
    tree=tree.temp[[1]]
    samptree.accept[sampi]=tree.temp[[2]]
    if (sampi%%writeskip == 0) {
      samptree[[writei]] = tree
      writei = writei + 1
    }
    samptree.lik[sampi]=tree$likelihood
    samptree.accept.rate[sampi]=mean(samptree.accept[max(1,sampi-999):sampi])
    # if ((sampi >= 2*min.simrun) & (samptree.lik[sampi] <= mean(samptree.lik[max((sampi-1000),1):max((sampi-1),1)])) &
    #     (samptree.accept.rate[sampi] <= mean(samptree.accept.rate[max((sampi-1000),1):max((sampi-1),1)])) &
    #     (samptree.accept.rate[sampi] <= 0.1)) break
    sampi = sampi + 1
    ######## sample P (clonal proportions)
    tree.new = tree
    tree.new$P = sampP(tree.new, cell.line)
    tree.new$VAF = getVAF(tree.new, Y)
    tree.new$likelihood = getlikelihood(tree.new, R, X, 
                                        WM, Wm, epsilonM, epsilonm)
    tree.temp=addsamptree(tree,tree.new)
    tree=tree.temp[[1]]
    samptree.accept[sampi]=tree.temp[[2]]
    if (sampi%%writeskip == 0) {
      samptree[[writei]] = tree
      writei = writei + 1
    }
    samptree.lik[sampi]=tree$likelihood
    samptree.accept.rate[sampi]=mean(samptree.accept[max(1,sampi-999):sampi])
    # if ((sampi >= 2*min.simrun) & (samptree.lik[sampi] <= mean(samptree.lik[max((sampi-1000),1):max((sampi-1),1)])) &
    #     (samptree.accept.rate[sampi] <= mean(samptree.accept.rate[max((sampi-1000),1):max((sampi-1),1)])) &
    #     (samptree.accept.rate[sampi] <= 0.1)) break
    sampi = sampi + 1
    ######## sample major and minor copy number
    tree.new = tree
    tree.new$cna.copy = sampcnacopy(tree.new)
    CMCm = getCMCm(tree.new, C)
    tree.new$CM = CMCm[[1]]
    tree.new$Cm = CMCm[[2]]
    tree.new$VAF = getVAF(tree.new, Y)
    tree.new$likelihood = getlikelihood(tree.new, R, X, 
                                        WM, Wm, epsilonM, epsilonm)
    tree.temp=addsamptree(tree,tree.new)
    tree=tree.temp[[1]]
    samptree.accept[sampi]=tree.temp[[2]]
    if (sampi%%writeskip == 0) {
      samptree[[writei]] = tree
      writei = writei + 1
    }
    samptree.lik[sampi]=tree$likelihood
    samptree.accept.rate[sampi]=mean(samptree.accept[max(1,sampi-999):sampi])
    # if ((sampi >= 2*min.simrun) & (samptree.lik[sampi] <= mean(samptree.lik[max((sampi-1000),1):max((sampi-1),1)])) &
    #     (samptree.accept.rate[sampi] <= mean(samptree.accept.rate[max((sampi-1000),1):max((sampi-1),1)])) &
    #     (samptree.accept.rate[sampi] <= 0.1)) break
    sampi = sampi + 1
    ######## sample whether SNA falls in major or minor allele
    if (any(tree$Q == 1)) {
      tree.new = tree
      q.temp = which(tree.new$Q == 1)
      q.temp.change = q.temp[sample.int(1, n = length(q.temp))]
      tree.new$H[q.temp.change] = 1 - tree.new$H[q.temp.change]
      tree.new$VAF = getVAF(tree.new, Y)
      tree.new$likelihood = getlikelihood(tree.new, R, X, 
                                          WM, Wm, epsilonM, epsilonm)
      tree.temp=addsamptree(tree,tree.new)
      tree=tree.temp[[1]]
      samptree.accept[sampi]=tree.temp[[2]]
      if (sampi%%writeskip == 0) {
        samptree[[writei]] = tree
        writei = writei + 1
      }
      samptree.lik[sampi]=tree$likelihood
      samptree.accept.rate[sampi]=mean(samptree.accept[max(1,sampi-999):sampi])
      # if ((sampi >= 2*min.simrun) & (samptree.lik[sampi] <= mean(samptree.lik[max((sampi-1000),1):max((sampi-1),1)])) &
      #     (samptree.accept.rate[sampi] <= mean(samptree.accept.rate[max((sampi-1000),1):max((sampi-1),1)])) &
      #     (samptree.accept.rate[sampi] <= 0.1)) break
      sampi = sampi + 1
    }
  }
  sampchaink = samptree[1:(writei - 1)]
  sampchaink.lik = samptree.lik
  sampchaink.accept.rate = samptree.accept.rate
  
  return(list(sampchaink = sampchaink, 
              sampchaink.lik = sampchaink.lik, 
              sampchaink.accept.rate = sampchaink.accept.rate))
}




canopy.sample = function(R, X, WM, Wm, epsilonM, epsilonm, C = NULL, 
    Y, K, numchain, max.simrun, min.simrun, writeskip, projectname, 
    cell.line = NULL, plot.likelihood = NULL) {
    if (!is.matrix(R)) {
        stop("R should be a matrix!")
    }
    if (!is.matrix(X)) {
        stop("X should be a matrix!")
    }
    if (!is.matrix(WM)) {
        stop("WM should be a matrix!")
    }
    if (!is.matrix(Wm)) {
        stop("Wm should be a matrix!")
    }
    if (min(K) < 2) {
        stop("Smallest number of subclones should be >= 2!\n")
    }
    if (is.null(cell.line)) {
        cell.line = FALSE
    }
    if (is.null(plot.likelihood)) {
        plot.likelihood = TRUE
    }
    if (is.null(C)) {
        C = diag(nrow(WM))
        colnames(C) = rownames(C) = rownames(WM)
    }
    if (any(colSums(C) != 1)) {
        stop("Matrix C should have one and only one 1 for each column!")
    }
    if ( plot.likelihood){
        pdf(file = paste(projectname, "_likelihood.pdf", sep = ""), width = 10, height = 5)
    }
    
    sampname = colnames(R)
    sna.name = rownames(R)
    cna.region.name = rownames(C)
    cna.name = colnames(C)
    sampchain = vector("list", length(K))
    ki = 1
    for (k in K) {
        cat("Sample in tree space with", k, "subclones\n")
        sampchaink = vector("list", numchain)
        sampchaink.lik=vector('list',numchain)
        sampchaink.accept.rate=vector('list',numchain)
        for (numi in 1:numchain) {  # numi: number of chain
            cat("\tRunning chain", numi, "out of", numchain, "...\n")
            ###################################### Tree initialization #####
            text = paste(paste(paste(paste("(", 1:(k - 1), ",", sep = ""), 
                                     collapse = ""), k, sep = ""), paste(rep(")", (k - 1)), 
                                     collapse = ""), ";", sep = "")
            runif.temp=runif(1)
            if(k == 5 & runif.temp<0.5){
                    text = c('(1,((2,3),(4,5)));')
            }else if(k == 6 & runif.temp < 1/3){
                text = c('(1,((2,3),(4,(5,6))));')
            }else if(k == 6 & runif.temp > 2/3){
                text = c('(1,(2,((3,4),(5,6))));')
            }else if(k == 7 & runif.temp > 1/4 & runif.temp <= 2/4){
                text=c('(1,((2,3),(4,(5,(6,7)))));')
            }else if(k == 7 & runif.temp > 2/4 & runif.temp <= 3/4){
                text = c('(1,((2,3),((4,5),(6,7))));')
            }else if(k == 7 & runif.temp > 3/4){
                text = c('(1,((2,(3,4)),(5,(6,7))));')
            }
            tree <- read.tree(text = text)
            tree$sna = initialsna(tree, sna.name)
            # if(k>=5){tree$relation=getrelation(tree)}
            tree$Z = getZ(tree, sna.name)
            tree$P = initialP(tree, sampname, cell.line)
            tree$cna = initialcna(tree, cna.name)
            tree$cna.copy = initialcnacopy(tree)
            CMCm = getCMCm(tree, C)  # get major and minor copy per clone
            tree$CM = CMCm[[1]]
            tree$Cm = CMCm[[2]]  # major/minor copy per clone
            tree$Q = getQ(tree, Y, C)
            tree$H = tree$Q  # start as all SNAs that precede CNAs land 
                             # on major copies
            tree$VAF = getVAF(tree, Y)
            tree$likelihood = getlikelihood(tree, R, X, WM, Wm, epsilonM, 
                epsilonm)
            ###################################### Sample in tree space #####
            sampi = 1
            writei = 1
            samptree = vector("list", max.simrun)
            samptree.lik=rep(NA, max.simrun)
            samptree.accept=rep(NA, max.simrun)
            samptree.accept.rate=rep(NA, max.simrun)
            while (sampi <= max.simrun) {  # sampi: MCMC iteration number
                ######### sample sna positions
                tree.new = tree
                tree.new$sna = sampsna(tree)
                tree.new$Z = getZ(tree.new, sna.name)
                tree.new$Q = getQ(tree.new, Y, C)
                tree.new$H = tree.new$Q
                tree.new$VAF = getVAF(tree.new, Y)
                tree.new$likelihood = getlikelihood(tree.new, R, X, 
                  WM, Wm, epsilonM, epsilonm)
                tree.temp=addsamptree(tree,tree.new)
                tree=tree.temp[[1]]
                samptree.accept[sampi]=tree.temp[[2]]
                if (sampi%%writeskip == 0) {
                  samptree[[writei]] = tree
                  writei = writei + 1
                }
                samptree.lik[sampi]=tree$likelihood
                samptree.accept.rate[sampi]=mean(samptree.accept[max(1,sampi-999):sampi])
                # if ((sampi >= 2*min.simrun) & (samptree.lik[sampi] <= mean(samptree.lik[max((sampi-1000),1):max((sampi-1),1)])) &
                #     (samptree.accept.rate[sampi] <= mean(samptree.accept.rate[max((sampi-1000),1):max((sampi-1),1)])) &
                #     (samptree.accept.rate[sampi] <= 0.1)) break
                sampi = sampi + 1
                ######## sample cna positions
                tree.new = tree
                tree.new$cna = sampcna(tree)
                CMCm = getCMCm(tree.new, C)
                tree.new$CM = CMCm[[1]]
                tree.new$Cm = CMCm[[2]]
                tree.new$Q = getQ(tree.new, Y, C)
                tree.new$H = tree.new$Q
                tree.new$VAF = getVAF(tree.new, Y)
                tree.new$likelihood = getlikelihood(tree.new, R, X, 
                  WM, Wm, epsilonM, epsilonm)
                tree.temp=addsamptree(tree,tree.new)
                tree=tree.temp[[1]]
                samptree.accept[sampi]=tree.temp[[2]]
                if (sampi%%writeskip == 0) {
                  samptree[[writei]] = tree
                  writei = writei + 1
                }
                samptree.lik[sampi]=tree$likelihood
                samptree.accept.rate[sampi]=mean(samptree.accept[max(1,sampi-999):sampi])
                # if ((sampi >= 2*min.simrun) & (samptree.lik[sampi] <= mean(samptree.lik[max((sampi-1000),1):max((sampi-1),1)])) &
                #     (samptree.accept.rate[sampi] <= mean(samptree.accept.rate[max((sampi-1000),1):max((sampi-1),1)])) &
                #     (samptree.accept.rate[sampi] <= 0.1)) break
                sampi = sampi + 1
                ######## sample P (clonal proportions)
                tree.new = tree
                tree.new$P = sampP(tree.new, cell.line)
                tree.new$VAF = getVAF(tree.new, Y)
                tree.new$likelihood = getlikelihood(tree.new, R, X, 
                  WM, Wm, epsilonM, epsilonm)
                tree.temp=addsamptree(tree,tree.new)
                tree=tree.temp[[1]]
                samptree.accept[sampi]=tree.temp[[2]]
                if (sampi%%writeskip == 0) {
                  samptree[[writei]] = tree
                  writei = writei + 1
                }
                samptree.lik[sampi]=tree$likelihood
                samptree.accept.rate[sampi]=mean(samptree.accept[max(1,sampi-999):sampi])
                # if ((sampi >= 2*min.simrun) & (samptree.lik[sampi] <= mean(samptree.lik[max((sampi-1000),1):max((sampi-1),1)])) &
                #     (samptree.accept.rate[sampi] <= mean(samptree.accept.rate[max((sampi-1000),1):max((sampi-1),1)])) &
                #     (samptree.accept.rate[sampi] <= 0.1)) break
                sampi = sampi + 1
                ######## sample major and minor copy number
                tree.new = tree
                tree.new$cna.copy = sampcnacopy(tree.new)
                CMCm = getCMCm(tree.new, C)
                tree.new$CM = CMCm[[1]]
                tree.new$Cm = CMCm[[2]]
                tree.new$VAF = getVAF(tree.new, Y)
                tree.new$likelihood = getlikelihood(tree.new, R, X, 
                  WM, Wm, epsilonM, epsilonm)
                tree.temp=addsamptree(tree,tree.new)
                tree=tree.temp[[1]]
                samptree.accept[sampi]=tree.temp[[2]]
                if (sampi%%writeskip == 0) {
                  samptree[[writei]] = tree
                  writei = writei + 1
                }
                samptree.lik[sampi]=tree$likelihood
                samptree.accept.rate[sampi]=mean(samptree.accept[max(1,sampi-999):sampi])
                # if ((sampi >= 2*min.simrun) & (samptree.lik[sampi] <= mean(samptree.lik[max((sampi-1000),1):max((sampi-1),1)])) &
                #     (samptree.accept.rate[sampi] <= mean(samptree.accept.rate[max((sampi-1000),1):max((sampi-1),1)])) &
                #     (samptree.accept.rate[sampi] <= 0.1)) break
                sampi = sampi + 1
                ######## sample whether SNA falls in major or minor allele
                if (any(tree$Q == 1)) {
                  tree.new = tree
                  q.temp = which(tree.new$Q == 1)
                  q.temp.change = q.temp[sample.int(1, n = length(q.temp))]
                  tree.new$H[q.temp.change] = 1 - tree.new$H[q.temp.change]
                  tree.new$VAF = getVAF(tree.new, Y)
                  tree.new$likelihood = getlikelihood(tree.new, R, X, 
                    WM, Wm, epsilonM, epsilonm)
                  tree.temp=addsamptree(tree,tree.new)
                  tree=tree.temp[[1]]
                  samptree.accept[sampi]=tree.temp[[2]]
                  if (sampi%%writeskip == 0) {
                    samptree[[writei]] = tree
                    writei = writei + 1
                  }
                  samptree.lik[sampi]=tree$likelihood
                  samptree.accept.rate[sampi]=mean(samptree.accept[max(1,sampi-999):sampi])
                  # if ((sampi >= 2*min.simrun) & (samptree.lik[sampi] <= mean(samptree.lik[max((sampi-1000),1):max((sampi-1),1)])) &
                  #     (samptree.accept.rate[sampi] <= mean(samptree.accept.rate[max((sampi-1000),1):max((sampi-1),1)])) &
                  #     (samptree.accept.rate[sampi] <= 0.1)) break
                  sampi = sampi + 1
                }
            }
            sampchaink[[numi]] = samptree[1:(writei - 1)]
            sampchaink.lik[[numi]]=samptree.lik
            sampchaink.accept.rate[[numi]]=samptree.accept.rate
        }
        ###################################### plotting and saving #####
        if (plot.likelihood) {
            par(mfrow=c(1,2))
            xmax=ymin=ymax=rep(NA,numchain)
            for(i in 1:numchain){
                xmax[i]=max(which((!is.na(sampchaink.lik[[i]]))))
                ymin[i]=sampchaink.lik[[i]][1]
                ymax[i]=sampchaink.lik[[i]][xmax[i]]
            }
            
            plot(sampchaink.lik[[1]],xlim=c(1,max(xmax)),ylim=c(min(ymin),max(ymax)),
                 xlab='Iteration',ylab='Log-likelihood',type='l',
                 main=paste('Post. likelihood:',k,'branches'))
            for(numi in 2:numchain){
                points(sampchaink.lik[[numi]],xlim=c(1,max(xmax)),ylim=c(min(ymin),max(ymax)),col=numi,type='l')
            }
            
            plot(sampchaink.accept.rate[[1]],ylim=c(0,1),xlim=c(1,max(xmax)),
                 xlab='Iteration',ylab='Acceptance rate',type='l',
                 main=paste('Acceptance rate:',k,'branches'))
            for(numi in 2:numchain){
                points(sampchaink.accept.rate[[numi]],ylim=c(0,1),xlim=c(1,max(xmax)),col=numi,type='l')
            }
            par(mfrow=c(1,1))
        }
        sampchain[[ki]] = sampchaink
        
        ki = ki + 1
    }
    if(plot.likelihood) {
        dev.off()
    }
    return(sampchain)
} 


canopy.simrun.diagnostic = function(sampchain, optK, K, writeskip, yRange = 100){
  
  chainLikeDF = rbindlist(
    lapply(seq_along(sampchain[[which(K == optK)]]), function(chainIndex) {
      myChain = sampchain[[which(K == optK)]][[chainIndex]]
      data.frame(
        chainIndex = chainIndex,
        treeIndex = seq_along(myChain) * writeskip,
        likelihood = sapply( myChain, function(myTree) myTree$likelihood )
      )
    })
  )
  
  p = list()
  p[[1]] = ggplot(data = chainLikeDF) + 
    geom_line( aes( x = treeIndex, y = likelihood, color = factor( chainIndex ) ) ) +
    guides(color=guide_legend(title="chainIndex"))
  p[[2]] = ggplot(data = chainLikeDF[likelihood > max(chainLikeDF$likelihood) - yRange]) + 
    geom_line( aes( x = treeIndex, y = likelihood, color = factor( chainIndex ) ) ) +
    guides(color=guide_legend(title="chainIndex"))
  return (grid.arrange(grobs = p, nrow = 1))
  
}


getclonalcomposition = function(tree) {
    snaname = rownames(tree$sna)
    cnaname = rownames(tree$cna)
    n = (nrow(tree$edge) + 2)/2
    clonal.mutations = vector("list", n)
    for (tip in 2:n) {
        child.node = tip
        parent.node = tree$edge[which(tree$edge[, 2] == child.node), 
            1]
        while (parent.node >= (n + 1)) {
            muttemp = snaname[intersect(which(tree$sna[, 2] == parent.node), 
                which(tree$sna[, 3] == child.node))]
            if (length(muttemp) > 0) {
                clonal.mutations[[tip]] = c(clonal.mutations[[tip]], 
                  muttemp)
            }
            child.node = parent.node
            if (child.node == (n + 1)) 
                break
            parent.node = tree$edge[which(tree$edge[, 2] == child.node), 
                1]
        }
    }
    
    for (tip in 2:n) {
        child.node = tip
        parent.node = tree$edge[which(tree$edge[, 2] == child.node), 
            1]
        while (parent.node >= (n + 1)) {
            muttemp = cnaname[intersect(which(tree$cna[, 2] == parent.node), 
                which(tree$cna[, 3] == child.node))]
            if (length(muttemp) > 0) {
                clonal.mutations[[tip]] = c(clonal.mutations[[tip]], 
                  muttemp)
            }
            child.node = parent.node
            if (child.node == (n + 1)) 
                break
            parent.node = tree$edge[which(tree$edge[, 2] == child.node), 
                1]
        }
    }
    
    clonal.mutations[[1]] = "None"
    for (k in 1:length(clonal.mutations)) {
        if (!is.null(clonal.mutations[[k]]) & (length(clonal.mutations[[k]])) > 
            0) {
            clonal.mutations[[k]] = sort(clonal.mutations[[k]])
        }
    }
    return(clonal.mutations)
} 


getCMCm = function(tree, C) {
    k = (nrow(tree$edge) + 2)/2
    s.cna = nrow(C)
    CM = matrix(nrow = s.cna, ncol = k, data = 1)
    rownames(CM) = rownames(C)
    colnames(CM) = paste("clone", 1:k, sep = "")
    Cm = CM
    clonal.cna = vector("list", k)
    for (tip in 2:k) {
        child.node = tip
        parent.node = tree$edge[which(tree$edge[, 2] == child.node), 
            1]
        while (parent.node >= (k + 1)) {
            cnatemp = intersect(which(tree$cna[, 2] == parent.node), 
                which(tree$cna[, 3] == child.node))
            if (length(cnatemp) > 0) {
                clonal.cna[[tip]] = c(clonal.cna[[tip]], cnatemp)
            }
            child.node = parent.node
            if (child.node == (k + 1)) 
                break
            parent.node = tree$edge[which(tree$edge[, 2] == child.node), 1]
        }
    }
    clonal.cna[[1]] = 0
    
    if(all(apply(C,1,sum)==1)){
        Z = matrix(nrow = nrow(tree$cna), ncol = k, data = 0)
        for (ki in 2:k) {
            Z[clonal.cna[[ki]], ki] = 1
        }
        CM.temp=as.matrix(tree$cna.copy[1,])%*%rep(1,ncol(Z))*Z
        Cm.temp=as.matrix(tree$cna.copy[2,])%*%rep(1,ncol(Z))*Z
        CM[Z==1]=CM.temp[Z==1]
        Cm[Z==1]=Cm.temp[Z==1]
    } else{
        for (i in 1:nrow(C)) {
            cnai = which(C[i, ] == 1)
            cnai = cnai[rank(tree$cna[cnai, 2], ties.method = "random")]
            for (s in cnai) {
                for (t in 2:k) {
                    if (is.element(s, clonal.cna[[t]])) {
                        CM[i, t] = tree$cna.copy[1, s]
                        Cm[i, t] = tree$cna.copy[2, s]
                    }
                }
            }
        }
    }
    return(list(CM, Cm))
} 


getCZ = function(tree) {
    k = (nrow(tree$edge) + 2)/2
    t = nrow(tree$cna)
    CZ = matrix(nrow = t, ncol = k, data = 0)
    rownames(CZ) = rownames(tree$cna)
    colnames(CZ) = paste("clone", 1:k, sep = "")
    clonal.cna = vector("list", k)
    for (tip in 2:k) {
        child.node = tip
        parent.node = tree$edge[which(tree$edge[, 2] == child.node), 
            1]
        while (parent.node >= (k + 1)) {
            cnatemp = intersect(which(tree$cna[, 2] == parent.node), 
                which(tree$cna[, 3] == child.node))
            if (length(cnatemp) > 0) {
                clonal.cna[[tip]] = c(clonal.cna[[tip]], cnatemp)
            }
            child.node = parent.node
            if (child.node == (k + 1)) 
                break
            parent.node = tree$edge[which(tree$edge[, 2] == child.node), 
                1]
        }
    }
    clonal.cna[[1]] = 0
    for (k in 2:k) {
        for (t in 1:t) {
            if (is.element(t, clonal.cna[[k]])) {
                CZ[t, k] = 1
            }
        }
    }
    return(CZ)
} 


getlikelihood = function(tree, R, X, WM, Wm, epsilonM, epsilonm) {
    if ((!is.matrix(epsilonM)) & length(epsilonM == 1)) {
        epsilonM = matrix(ncol = ncol(WM), nrow = nrow(WM), data = epsilonM)
        colnames(epsilonM) = colnames(WM)
        rownames(epsilonM) = rownames(WM)
    }
    if ((!is.matrix(epsilonm)) & length(epsilonm == 1)) {
        epsilonm = matrix(ncol = ncol(WM), nrow = nrow(WM), data = epsilonm)
        colnames(epsilonm) = colnames(WM)
        rownames(epsilonm) = rownames(WM)
    }
    # SNA
    sna.read = R
    reference.read = X - R
    sna.freq = tree$VAF
    sna.freq[sna.freq <= 0] = 1e-04
    sna.freq[sna.freq >= 1] = 0.9999
    l=sum(sna.read*log(sna.freq)+reference.read*log(1-sna.freq))
    # CNA
    CM.sample = tree$CM %*% tree$P
    Cm.sample = tree$Cm %*% tree$P
    l = l + sum(dnorm(WM, mean=CM.sample, sd = epsilonM, log = TRUE))
    l = l + sum(dnorm(Wm, mean=Cm.sample, sd = epsilonm, log = TRUE))
    return(l)
} 


getlikelihood.sna = function(tree, R, X) {
    # SNA
    sna.read = R
    reference.read = X - R
    sna.freq = tree$VAF
    sna.freq[sna.freq <= 0] = 1e-04
    sna.freq[sna.freq >= 1] = 0.9999
    l=sum(sna.read*log(sna.freq)+reference.read*log(1-sna.freq))
    return(l)
} 


getQ = function(tree, Y, C) {
    Q = Y[, -1] %*% C  # whether SNAs precede CNAs
    sna.cna.temp=as.matrix(tree$sna[,2])%*%rep(1,ncol(Q))-as.matrix(rep(1,nrow(Q)))%*%tree$cna[,2]
    Q[Q==1 & sna.cna.temp>0]=0
    return(Q)
} 

getVAF = function(tree, Y) {
    k = ncol(tree$CM)
    temp1 = Y %*% (rbind(rep(1, k), tree$CM))
    temp2 = Y %*% (rbind(rep(1, k), tree$Cm))
    denominator = (temp1 + temp2) %*% tree$P
    denominator[denominator < 0.001] = 0.001
    CZ = getCZ(tree)
    temp = tree$Z
    # ith SNP jth clone
    for (i in 1:nrow(temp)) {
        for (j in 1:ncol(temp)) {
            if(temp[i,j]==0) next
            q.temp = which(tree$Q[i, ] == 1)
            if (length(q.temp) > 0) {
                # sort from root to leaves
                q.temp = q.temp[rank(tree$cna[q.temp, 2], ties.method = "random")]
                for (s in q.temp) {
                    # ith SNA from major copy of sth CNA
                    if (tree$H[i, s] == 1) {
                        if (CZ[s, j] == 1) {
                            temp[i, j] = (tree$cna.copy[1, s])
                        }
                    } else {
                        # ith SNA from minor copy of sth CNA
                        if (CZ[s, j] == 1) {
                            temp[i, j] = tree$cna.copy[2, s]
                        }
                    }
                }
            }
        }
    }
    numerator = temp %*% tree$P
    VAF = round(numerator/denominator, 3)
    return(VAF)
} 


getZ = function(tree, sna.name) {
    k = (nrow(tree$edge) + 2)/2
    s = nrow(tree$sna)
    Z = matrix(nrow = s, ncol = k, data = 0)
    rownames(Z) = sna.name
    colnames(Z) = paste("clone", 1:k, sep = "")
    clonal.sna = vector("list", k)
    for (tip in 2:k) {
        child.node = tip
        parent.node = tree$edge[which(tree$edge[, 2] == child.node), 1]
        while (parent.node >= (k + 1)) {
            snatemp = intersect(which(tree$sna[, 2] == parent.node), 
                which(tree$sna[, 3] == child.node))
            if (length(snatemp) > 0) {
                clonal.sna[[tip]] = c(clonal.sna[[tip]], snatemp)
            }
            child.node = parent.node
            if (child.node == (k + 1)) 
                break
            parent.node = tree$edge[which(tree$edge[, 2] == child.node), 1]
        }
    }
    clonal.sna[[1]] = 0
    for (ki in 2:k) {
        Z[clonal.sna[[ki]],ki]=1
    }
    return(Z)
}


initialcna = function(tree, cna.name) {
    k = (nrow(tree$edge) + 2)/2
    cna.no = length(cna.name)
    if(nrow(tree$edge)>2){
        cna.edge = sample(2:nrow(tree$edge), size = cna.no, replace = TRUE)
    } else{
        cna.edge = rep(2, cna.no)
    }
    if (cna.no == 1) {
        cna.mat = t(as.matrix(c(cna.no, tree$edge[cna.edge, ])))
    } else {
        cna.mat = cbind(1:cna.no, tree$edge[cna.edge, ])
    }
    colnames(cna.mat) = c("cna", "cna.st.node", "cna.ed.node")
    rownames(cna.mat) = cna.name
    return(cna.mat)
}


initialcnacopy = function(tree) {
    s.cna = nrow(tree$cna)
    cna.copy = matrix(nrow = 2, ncol = s.cna)
    colnames(cna.copy) = paste("cna", 1:s.cna, sep = "")
    for (j in 1:s.cna) {
        CM.temp = (0:3)[which(rmultinom(1, 1, c(0.1, 0.5, 0.3, 0.1))[, 
            1] == 1)]
        cna.copy[1, j] = CM.temp
        if (CM.temp <= 1) {
            cna.copy[2, j] = 0
        } else {
            cna.copy[2, j] = (0:CM.temp)[which(rmultinom(1, 1, rep(1/(CM.temp + 
                1), (CM.temp + 1)))[, 1] == 1)]
        }
    }
    rownames(cna.copy) = c("major_copy", "minor_copy")
    colnames(cna.copy) = rownames(tree$cna)
    return(cna.copy)
} 


initialP = function(tree, sampname, cell.line) {
    k = (nrow(tree$edge) + 2)/2
    n = length(sampname)
    if (cell.line == TRUE) {
        P = rbind(rep(0, n), rmultinom(n, 100, prob = rep(1/(k - 1), 
            (k - 1)))/100)
    } else if (cell.line == FALSE) {
        P = rmultinom(n, 100, prob = rep(1/k, k))/100
    }
    colnames(P) = sampname
    rownames(P) = paste("clone", 1:k, sep = "")
    return(P)
} 


initialsna = function(tree, sna.name) {
    sna.no = length(sna.name)
    if(nrow(tree$edge)>2){
      sna.edge = sample(2:nrow(tree$edge), size = sna.no, replace = TRUE)
    } else{
      sna.edge = rep(2, sna.no)
    }
    sna.mat = cbind(1:sna.no, tree$edge[sna.edge, ])
    colnames(sna.mat) = c("sna", "sna.st.node", "sna.ed.node")
    rownames(sna.mat) = sna.name
    return(sna.mat)
} 


sampcna = function(tree) {
    t = nrow(tree$cna)
    cna = tree$cna
    if(nrow(tree$edge)>2){
        cna.change = sample.int(sample.int(1, n = t), n = t)
        cna.edge = sample(2:nrow(tree$edge), size = length(cna.change), replace = TRUE)
        cna[cna.change, 2:3] = tree$edge[cna.edge, 1:2]
    }
    return(cna)
} 


sampcnacopy = function(tree) {
    t = nrow(tree$cna)
    cna.change = sample.int(1, n = t)
    cna.copy = tree$cna.copy
    CM.temp = (0:3)[which(rmultinom(1, 1, c(0.1, 0.5, 0.3, 0.1))[, 1] == 
        1)]
    cna.copy[1, cna.change] = CM.temp
    if (CM.temp <= 1) {
        cna.copy[2, cna.change] = 0
    } else {
        cna.copy[2, cna.change] = (0:CM.temp)[which(rmultinom(1, 1, 
            rep(1/(CM.temp + 1), (CM.temp + 1)))[, 1] == 1)]
    }
    return(cna.copy)
} 


sampP = function(tree, cell.line) {
    if (cell.line == TRUE) {
        P = tree$P
        colchange = sample.int(1, n = ncol(P))
        for (j in colchange) {
            rowchange = sample(2:nrow(P), 2)
            temp = sum(P[rowchange, j]) * 1000
            P[rowchange, j] = rmultinom(1, temp, prob = runif(length(rowchange), 
                0, 1))/1000
        }
        P = round(P, 3)
        if (ncol(P) == 1) {
            P[nrow(P), ] = 1 - colSums(as.matrix(P[1:(nrow(P) - 1), , drop = FALSE]))
        } else {
            P[nrow(P), ] = 1 - colSums(P[1:(nrow(P) - 1), , drop = FALSE])
        }
    } else if (cell.line == FALSE) {
        P = tree$P
        colchange = sample.int(1, n = ncol(P))
        for (j in colchange) {
            rowchange = sample(1:nrow(P), 2)
            temp = sum(P[rowchange, j]) * 1000
            P[rowchange, j] = rmultinom(1, temp, prob = runif(length(rowchange), 
                0, 1))/1000
        }
        P = round(P, 3)
        if (ncol(P) == 1) {
            P[nrow(P), ] = 1 - colSums(as.matrix(P[1:(nrow(P) - 1), , drop = FALSE]))
        } else {
            P[nrow(P), ] = 1 - colSums(P[1:(nrow(P) - 1), , drop = FALSE])
        }
    }
    return(P)
} 


sampsna.cluster=function(tree){
  s=nrow(tree$sna.cluster)
  mut=tree$sna.cluster
  if(nrow(tree$edge)>2){
    mut.change=sample.int(1,n=s)
    while(all(mut==tree$sna.cluster)){
        mut.edge=sample(2:nrow(tree$edge),size=length(mut.change),replace=TRUE)
        mut[mut.change,2:3]=tree$edge[mut.edge,1:2]
    }
  }
  return(mut)
}


sampsna = function(tree) {
    s = nrow(tree$sna)
    sna = tree$sna
    if(nrow(tree$edge)>2){
        sna.change = sample.int(min(4, sample.int(1, n = s)), n = s)
        sna.edge = sample(2:nrow(tree$edge), size = length(sna.change), replace = TRUE)
        sna[sna.change, 2:3] = tree$edge[sna.edge, 1:2]
    }
    return(sna)
} 


sortcna = function(tree, C) {
    cna.copy = tree$cna.copy
    for (i in which(apply(C, 1, sum) > 1)) {
        col.temp = which(C[i, ] == 1)
        if (cna.copy[1, col.temp[1]] < cna.copy[1, col.temp[2]]) {
            tree$cna.copy[, col.temp] = cna.copy[, rev(col.temp)]
            tree$cna[col.temp, 2:3] = tree$cna[rev(col.temp), 2:3]
            tree$Q[, col.temp] = tree$Q[, rev(col.temp)]
            tree$H[, col.temp] = tree$H[, rev(col.temp)]
        } else if (cna.copy[1, col.temp[1]] == cna.copy[1, col.temp[[2]]]) {
            if (cna.copy[2, col.temp[1]] < cna.copy[2, col.temp[2]]) {
                tree$cna.copy[, col.temp] = cna.copy[, rev(col.temp)]
                tree$cna[col.temp, 2:3] = tree$cna[rev(col.temp), 2:3]
                tree$Q[, col.temp] = tree$Q[, rev(col.temp)]
                tree$H[, col.temp] = tree$H[, rev(col.temp)]
            }
        }
    }
    tree$clonalmut = getclonalcomposition(tree)
    return(tree)
} 
