import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

def log_func(x, a, b):
    return a * np.log(x) + b

def format_value(val):
    return f"{val:.3g}"

clock_data = {
    3: [0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    7: [0, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    17: [0, 0, 1, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    27: [0, 0, 0, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    59: [0, 0, 0, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    64: [0, 0, 0, 2, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    97: [0, 0, 0, 0, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    113: [0, 0, 0, 0, 1, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    128: [0, 0, 0, 1, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    303: [0, 0, 0, 0, 0, 0, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    310: [0, 0, 0, 0, 0, 2, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    499: [0, 0, 0, 0, 0, 0, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    720: [0, 0, 0, 0, 0, 0, 1, 5, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    977: [0, 0, 0, 0, 0, 0, 2, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    1977: [0, 0, 0, 0, 0, 0, 1, 2, 1, 4, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    4013: [0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
}
pizza_data = {
    3: [0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    7: [0, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    17: [0, 0, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    27: [0, 0, 0, 3, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    59: [0, 0, 0, 1, 3, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    64: [0, 0, 0, 2, 5, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    97: [0, 0, 0, 0, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    113: [0, 0, 0, 0, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    128: [0, 0, 0, 0, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    303: [0, 0, 0, 0, 0, 1, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    310: [0, 0, 0, 0, 0, 2, 6, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    499: [0, 0, 0, 0, 0, 0, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    720: [0, 0, 0, 0, 0, 0, 2, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    977: [0, 0, 0, 0, 0, 0, 0, 2, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    1977: [0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    4013: [0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

}
mlp_data = [
    (3, 1.02, 0.14),
    (5, 1.2449, 0.4531),
    (7, 1.7576, 0.73979),
    (13, 1.984, 0.931978),
    (31, 2.46, 0.7151119),
    (59, 3.1091, 0.45837),
    (64, 3.29, 0.45),
    (113, 3.6774, 0.857118),
    (128, 3.62, 0.70),
    (256, 4.14, 0.83),
    (310, 4.57, 1.05),
    (499, 4.95, 0.8047),
    (710, 4.81, 0.83),
    (997, 5.25, 0.925),
    (1409, 5.0, 1.0444),
    (1999, 5.7931, 1.12585),
    (2999, 6.0, 0.7071),
    (4999, 6.4545, 0.6556),
]

# Function to process data
def process_data(data_dict):
    x_values = []
    averages = []
    std_devs = []

    for x, counts in data_dict.items():
        indices = np.arange(len(counts))
        weighted_avg = np.sum(indices * counts) / np.sum(counts)
        std_dev = np.sqrt(np.sum((indices - weighted_avg) ** 2 * counts) / np.sum(counts))
        
        x_values.append(x)
        averages.append(weighted_avg)
        std_devs.append(std_dev)

    popt, _ = curve_fit(log_func, x_values, averages)
    a, b = popt
    fitted_values = log_func(np.array(x_values), a, b)
    ss_res = np.sum((np.array(averages) - fitted_values) ** 2)
    ss_tot = np.sum((np.array(averages) - np.mean(averages)) ** 2)
    r_squared = 1 - (ss_res / ss_tot)
    
    return x_values, averages, std_devs, fitted_values, r_squared, a, b

# Process all datasets
clock_x, clock_avg, clock_std, clock_fit, clock_r2, clock_a, clock_b = process_data(clock_data)
pizza_x, pizza_avg, pizza_std, pizza_fit, pizza_r2, pizza_a, pizza_b = process_data(pizza_data)
mlp_x = [entry[0] for entry in mlp_data]
mlp_avg = [entry[1] for entry in mlp_data]
mlp_std = [entry[2] for entry in mlp_data]

popt_mlp, _ = curve_fit(log_func, mlp_x, mlp_avg)
mlp_a, mlp_b = round(popt_mlp[0], 3), round(popt_mlp[1], 3)
mlp_fit = log_func(np.array(mlp_x), *popt_mlp)
ss_res_mlp = np.sum((np.array(mlp_avg) - mlp_fit) ** 2)
ss_tot_mlp = np.sum((np.array(mlp_avg) - np.mean(mlp_avg)) ** 2)
mlp_r2 = round(1 - (ss_res_mlp / ss_tot_mlp), 3)

clock_r2, clock_a, clock_b = format_value(clock_r2), format_value(clock_a), format_value(clock_b)
pizza_r2, pizza_a, pizza_b = format_value(pizza_r2), format_value(pizza_a), format_value(pizza_b)
mlp_r2, mlp_a, mlp_b = format_value(mlp_r2), format_value(mlp_a), format_value(mlp_b)

# --- New plotting ---
font_size_increment = 12

fig, axs = plt.subplots(1, 2, figsize=(16, 3.8), sharey=True)  # reduce height from 6 to 4.6 inches

# First plot (log x-axis)
axs[0].errorbar(clock_x, clock_avg, yerr=clock_std, fmt='o', capsize=5, 
                label=f"clock $R^2 = {clock_r2}$\n a={clock_a}, b={clock_b}", color='blue')
axs[0].plot(clock_x, clock_fit, color='blue', linestyle='--')
axs[0].errorbar(pizza_x, pizza_avg, yerr=pizza_std, fmt='o', capsize=5, 
                label=f"pizza $R^2 = {pizza_r2}$\n a={pizza_a}, b={pizza_b}", color='green')
axs[0].plot(pizza_x, pizza_fit, color='green', linestyle='--')
axs[0].errorbar(mlp_x, mlp_avg, yerr=mlp_std, fmt='o', capsize=5, 
                label=f"MLP $R^2 = {mlp_r2}$\n a={mlp_a}, b={mlp_b}", color='purple')
axs[0].plot(mlp_x, mlp_fit, color='purple', linestyle='--')
axs[0].set_xscale('log')
axs[0].set_xlabel("mod n (log scale)", fontsize=7 + font_size_increment)
axs[0].set_ylabel("Avg # of frequencies", fontsize=7 + font_size_increment)
axs[0].tick_params(axis='both', labelsize=5 + font_size_increment)

# Shifted title
axs[0].set_title("# frequencies found across moduli and architectures", fontsize=10 + font_size_increment, loc='center', x=0.6)

axs[0].grid(True, linestyle='--', alpha=0.7)
axs[0].legend(fontsize=12)

# Second plot (normal x-axis)
axs[1].errorbar(clock_x, clock_avg, yerr=clock_std, fmt='o', capsize=5, 
                label=f"clock $R^2 = {clock_r2}$\n a={clock_a}, b={clock_b}", color='blue')
axs[1].plot(clock_x, clock_fit, color='blue', linestyle='--')
axs[1].errorbar(pizza_x, pizza_avg, yerr=pizza_std, fmt='o', capsize=5, 
                label=f"pizza $R^2 = {pizza_r2}$\n a={pizza_a}, b={pizza_b}", color='green')
axs[1].plot(pizza_x, pizza_fit, color='green', linestyle='--')
axs[1].errorbar(mlp_x, mlp_avg, yerr=mlp_std, fmt='o', capsize=5, 
                label=f"MLP $R^2 = {mlp_r2}$\n a={mlp_a}, b={mlp_b}", color='purple')
axs[1].plot(mlp_x, mlp_fit, color='purple', linestyle='--')
axs[1].set_xlabel("mod n", fontsize=7 + font_size_increment)
axs[1].grid(True, linestyle='--', alpha=0.7)
axs[1].tick_params(axis='both', labelsize=5 + font_size_increment)

plt.tight_layout()
plt.savefig("plots_side_by_side_adjusted.png", dpi=300, bbox_inches="tight")
plt.show()