library(reticulate)
library(waveslim)
library(fields)


### import data and preprocessing
np <- import("numpy")

behav <- np$load("data/epoched_data/080718_mitt/080718_mitt_tsr_w-4000_bvr.npz")
behav_keys = behav$f[["keys"]]
behav_data = behav$f[["data"]]

lfp <- np$load("data/epoched_data/080718_mitt/080718_mitt_tsr_w-4000_lfp.npz")
lfp_keys = lfp$f[["keys"]]
lfp_data = lfp$f[["data"]]

index_inseq <- behav_data[2, ] == 1
index_correct_inseq <- (behav_data[1, ] == 1) & index_inseq
index_wrong_inseq <- (behav_data[1, ] == 0) & index_inseq



# trial x T x P
lfp_data_correct <- lfp_data[index_correct_inseq, ,] 
lfp_data_wrong <- lfp_data[index_wrong_inseq, ,] 

# trial index x time x channel
#lfp_data_correct_pred <- lfp_data_correct[1:40, , , drop = FALSE] # cut a time period and only use first 40 trials in correct
#lfp_data_wrong_pred <- lfp_data_wrong[, , , drop = FALSE]


lfp_data_correct_4096 <- array(NA, dim = c(230, 4096, 22))
for (i in 1:230) {
  for (ch in 1:22) {
    orig <- lfp_data_correct[i, , ch]
    lfp_data_correct_4096[i, , ch] <- c(orig[1:96], orig)
  }
}

# For lfp_data_wrong
num_trials_wrong <- dim(lfp_data_wrong)[1]
lfp_data_wrong_pred <- array(NA, dim = c(32, 4096, 22))
for (i in 1:32) {
  for (ch in 1:22) {
    orig <- lfp_data_wrong[i, , ch]
    lfp_data_wrong_pred[i, , ch] <- c(orig[1:96], orig)
  }
}







### formal analysis
lfp_data_correct_pred <- lfp_data_correct_4096[1:40,,]

dim_correct <- dim(lfp_data_correct_pred )
dim_wrong <- dim(lfp_data_wrong_pred )
num_trial_correct <- dim_correct[1]
num_trial_wrong <- dim_wrong[1]
T <- dim_correct[2]
J <- log2(T)
P <- dim_correct[3]

# set the regions we concern based on the index of channel
Region1 <- c(1,2,4,5)
Region2 <- c(12,13,14,15,16)
#Region3 <- c(13,14,15,16,17)

rho_correct <- array(NA, dim = c(num_trial_correct,J,T))
a_correct <- array(NA, dim = c(num_trial_correct,J,length(Region1),T))
b_correct <- array(NA, dim = c(num_trial_correct,J,length(Region2),T))


## calculate the WaveCanCoh with LFP data for correct trails and incorrect trials
for (i in 1:num_trial_correct){
  X <- lfp_data_correct_pred[i,,Region1]
  Y <- lfp_data_correct_pred[i,,Region2]
  wavecan_result <- Calculate_WaveCan(X,Y,filter_number = 4, wavelet_family = "DaubExPhase")
  rho_correct[i,,] <- wavecan_result$rho
  a_correct[i,,,] <- wavecan_result$a
  b_correct[i,,,] <- wavecan_result$b
  
}

rho_correct_mean <- apply(rho_correct, c(2,3), mean)
a_correct_mean <- apply(a_correct,c(2,3,4),mean)
b_correct_mean <- apply(b_correct,c(2,3,4),mean)



rho_wrong <- array(NA, dim = c(num_trial_wrong,J,T))
a_wrong <- array(NA, dim = c(num_trial_wrong,J,length(Region1),T))
b_wrong <- array(NA, dim = c(num_trial_wrong,J,length(Region2),T))

for (i in 1:num_trial_wrong){
  X <- lfp_data_wrong_pred[i,,Region1]
  Y <- lfp_data_wrong_pred[i,,Region2]
  wavecan_result <- Calculate_WaveCan(X,Y,filter_number = 4, wavelet_family = "DaubExPhase")
  rho_wrong[i,,] <- wavecan_result$rho
  a_wrong[i,,,] <- wavecan_result$a
  b_wrong[i,,,] <- wavecan_result$b
  
}

### Figure 5 is obtained based on the results below
rho_wrong_mean <- apply(rho_wrong, c(2,3), mean)
a_wrong_mean <- apply(a_wrong,c(2,3,4),mean)
b_wrong_mean <- apply(b_wrong,c(2,3,4),mean)




# cut the time

rho_correct_mean_mitt <- rho_correct_mean[,97:4096]
a_correct_mean_mitt <- a_correct_mean[,,97:4096]
b_correct_mean_mitt <-b_correct_mean[,,97:4096]
rho_correct_mitt <- rho_correct[,,97:4096]
a_correct_mitt <- a_correct[,,,97:4096]
b_correct_mitt <-b_correct[,,,97:4096]


rho_wrong_mean_mitt <- rho_wrong_mean[,97:4096]
a_wrong_mean_mitt <- a_wrong_mean[,,97:4096]
b_wrong_mean_mitt <-b_wrong_mean[,,97:4096]
rho_wrong_mitt <- rho_wrong[,,97:4096]
a_wrong_mitt <- a_wrong[,,,97:4096]
b_wrong_mitt <-b_wrong[,,,97:4096]



