# Create time-shifted versions for causal wavelet coherence analysis
# Define time lags to analyze
h_values <- c(10, 20, 30, 40, 50)  # Time lags to analyze
num_h <- length(h_values)

# Get dimensions for correct and wrong trials
dim_correct <- dim(lfp_data_correct_pred)
num_trials_correct <- dim_correct[1]  # Should be 40
T <- dim_correct[2]  # This should be 4096
num_channels <- dim_correct[3]

dim_wrong <- dim(lfp_data_wrong_pred)
num_trials_wrong <- dim_wrong[1]  # Should be 32

# Set the regions we're interested in
Region1 <- c(1,2,4,5)  # X region
Region2 <- c(12,13,14,15,16)  # Y region

# Initialize arrays for storing results
J <- log2(T)  # number of wavelet scales

# Arrays for X(t) → Y(t+h) direction
rho_correct_XY <- array(NA, dim = c(num_h, num_trials_correct, J, T))
rho_wrong_XY <- array(NA, dim = c(num_h, num_trials_wrong, J, T))
a_correct_XY <- array(NA, dim = c(num_h, num_trials_correct, J, length(Region1), T))
a_wrong_XY <- array(NA, dim = c(num_h, num_trials_wrong, J, length(Region1), T))
b_correct_XY <- array(NA, dim = c(num_h, num_trials_correct, J, length(Region2), T))
b_wrong_XY <- array(NA, dim = c(num_h, num_trials_wrong, J, length(Region2), T))

# Arrays for Y(t) → X(t+h) direction
rho_correct_YX <- array(NA, dim = c(num_h, num_trials_correct, J, T))
rho_wrong_YX <- array(NA, dim = c(num_h, num_trials_wrong, J, T))
a_correct_YX <- array(NA, dim = c(num_h, num_trials_correct, J, length(Region2), T))
a_wrong_YX <- array(NA, dim = c(num_h, num_trials_wrong, J, length(Region2), T))
b_correct_YX <- array(NA, dim = c(num_h, num_trials_correct, J, length(Region1), T))
b_wrong_YX <- array(NA, dim = c(num_h, num_trials_wrong, J, length(Region1), T))

# Calculate causal wavelet coherence for each time lag
for (h_idx in 1:num_h) {
    h <- h_values[h_idx]
    cat(sprintf("\nProcessing time lag h = %d...\n", h))
    
    # Create shifted versions for this time lag
    lfp_data_correct_pred_h <- array(NA, dim = c(num_trials_correct, T, num_channels))
    lfp_data_wrong_pred_h <- array(NA, dim = c(num_trials_wrong, T, num_channels))
    
    # Shift correct trials
    for (i in 1:num_trials_correct) {
        lfp_data_correct_pred_h[i,1:(T-h),] <- lfp_data_correct_pred[i,(h+1):T,]
        lfp_data_correct_pred_h[i,(T-h+1):T,] <- lfp_data_correct_pred[i,(T-h+1):T,]
    }
    
    # Shift wrong trials
    for (i in 1:num_trials_wrong) {
        lfp_data_wrong_pred_h[i,1:(T-h),] <- lfp_data_wrong_pred[i,(h+1):T,]
        lfp_data_wrong_pred_h[i,(T-h+1):T,] <- lfp_data_wrong_pred[i,(T-h+1):T,]
    }
    
    # Calculate X(t) → Y(t+h)
    cat("Calculating X(t) → Y(t+h)...\n")
    for (i in 1:num_trials_correct) {
        X <- lfp_data_correct_pred[i,,Region1]
        Y <- lfp_data_correct_pred_h[i,,Region2]
        wavecan_result <- Calculate_WaveCan(X, Y, filter_number = 4, wavelet_family = "DaubExPhase")
        rho_correct_XY[h_idx,i,,] <- wavecan_result$rho
        a_correct_XY[h_idx,i,,,] <- wavecan_result$a
        b_correct_XY[h_idx,i,,,] <- wavecan_result$b
    }
    
    for (i in 1:num_trials_wrong) {
        X <- lfp_data_wrong_pred[i,,Region1]
        Y <- lfp_data_wrong_pred_h[i,,Region2]
        wavecan_result <- Calculate_WaveCan(X, Y, filter_number = 4, wavelet_family = "DaubExPhase")
        rho_wrong_XY[h_idx,i,,] <- wavecan_result$rho
        a_wrong_XY[h_idx,i,,,] <- wavecan_result$a
        b_wrong_XY[h_idx,i,,,] <- wavecan_result$b
    }
    
    # Calculate Y(t) → X(t+h)
    cat("Calculating Y(t) → X(t+h)...\n")
    for (i in 1:num_trials_correct) {
        Y <- lfp_data_correct_pred[i,,Region2]
        X <- lfp_data_correct_pred_h[i,,Region1]
        wavecan_result <- Calculate_WaveCan(Y, X, filter_number = 4, wavelet_family = "DaubExPhase")
        rho_correct_YX[h_idx,i,,] <- wavecan_result$rho
        a_correct_YX[h_idx,i,,,] <- wavecan_result$a
        b_correct_YX[h_idx,i,,,] <- wavecan_result$b
    }
    
    for (i in 1:num_trials_wrong) {
        Y <- lfp_data_wrong_pred[i,,Region2]
        X <- lfp_data_wrong_pred_h[i,,Region1]
        wavecan_result <- Calculate_WaveCan(Y, X, filter_number = 4, wavelet_family = "DaubExPhase")
        rho_wrong_YX[h_idx,i,,] <- wavecan_result$rho
        a_wrong_YX[h_idx,i,,,] <- wavecan_result$a
        b_wrong_YX[h_idx,i,,,] <- wavecan_result$b
    }
}

