# P1 is numnber of channels of Y
# P2 is number of channels of X
# T is the length of time series
# J is the total number of scales/levels
# Spec is multi-variate time evolving spectrum matrix (mvEWS)

library(mvLSW)
library(zoo)
library(stats)
library(rtrend)
library(boot)
library(ggplot2)
library(dplyr)
library(pracma) 
library(signal)
library(astsa)
library(plot.matrix)
#install.packages("remotes")   # if you don't have the package already
#remotes::install_github("nickpoison/astsa/astsa_build")


# The input of function extract_spectrum, here we extract based on the structure of S = [Sxx,Sxy; Syx, Syy]
extract_spectrum <- function(spectrum,size_XY){
  dim_spec <- dim(spectrum)
  nrows <- dim_spec[1]
  ncols <- dim_spec[2]
  J <- dim_spec[3]
  T <- dim_spec[4]
  Px <- size_XY[1]
  Py <- size_XY[2]
  S_xx <- spectrum[1:Px, 1:Px,,]
  S_xy <- spectrum[1:Px,(Px+1):ncols,,]
  S_yx <- spectrum[(Px+1):nrows,1:Px,,]
  S_yy <- spectrum[(Px+1):nrows,(Px+1):ncols,,]
  S <- list(Sxx=S_xx,Sxy=S_xy,Syx=S_yx,Syy=S_yy)
  return(S)
}


# function for claculating the WaveCanCoh with partitioned sepctrum
# rho[J,T] is the time varying scale-specific canonical coherence between group X and group Y, 
#which is the largest eigenvalue of inv(Sxx) * Sxy * inv(Syy) * Syx, 
# a[J,Px,T] stores the contribution of each channel in X, which is the eigenvector of inv(Sxx) * Sxy * inv(Syy) * Syx, 
#corresponding to the largest eigenvalue
# b[J,Py,T] stores the contribution of each channel in Y,which is the eigenvector of inv(Syy) * Syx * inv(Sxx) * Sxy, 
#corresponding to the largest eigenvalue
WaveCan <- function(Sxx,Sxy,Syx,Syy){
  T <- dim(Sxx)[4]
  J <- dim(Sxx)[3]
  Px <- dim(Sxx)[1] # dim of group X
  Py <- dim(Syy)[1] # dim of group Y
  a <- array(0,dim = c(J,Px,T))
  b <- array(0,dim = c(J,Py,T))
  rho <- array(0,dim = c(J,T))
  
  for (j in 1:J){
    for (t in 1:T){
      
      inv_Sxx <- solve(Sxx[,,j,t])
      inv_Syy <- solve(Syy[,,j,t])
      A <- inv_Sxx %*% Sxy[,,j,t] %*% inv_Syy %*% Syx[,,j,t] 
      B <- inv_Syy %*% Syx[,,j,t] %*% inv_Sxx %*% Sxy[,,j,t] 
      eig_A <- eigen(A)
      eig_B <- eigen(B)
      rho[j,t] <- Re(eig_A$values[1]) # take the real part to avoid numerical issues, the imaginary parts are zero indeed
      a[j,,t] <- Re(eig_A$vectors[,1])
      b[j,,t] <- Re(eig_B$vectors[,1])
    }
  }
  waveCan <- list(rho = rho, a = a, b = b)
  return(waveCan)
}



