import os
import jax.numpy as jnp
import pickle

def load_object(filename):
    filepath = os.path.join('results', filename)
    with open(filepath, 'rb') as file:
        return pickle.load(file)

def calculate_vals(filename):
  a = load_object(filename)
  mean = jnp.mean(a, axis=1)        # Mean of each run
  mean_of_means = jnp.mean(mean)  # Mean of Means
  median = jnp.median(a, axis=1)      # Median for each run
  mean_of_medians = jnp.mean(median)
  variance = jnp.var(a, axis=1)       # Variance for each run
  mean_of_variances = jnp.mean(variance)
  std_dev = jnp.std(a, axis=1)        # Standard deviation for each run
  mean_of_std_devs = jnp.mean(std_dev)

  min_values = jnp.min(a,axis=1)      # Minimum for each run
  mean_of_mins = jnp.mean(min_values)
  max_values = jnp.max(a,axis=1)      # Maximum for each run
  mean_of_maxes = jnp.mean(max_values)
  min_value = jnp.min(min_values)
  max_value = jnp.max(max_values)
  print({
    'mean': round(float(mean_of_means), 4),
    'median': round(float(mean_of_medians), 4),
    'variance': round(float(mean_of_variances), 6),
    'std_dev': round(float(mean_of_std_devs), 6),
    'avg_min': round(float(mean_of_mins), 4),
    'avg_max': round(float(mean_of_maxes), 4),
    'min': round(float(min_value), 4),
    'max': round(float(max_value), 4)
    })
filename = './test_accuracies_waveform_2_2.pkl'
calculate_vals(filename)