# Calculate mean across trials for each direction and time lag
rho_correct_XY_mean <- apply(rho_correct_XY, c(1,3,4), mean)
rho_wrong_XY_mean <- apply(rho_wrong_XY, c(1,3,4), mean)
rho_correct_YX_mean <- apply(rho_correct_YX, c(1,3,4), mean)
rho_wrong_YX_mean <- apply(rho_wrong_YX, c(1,3,4), mean)

# Create time points for plotting
time_points <- seq(-2, 2, length.out = T)

plot_mean_coherence_by_lag_window <- function(direction = "XY", scale_j, 
                                              rho_correct_mitt, rho_wrong_mitt,
                                              smooth_window = 0,
                                              results_dir = "results/causal_wavecancoh",
                                              legend_position = "bottomleft",
                                              time_window = 2000:3200) {
  # direction: 'XY' or 'YX'
  # scale_j: scale index
  # rho_correct_mitt, rho_wrong_mitt: h=0 data
  # smooth_window: smoothing window size (default 0 = no smoothing)
  # results_dir: directory for results
  # legend_position: position for legend (e.g., 'topright', 'bottomright', etc.)
  # time_window: vector of time indices to average over
  if (!exists("rho_correct_XY_mean")) {
    mean_results <- readRDS(file.path(results_dir, "mean_results.rds"))
    params <- readRDS(file.path(results_dir, "parameters.rds"))
    
    # Assign to global environment for other functions
    rho_correct_XY_mean <<- mean_results$rho_correct_XY_mean
    rho_wrong_XY_mean <<- mean_results$rho_wrong_XY_mean
    rho_correct_YX_mean <<- mean_results$rho_correct_YX_mean
    rho_wrong_YX_mean <<- mean_results$rho_wrong_YX_mean
    h_values <<- params$h_values
    J <<- params$J
  }
  
  # Check inputs
  if (!direction %in% c("XY", "YX")) {
    stop("Direction must be either 'XY' or 'YX'")
  }
  if (scale_j < 1 || scale_j > J) {
    stop("Scale must be between 1 and ", J)
  }
  
  # Get the data for the specified direction
  if (direction == "XY") {
    rho_correct <- rho_correct_XY_mean[,scale_j,]
    rho_wrong <- rho_wrong_XY_mean[,scale_j,]
    main_title <- bquote(X[t] %->% Y[t+h] ~ "at scale" ~ j == .(scale_j))
  } else {
    rho_correct <- rho_correct_YX_mean[,scale_j,]
    rho_wrong <- rho_wrong_YX_mean[,scale_j,]
    main_title <- bquote(Y[t] %->% X[t+h] ~ "at scale" ~ j == .(scale_j))
  }
  
  # Calculate mean across the specified time window for each time lag
  mean_correct <- apply(rho_correct[, time_window, drop=FALSE], 1, mean, na.rm = TRUE)
  mean_wrong <- apply(rho_wrong[, time_window, drop=FALSE], 1, mean, na.rm = TRUE)
  
  # Add h=0 values from mitt data (also using the time window)
  all_h_values <- c(0, h_values)
  mean_correct_all <- c(mean(rho_correct_mitt[,scale_j,time_window], na.rm = TRUE), mean_correct)
  mean_wrong_all <- c(mean(rho_wrong_mitt[,scale_j,time_window], na.rm = TRUE), mean_wrong)
  
  # Apply smoothing if window size is specified
  if (smooth_window > 0) {
    mean_correct_all <- stats::filter(mean_correct_all, rep(1/smooth_window, smooth_window), sides = 2)
    mean_wrong_all <- stats::filter(mean_wrong_all, rep(1/smooth_window, smooth_window), sides = 2)
  }
  
  # Create the plot
  plot(all_h_values, mean_correct_all,
       type = "b",  # Both points and lines
       col = "#2E86C1",  # Deep blue
       pch = 16,  # Solid circle for correct trials
       lwd = 2,
       xlab = "Time lag (h)",
       ylab = "Mean Causal Coherence",
       main = main_title,
       ylim = c(min(mean_correct_all, mean_wrong_all, na.rm = TRUE),
                max(mean_correct_all, mean_wrong_all, na.rm = TRUE)),
       cex.main = 1.6,   # Title font size
       cex.lab = 1.4,    # Axis label font size
       cex.axis = 1.2    # Axis tick label font size
  )
  
  # Add wrong trials line with triangle points
  lines(all_h_values, mean_wrong_all,
        type = "b",
        col = "#C0392B",  # Deep red
        pch = 17,  # Triangle for incorrect trials
        lwd = 2)
  
  # Add legend without trial counts
  legend(legend_position,
         legend = c("Correct", "Incorrect"),
         col = c("#2E86C1", "#C0392B"),
         pch = c(16, 17),  # Different shapes for each condition
         lwd = 2,
         bty = "n",
         bg = "white",
         box.lwd = 0,
         cex = 1,
         inset = 0.02)
  
  # Add grid
  grid(nx = NULL, ny = NULL,
       lty = 2, col = "gray", lwd = 0.5)
}


### Figure 12 in appendix E
par(mfrow = c(1, 4), oma = c(2, 4, 1, 1))  # bottom, left, top, right
for (j in 4:7) {
  plot_mean_coherence_by_lag_window(
    direction = "XY",
    scale_j = j,
    rho_correct_mitt = rho_correct_mitt,
    rho_wrong_mitt = rho_wrong_mitt,
    smooth_window = 0,
    legend_position = "bottomleft"
    # time_window = 2000:3200  # default, can be omitted unless you want a different window
  )
}
#mtext("Time lag (h)", side = 1, line = 2, outer = TRUE, cex = 1.4)
mtext("Mean Causal Coherence", side = 2, line = 2, outer = TRUE, cex = 1.2)