# calculate the wavelet coherence between two group of multivariate time series X and Y.
# The input includes X and Y, should with formation T x P, filter_number and wavelet_family is used for specifying wavelet functions.
Calculate_WaveCan <- function(X,Y, filter_number = 1, wavelet_family = "DaubExPhase"){
  data_XY <- cbind(X,Y) # combine the X and Y
  Px <- ncol(X) # access the dim of X and Y
  Py <- ncol(Y)
  size_XY <- c(Px,Py)
  
  group_spectrum <- mvEWS(X = data_XY, filter.number = filter_number, family = wavelet_family,  kernel.name = "daniell",
               optimize = TRUE, bias.correct = TRUE,  tol = 1e-10)  # estimate the group spectrum S= [Sxx,Sxy; Syx, Syy]
  extracted_spectrum <- extract_spectrum(group_spectrum$spectrum,size_XY)  # extract the group spectrum into four corresponding blocks
  
  WaveCan_XY_results <- WaveCan(extracted_spectrum$Sxx,extracted_spectrum$Sxy,extracted_spectrum$Syx,extracted_spectrum$Syy) 
 
  return(WaveCan_XY_results)
}



#set.seed(1126)

#This function is for estimated the LWS matrix accurately by multiple replicates
estimate_spectrum <- function(spec, group_size = 10, filter_number = 1, wavelet_family = "DaubExPhase", seed = NULL){
  if (!is.null(seed)) set.seed(seed)
  dimen <- dim(spec)
  spec_results <- array(0,dim = c(group_size,dimen))
  True_mvEWS <- as.mvLSW(x = spec, filter.number = 1, family = "DaubExPhase") # Generate true mvEWS
  for (i in 1:group_size){
    # Generate a new seed for each iteration that depends on the initial seed
    if (!is.null(seed)) set.seed(seed + i)
    data_XY <- rmvLSW(Spectrum = True_mvEWS, noiseFN = rnorm)
    group_spectrum <- mvEWS(X = data_XY, filter.number = filter_number, family = wavelet_family,  kernel.name = "daniell",
                          kernel.param = sqrt(T), bias.correct = TRUE,  tol = 1e-10)
    spec_results[i,,,,] <- group_spectrum$spectrum
  }
  return(spec_results)
}


# This is function for get the reproducible accurate estimated spectrum.
do_one_LSW_simulation <- function(Spec, group_size, Px, Py, seed = NULL){
  # Set seed at the beginning of the function
  if (!is.null(seed)) set.seed(seed)
  
  # Pass the seed to estimate_spectrum
  Spec_results <- estimate_spectrum(Spec, group_size = group_size, 
                                  filter_number = 1, 
                                  wavelet_family = "DaubExPhase",
                                  seed = seed)
  
  Spec_mean <- apply(Spec_results, c(2, 3, 4, 5), mean)
  J <- dim(Spec)[3]
  T <- dim(Spec)[4]
  S <- extract_spectrum(Spec_mean,c(Px,Py))
  rho <- array(0,dim = c(J,T))
  a <- array(0,dim = c(J,Px,T))
  b <- array(0,dim = c(J,Py,T))
  wave_can <- WaveCan(S$Sxx,S$Sxy,S$Syx,S$Syy)
  rho <- wave_can$rho
  a <- wave_can$a
  b <- wave_can$b
  return(list(rho = rho, a = a, b = b))
}

# This function is used for perform multiple simulation with LSW model
do_multiple_LSW_simulation <- function(Spec, group_size, num_rep, Px, Py, seed = NULL){
  if (!is.null(seed)) set.seed(seed)
  
  J <- dim(Spec)[3]
  T <- dim(Spec)[4]
  multiple_rho <- array(0,dim = c(num_rep,J,T))
  multiple_a <- array(0,dim = c(num_rep,J,Px,T))
  multiple_b <- array(0,dim = c(num_rep,J,Py,T))
  
  for (i in 1:num_rep){
    # Generate a new seed for each iteration that depends on the initial seed
    iteration_seed <- if (!is.null(seed)) seed + i * 1000 else NULL
    one_result <- do_one_LSW_simulation(Spec, group_size, Px, Py, seed = iteration_seed)
    multiple_rho[i,,] <- one_result$rho
    multiple_a[i,,,] <- one_result$a
    multiple_b[i,,,] <- one_result$b
  }
  return(list(rho = multiple_rho, a = multiple_a, b = multiple_b))
}