### plot function for Figure 4, and Figure 10 can also be generated by changing the argument scale_j
plot_rho_comparison <- function(scale_j, smooth_window = 0) {
  # Check if scale is valid
  if (scale_j < 1 || scale_j > J) {
    stop("Scale must be between 1 and ", J)
  }
  
  # Set margins (left, right, bottom, top) and increase text size
  par(mar = c(3.5, 4.5, 3, 1), cex.lab = 1.5, cex.axis = 1.3, cex.main = 1.6)
  
  # Get the data for the specified scale
  rho_correct <- rho_correct_mean_mitt[scale_j,]
  rho_wrong <- rho_wrong_mean_mitt[scale_j,]
  T <- length(rho_correct)
  time_points <- seq(-2, 2, length.out = T)
  
  # Apply smoothing if window size is specified
  if (smooth_window > 0) {
    rho_correct <- stats::filter(rho_correct, rep(1/smooth_window, smooth_window), sides = 2)
    rho_wrong <- stats::filter(rho_wrong, rep(1/smooth_window, smooth_window), sides = 2)
  }
  
  # Create the plot with improved colors and larger text
  plot(time_points, rho_correct, 
       type = "l", 
       col = "#2E86C1",  # Deep blue
       lwd = 2.5,  # Increased line width
       xlab = "Time (s)",
       ylab = expression(hat(rho)),
       main = "",  # Leave main empty
       ylim = c(min(rho_correct, rho_wrong, na.rm = TRUE),
                max(rho_correct, rho_wrong, na.rm = TRUE)),
       cex.lab = 1.5,    # Larger axis labels
       cex.axis = 1.3)   # Larger axis numbers
  
  # Add the title above the plot
  title(main = bquote("Wavelet coherence at scale" ~ j == .(scale_j)), line = 2.5, cex.main = 1.6)
  
  # Add the wrong trials line with improved color
  lines(time_points, rho_wrong, 
        col = "#C0392B",  # Deep red
        lwd = 2.5)  # Increased line width
  
  # Add legend with improved position, formatting and larger text
  legend(x = "bottomleft",
         legend = c("Correct", "Incorrect"),
         col = c("#2E86C1", "#C0392B"),
         lwd = 2.5,
         bty = "n",
         bg = "white",
         box.lwd = 0,
         cex = 1.2,  # Larger legend text
         pt.cex = 1.2)  # Larger legend symbols
}

# Example usage:

par(mfrow=c(1,1))

plot_rho_comparison(scale_j = 4, smooth_window = 200)  # Plot for scale 1 with smoothing




### plot function for generate LFP realizations, corresponding to Figure 1
plot_channels_first_trial <- function() {
  old_par <- par(no.readonly = TRUE)
  on.exit(par(old_par))
  
  par(mfrow = c(9, 1), 
      mar = c(0.5, 4, 0.2, 1),
      oma = c(2, 1, 1, 1),
      cex = 1.2,
      mgp = c(2, 0.7, 0))
  
  region1_color <- "#2C3E50"
  region2_color <- "#C0392B"
  time_points <- seq(-2, 2, length.out = T)
  
  # Get your region indices
  Region1 <- c(1,2,4,5)
  Region3 <- c(13,14,15,16,17)
  
  # Calculate y-limits for each region
  region1_data <- as.vector(lfp_data_correct_pred[1, , Region1])
  region2_data <- as.vector(lfp_data_correct_pred[1, , Region3])
  y_min1 <- min(region1_data)
  y_max1 <- max(region1_data)
  y_range1 <- y_max1 - y_min1
  y_limits1 <- c(y_min1 - 0.05*y_range1, y_max1 + 0.05*y_range1)
  y_min2 <- min(region2_data)
  y_max2 <- max(region2_data)
  y_range2 <- y_max2 - y_min2
  y_limits2 <- c(y_min2 - 0.05*y_range2, y_max2 + 0.05*y_range2)
  
  # Plot Region1 channels
  for (idx in Region1) {
    channel_data <- lfp_data_correct_pred[1, , idx]
    plot(time_points, channel_data, 
         type = "l",
         col = region1_color,
         lwd = 1.5,
         xlab = "",
         ylab = paste0("T", idx),
         xaxt = "n",
         ylim = y_limits1,
         cex.lab = 1.2,
         cex.axis = 1.1)
  }
  # Plot Region3 channels
  for (j in seq_along(Region3)) {
    idx <- Region3[j]
    channel_data <- lfp_data_correct_pred[1, , idx]
    is_last <- (j == length(Region3))
    plot(time_points, channel_data, 
         type = "l",
         col = region2_color,
         lwd = 1.5,
         xlab = "",
         ylab = paste0("T", idx),
         xaxt = "n",
         ylim = y_limits2,
         cex.lab = 1.2,
         cex.axis = 1.1)
    if (is_last) {
      axis(1, at = seq(-2, 2, by = 0.5), labels = seq(-2, 2, by = 0.5), cex.axis = 1.1)
    }
  }
  mtext("Time(s)", side = 1, outer = TRUE, line = 1, cex = 1.2)
}

# Plot the channels
plot_channels_first_trial()








