import numpy as np
import matplotlib.pyplot as plt
# Load the data from the .npy file
from tkinter import filedialog

file_path = filedialog.askopenfilename(title="Select a npy file", filetypes=[("npy Files", "*.npy")])
data = np.load(file_path)

# Extract the configuration and results
configs = data[:, :3].astype(int)  # Treating the configurations as integers
results = data[:, 3:]
#plot config[:, 0] as dots only
plt.plot(configs[:, 0], 'o')
plt.show()
# Get unique configurations
unique_configs = np.unique(configs, axis=0)
ns = np.unique(configs[::,0],axis=0)
eps = np.unique(configs[::,1],axis=0)
dacs = np.unique(configs[::,2],axis=0)
# Initialize dictionaries for means and stds
means_dict = {}
stds_dict = {}

# Loop through each unique configuration
for config in unique_configs:
    # Select the rows that match the current configuration
    mask = np.all(configs == config, axis=1)
    filtered_results = results[mask]
    k=tuple([x.item() for x in config])
    #normal QQ plot of filtered results
    import scipy.stats as stats
    #stats.probplot(filtered_results[:,3], dist="norm", plot=plt)
    #plt.show()
    # Calculate mean and std for the experiment results
    means_dict[k] = np.mean(filtered_results, axis=0)
    stds_dict[k] = np.std(filtered_results, axis=0)

# Now you can access the means and stds as dictionaries
print("Means dictionary:", means_dict)
print("Stds dictionary:", stds_dict)



# Set up the plot
plt.figure(figsize=(8, 6))

# Define line styles
line_styles = [':', '-.', '-', ':']
colors = ['#000000', '#00ffff', '#ff0000', '#ff00ff', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']

for column in [1]:
    print("column",column)
    for i, ep in enumerate(eps):
        for j, dac in enumerate(dacs):
            #if j!=0:
            #    continue

            mean_values = [means_dict[(n, ep, dac)][column].item() for n in ns]
            std_values= [stds_dict[(n, ep, dac)][column].item() for n in ns]
            for k in range(len(mean_values)):
                print(f"{mean_values[k]:.3f}({std_values[k]:.3f}),", end=" ")
            print("")
            #print(ep)
            #print(dac)
            #print(y_values)
            y_values=mean_values
            # Color and line style assignment
            #color = colors[(j) ]  # Cycle through the preselected colors
            #line_style = line_styles[i ]  # Cycle through the preselected line styles
            #scaled_ns = [n**(1/3) for n in ns]  # Scale x-axis by n^(1/3)
            #scaled_ns=ns
            #plt.plot(scaled_ns, y_values, label=f'ep={ep}, dac={dac}',
            #         linestyle=line_style, color=color)
# Adjust y-axis to start from 0
plt.ylim(bottom=0)

# Only label n values that are plotted
#plt.xticks(ns)

plt.xlabel('n')
plt.ylabel('y (mean value)')
plt.title('Plot of L_inf error')
plt.legend()
plt.grid(True)
plt.show()