# Function to perform permutation test at specific time points, corrsponding to Algorithm 2 in the paper
perform_permutation_test_timepoints <- function(rho_correct, rho_wrong, 
                                              time_points = NULL,
                                              time_indices = NULL,
                                              window_size = 200, 
                                              n_perm = 1000) {
  # Get dimensions
  J <- dim(rho_correct)[2]  # number of scales
  T <- dim(rho_correct)[3]  # number of time points
  
  # Handle time points vs time indices
  if (!is.null(time_points) && is.null(time_indices)) {
    # Convert time points to indices (mapping -2 to 2s to 1 to T)
    time_indices <- round((time_points + 2) * (T/4))
  } else if (is.null(time_points) && !is.null(time_indices)) {
    # Use provided indices directly
    time_indices <- time_indices
  } else if (is.null(time_points) && is.null(time_indices)) {
    # Default behavior
    time_points <- c(-1, -0.5, 0.5, 1)
    time_indices <- round((time_points + 2) * (T/4))
  } else {
    stop("Please provide either time_points OR time_indices, not both")
  }
  
  n_times <- length(time_indices)
  half_window <- floor(window_size/2)
  
  # Initialize arrays to store results
  T_obs <- matrix(NA, nrow = J, ncol = n_times)
  T_perm <- array(NA, dim = c(n_perm, J, n_times))
  p_values <- matrix(NA, nrow = J, ncol = n_times)
  significant <- matrix(FALSE, nrow = J, ncol = n_times)
  
  # For each scale and time point
  for (j in 1:J) {
    for (t in 1:n_times) {
      # Get window indices
      center_idx <- time_indices[t]
      start_idx <- max(1, center_idx - half_window)
      end_idx <- min(T, center_idx + half_window)
      
      # Calculate observed test statistic (sum over window)
      T_obs[j, t] <- sum(sapply(start_idx:end_idx, function(idx) {
        correct_med <- median(rho_correct[, j, idx])
        wrong_med <- median(rho_wrong[, j, idx])
        (correct_med - wrong_med)^2
      }))
      
      # Combine data for permutation
      n_correct <- dim(rho_correct)[1]
      n_wrong <- dim(rho_wrong)[1]
      
      # Perform permutations
      for (i in 1:n_perm) {
        # For each time point in window
        T_perm[i, j, t] <- sum(sapply(start_idx:end_idx, function(idx) {
          # Combine data for this time point
          combined_data <- c(rho_correct[, j, idx], rho_wrong[, j, idx])
          # Shuffle
          shuffled <- sample(combined_data)
          # Split into two groups
          perm_correct <- shuffled[1:n_correct]
          perm_wrong <- shuffled[(n_correct + 1):length(shuffled)]
          # Calculate test statistic
          (median(perm_correct) - median(perm_wrong))^2
        }))
      }
      
      # Calculate p-value (one-sided test)
      #p_values[j, t] <- mean(T_perm[, j, t] >= T_obs[j, t])
      p_values[j,t] <- (1 + sum(T_perm[,j,t] >T_obs[j,t]))/(1 + n_perm)
      
      
      # Mark as significant if p < 0.05
      significant[j, t] <- p_values[j, t] < 0.05
    }
  }
  
  return(list(p_values = p_values, 
              significant = significant,
              T_perm = T_perm, 
              T_obs = T_obs,
              time_points = time_points,
              time_indices = time_indices,
              window_size = window_size))
}


### plot the density function in Figure 11
plot_timepoint_density <- function(scale_j, perm_results) {
  # prep
  n_perm <- dim(perm_results$T_perm)[1]
  times   <- perm_results$time_points
  p_vals  <- perm_results$p_values[scale_j, ]
  
  # make the expression-strings for each time
  expr_labels <- paste0('t^"*"' , ' == ' , times , ' ~ "s"')
  
  # permuted data
  df_perm <- data.frame(
    statistic  = as.vector(perm_results$T_perm[, scale_j, ]),
    time_label = factor(rep(expr_labels, each = n_perm),
                        levels = expr_labels)
  )
  
  # observed lines
  df_obs <- data.frame(
    statistic  = perm_results$T_obs[scale_j, ],
    time_label = factor(expr_labels, levels = expr_labels)
  )
  
  # p-value text
  df_p <- data.frame(
    p          = p_vals,
    time_label = factor(expr_labels, levels = expr_labels)
  )
  
  ggplot(df_perm, aes(x = statistic)) +
    geom_density(fill = "lightblue", alpha = 0.5) +
    geom_vline(data = df_obs,
               aes(xintercept = statistic),
               color = "red", linewidth = 1) +
    geom_text(data = df_p,
              aes(x = Inf, y = Inf, label = paste0("p = ", round(p, 3))),
              hjust = 1.1, vjust = 1.1, size = 3) +
    facet_wrap(~ time_label, 
               nrow    = 1, 
               scales  = "free_x",
               labeller = label_parsed) +
    labs(
      title = paste0("Permutation densities at j=", scale_j),
      x     = "Test statistic",
      y     = "Density"
    ) +
    theme_minimal(base_size = 14) +
    theme(
      plot.title      = element_text(size = 12, hjust = 0.5, margin = margin(b = 8)),
      strip.text      = element_text(size = 12),
      panel.spacing.x = unit(0.8, "lines"),
      plot.margin     = unit(c(1, 1, 1, 1), "cm")
    )
}



summarize_timepoint_results <- function(perm_results, alpha = 0.05) {
  J  <- dim(perm_results$T_obs)[1]
  TT <- length(perm_results$time_points)
  
  results <- data.frame(
    Scale       = rep(1:J,      each = TT),
    Time        = rep(perm_results$time_points, times = J),
    PValue      = as.vector(t(perm_results$p_values)),      # note t()
    Significant = as.vector(t(perm_results$significant))    # and here
  )
  
  cat("\nResults for all scales and time points:\n")
  print(results, row.names = FALSE)
  cat(sprintf("\nTotal significant results: %d\n", sum(results$Significant)))
  return(results)
}



perm_results <- perform_permutation_test_timepoints(rho_correct_mitt, rho_wrong_mitt)

# plot the test density in Figure 11
plot_timepoint_density(scale_j = 3, perm_results)


# Calculate median differences at specific time points (without windows)
# View summary of results, results in table 1
results_summary <- summarize_timepoint_results(perm_results)
T <- dim(rho_correct_mitt)[3] 
time_points <- c(-1.5, -0.5, 0.5, 1.5)  # Time points of interest
time_indices <- round((time_points + 2) * (T/4))  # Convert to indices

j <- 7
# Calculate median differences at exact time points
exact_med_diff <- sapply(time_indices, function(idx) {
  median(rho_correct_mitt[, j, idx]) - median(rho_wrong_mitt[, j, idx])  # for scale j
})

print(exact_med_diff)