# Calculate confidence intervals for wavelet canonical correlation analysis, where the "wavecan_results" should be 
# returned output from function "do_multiple_LSW_simulation"
calculate_bootstrap_CI <- function(wavecan_results, alpha = 0.05, method = "normal") {
  # Extract dimensions
  J <- dim(wavecan_results$rho)[2]  # number of scales
  T <- dim(wavecan_results$rho)[3]  # time points
  Px <- dim(wavecan_results$a)[2]   # dimension of X
  Py <- dim(wavecan_results$b)[2]   # dimension of Y
  
  # Calculate means for all parameters
  rho_mean <- apply(wavecan_results$rho, c(2,3), mean)
  a_mean <- apply(wavecan_results$a, c(2,3,4), mean)
  b_mean <- apply(wavecan_results$b, c(2,3,4), mean)
  
  # Initialize CI array
  rho_ci <- array(0, dim = c(2, J, T))
  
  if (method == "normal") {
    # Normal theory method
    rho_sd <- apply(wavecan_results$rho, c(2,3), sd)
    z_score <- qnorm(1 - alpha/2)  # 1.96 for alpha = 0.05
    rho_ci[1,,] <- rho_mean + z_score * rho_sd  # upper CI
    rho_ci[2,,] <- rho_mean - z_score * rho_sd  # lower CI
  } else if (method == "percentile") {
    # Percentile method
    for (j in 1:J) {
      for (t in 1:T) {
        rho_ci[1,j,t] <- quantile(wavecan_results$rho[,j,t], probs = 1 - alpha/2)  # upper CI
        rho_ci[2,j,t] <- quantile(wavecan_results$rho[,j,t], probs = alpha/2)      # lower CI
      }
    }
  } else {
    stop("Method must be either 'normal' or 'percentile'")
  }
  
  # Return results in the format expected by plot.wavecan
  return(list(
    rho = list(mean = rho_mean, ci = rho_ci),
    a = list(mean = a_mean, ci = NULL),
    b = list(mean = b_mean, ci = NULL)
  ))
}



