source("R/signs.R")
source("R/helpers.R")
source("R/generate.R")

# Load required packages
library(purrr)
library(ggplot2)
library(dplyr)

# Load the scales package
library(scales)

# Create a custom scale using the log1p transformation
log1p_scale <- scale_x_continuous(trans = log1p_trans())

# Set the parameters
K <- 5
n_reps <- 50
reps_grid <- 1:n_reps
alpha_grid <- seq(1e-2, 1, length.out = 3)
noise_sd_grid <- c(0, exp(seq(log(1e-3), log(0.05), length.out = 50)))

# Modify the single_rep function to create noiseless Q and noisy Qh
single_rep <- function(alpha, rep_id, noise_sd) {
  B_base <- alpha * runif_symmetric_matrix(K, 0.2, 0.7)
  
  # Create a symmetric noise matrix
  noise_matrix <- rnorm_symmetric_matrix(K, mean = 0, sd = noise_sd)
  
  # Calculate the noiseless Q
  Q <- eigen(B_base)$vectors
  
  # Add the symmetric noise matrix to B
  B_noisy <- B_base + noise_matrix
  
  # Calculate the noisy Qh
  Q_noisy <- eigen(B_noisy)$vectors
  
  # sample K signs
  sign_vec <- sample(c(-1, 1), K, replace = TRUE)
  
  # Calculate Qh based on the noisy Q
  Qh <- Q_noisy %*% diag(sign_vec)
  
  fraction_mismatch <- count_vector_mismatches(sign_vec, recover_sign(Q, Qh)) / length(sign_vec)
  return(data.frame(alpha = alpha, rep_id = rep_id, noise_sd = noise_sd, fraction_mismatch = fraction_mismatch))
}

# Run the simulation with the modified single_rep function
simulation_results <- expand.grid(alpha = alpha_grid, rep_id = reps_grid, noise_sd = noise_sd_grid) %>%
  pmap_dfr(single_rep)

# Calculate the average fraction_mismatches for each combination of alpha and noise_sd
averages <- aggregate(fraction_mismatch ~ alpha + noise_sd, data = simulation_results, mean)


# Create the line plot with the desired aesthetics and save it to a variable
averages %>%
  ggplot(aes(x = noise_sd, y = fraction_mismatch, color = as.factor(alpha), group = alpha)) +
  geom_line() +
  log1p_scale +
  labs(x = "Noise Standard Deviation",
       y = "Fraction of Mismatches",
       color = "Alpha",
       title = paste0("Fraction of Mismatches vs. Noise Standard Deviation (K = ", K, ")")) +
  theme_minimal() -> plot_with_K

# Save the plot with a filename that reflects the K used
ggsave(paste("fraction_mismatches_vs_noise_sd_K_", K, ".png", sep = ""), plot = plot_with_K, width = 8, height = 6)

# Print the plot_output variable to display the plot
print(plot_with_K)