# Enhanced plotting function for wavelet canonical correlation analysis results, 
# the argument can be "rho", "a" or "b" accordingly
plot.wavecan <- function(wavecan_results, target = "rho", j = 1, window_size = 1, 
                        with_true = FALSE, true_results = NULL) {
  
  # Define muted, professional colors for different channels
  channel_colors <- c(
    "#E41A1C",  # Red
    "#377EB8",  # Blue
    "#4DAF4A",  # Green
    "#984EA3",  # Purple
    "#FF7F00",  # Orange
    "#A65628",  # Brown
    "#F781BF",  # Pink
    "#1B9E77",  # Teal
    "#D95F02"   # Burnt orange
  )
  
  if (target == "rho") {
    T_len <- dim(wavecan_results$rho$mean)[2]
    # Create scaled time points
    time_scaled <- seq(0, 1, length.out = T_len)
    
    # Create long format data for ggplot
    df <- data.frame(
      time = time_scaled,
      value = wavecan_results$rho$mean[j,],
      type = "estimated"
    )
    
    if (with_true && !is.null(true_results)) {
      df_true <- data.frame(
        time = time_scaled,
        value = true_results$rho[j,],
        type = "true"
      )
      df <- rbind(df, df_true)
    }
    
    # Apply moving average if window_size > 1
    if (window_size > 1) {
      df$value <- ave(df$value, df$type, FUN = function(x) movmean(x, window_size))
    }
    
    # Create CI data
    ci_df <- data.frame(
      time = time_scaled,
      lower = wavecan_results$rho$ci[1,j,],
      upper = wavecan_results$rho$ci[2,j,]
    )
    
    if (window_size > 1) {
      ci_df$lower <- movmean(ci_df$lower, window_size)
      ci_df$upper <- movmean(ci_df$upper, window_size)
    }
    
    # Create plot
    p <- ggplot() +
      geom_ribbon(data = ci_df, aes(x = time, ymin = lower, ymax = upper), 
                 fill = "skyblue", alpha = 0.3) +
      geom_line(data = df, aes(x = time, y = value, color = type, 
                linetype = type), size = 1.2) +
      scale_color_manual(
        values = c("estimated" = "#2F4F4F", "true" = "#8B0000"),  # Dark slate gray and Dark red
        name = "type",
        guide = if (!with_true) "none" else "legend"
      ) +
      scale_linetype_manual(
        values = c("estimated" = "solid", "true" = "dashed"),
        name = "type",
        guide = if (!with_true) "none" else "legend"
      ) +
      scale_x_continuous(breaks = seq(0, 1, by = 0.2)) +
      labs(
        title = bquote(paste("Estimated ", hat(rho), " at scale j = ", .(j))),
        y = expression(hat(rho)),
        x = "Time"
      ) +
      theme_minimal(base_size = 14)
    
  } else if (target %in% c("a", "b")) {
    # For a and b, plot means for each channel without CI
    if (target == "a") {
      means <- wavecan_results$a$mean
      true_vals <- if (!is.null(true_results)) true_results$a else NULL
      title_text <- bquote(paste("Estimated ", hat(a), " (channel contributions for X) at scale j = ", .(j)))
      num_channels <- dim(means)[2]
    } else {
      means <- wavecan_results$b$mean
      true_vals <- if (!is.null(true_results)) true_results$b else NULL
      title_text <- bquote(paste("Estimated ", hat(b), " (channel contributions for Y) at scale j = ", .(j)))
      num_channels <- dim(means)[2]
    }
    
    T_len <- dim(means)[3]
    time_scaled <- seq(0, 1, length.out = T_len)
    
    # Create a data frame for plotting
    df <- data.frame(time = rep(time_scaled, num_channels))
    
    # Add means for each channel
    for (i in 1:num_channels) {
      channel_data <- means[j,i,]
      if (window_size > 1) {
        channel_data <- movmean(channel_data, window_size)
      }
      df[[paste0("channel_", i)]] <- rep(channel_data, times = 1)
    }
    
    # Reshape data for ggplot
    df_long <- tidyr::pivot_longer(df, 
                                 cols = starts_with("channel_"),
                                 names_to = "channel",
                                 values_to = "value")
    df_long$type <- "estimated"
    
    # Create the base plot
    p <- ggplot()
    
    # Add true values if available
    if (with_true && !is.null(true_vals)) {
      # Add true values for each channel
      df_true <- data.frame(time = rep(time_scaled, num_channels))
      for (i in 1:num_channels) {
        channel_data <- true_vals[j,i,]
        if (window_size > 1) {
          channel_data <- movmean(channel_data, window_size)
        }
        df_true[[paste0("channel_", i)]] <- rep(channel_data, times = 1)
      }
      
      df_true_long <- tidyr::pivot_longer(df_true,
                                        cols = starts_with("channel_"),
                                        names_to = "channel",
                                        values_to = "value")
      df_true_long$type <- "true"
      
      # Combine estimated and true data
      df_long <- rbind(df_long, df_true_long)
    }
    
    # Add all lines at once
    p <- p + geom_line(data = df_long,
                       aes(x = time, y = value, 
                           color = channel), 
                       size = 0.8)
    
    # Add scales after all layers
    p <- p + 
      scale_color_manual(
        values = channel_colors[1:num_channels],
        labels = paste("channel", 1:num_channels),
        name = "channel"
      ) +
      scale_x_continuous(breaks = seq(0, 1, by = 0.2)) +
      labs(
        title = title_text,
        y = ifelse(target == "a", expression(hat(a)), expression(hat(b))),
        x = "Time"
      ) +
      theme_minimal(base_size = 14)
  }
  
  return(p)
}

