"""Plotting utils for CIFAR-10/CIFAR-10.1 accuracy plots."""
import matplotlib
import matplotlib.pyplot
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter
from matplotlib.ticker import MultipleLocator
import numpy as np
import scipy.stats
import statsmodels.api as sm
import json
import os
import pickle

def clopper_pearson(k, n, alpha=0.05):
  """Clopper pearson confidence intervals.

  http://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval
  Alpha confidence intervals for a binomial distribution of k expected
  successes on n trials.Clopper Pearson intervals are a conservative estimate.

  Args:
    k: int for number of sucesses
    n: int for number of trials
    alpha: confidence level

  Returns:
    lo: lower confidence delta
    hi: upper confidence delta
  """
  lo = scipy.stats.beta.ppf(alpha/2, k, n-k+1)
  hi = scipy.stats.beta.ppf(1 - alpha/2, k+1, n-k)
  return lo, hi


def get_model_upper_and_lower_cis(model_accuracies, dataset_size):
  """Computes the upper and lower confidence interval for model accuracy.

  Uses a clopper pearson test.

  Args:
      model_accuracies (dict): Maps model name to accuracies.
      dataset_size (int): Size of dataset model accuracy was computed on.
  Returns:
      A dict mapping model names to tuple containing lower and upper accuracies
      for the
      model confidence intervals.
  """
  result = {}
  for model, acc in model_accuracies.items():
    lo, hi = clopper_pearson(acc * dataset_size, dataset_size)
    result[model] = (acc - lo, hi - acc)
  return result


def get_ordered_plot_data(original_accuracies, original_ci, new_accuracies,
                          new_ci, models_to_ignore=None):
  """Order data alphabetically by model name for plotting functions.

  Args:
      original_accuracies (dict): Maps model name to original accuracies.
      original_ci (dict): Maps model name to tuple containing lower and upper
        ci for orig acc.
      new_accuracies (dict): Maps model name to new accuracies.
      new_ci (dict): Maps model name to tuple containing lower and upper
        ci for new acc.
      models_to_ignore (list): List of model names to ignore

  Returns:
      Tuple containing data for plotting. Accuracies now in percentage format.
      The orig_ci_arr and new_ci_arr are (2, num_models) numpy arrays containing
      separate - and + values for each model.  The first row contains
      the lower accuracies, the second row contains the upper accuracies.
  """
  orig_acc = []
  new_acc = []
  orig_ci_lower = []
  orig_ci_upper = []
  new_ci_lower = []
  new_ci_upper = []
  for model in sorted(new_accuracies.keys()):
    if models_to_ignore:
      if model in models_to_ignore:
        continue
    if model not in original_accuracies:
      print(model)
    assert model in original_accuracies
    assert model in original_ci
    assert model in new_ci

    orig_acc.append(original_accuracies[model] * 100.0)
    new_acc.append(new_accuracies[model] * 100.0)
    orig_ci_lower.append(original_ci[model][0] * 100.0)
    orig_ci_upper.append(original_ci[model][1] * 100.0)
    new_ci_lower.append(new_ci[model][0] * 100.0)
    new_ci_upper.append(new_ci[model][0] * 100.0)
  orig_ci_arr = np.stack([np.array(orig_ci_lower),
                          np.array(orig_ci_upper)],
                         axis=1).transpose()
  new_ci_arr = np.stack(
      [np.array(new_ci_lower), np.array(new_ci_upper)], axis=1).transpose()

  return np.array(orig_acc), np.array(new_acc), orig_ci_arr, new_ci_arr


def run_bootstrap_linreg(xs, ys, num_bootstrap_samples, x_eval_grid, seed):
  """Run bootstrap linear regression."""
  rng = np.random.RandomState(seed)
  num_samples = xs.shape[0]
  result_coeffs = []
  result_y_grid_vals = []
  x_eval_grid_padded = np.stack([np.ones(x_eval_grid.shape[0]), x_eval_grid],
                                axis=1)
  for _ in range(num_bootstrap_samples):
    cur_indices = rng.choice(num_samples, num_samples)
    cur_x = np.stack([np.ones(num_samples), xs[cur_indices]], axis=1)
    cur_y = ys[cur_indices]
    cur_coeffs = np.linalg.lstsq(cur_x, cur_y, rcond=None)[0]
    result_coeffs.append(cur_coeffs)
    cur_y_grid_vals = np.dot(x_eval_grid_padded, cur_coeffs)
    result_y_grid_vals.append(cur_y_grid_vals)
  return np.vstack(result_coeffs), np.vstack(result_y_grid_vals)


def get_bootstrap_cis(xs,
                      ys,
                      num_bootstrap_samples,
                      x_eval_grid,
                      seed,
                      significance_level_coeffs=95,
                      significance_level_grid=95):
  """Get bootstrap confidence intervals."""
  coeffs, y_grid_vals = run_bootstrap_linreg(xs, ys, num_bootstrap_samples,
                                             x_eval_grid, seed)
  result_coeffs = []
  result_grid_lower = []
  result_grid_upper = []
  percentile_lower_coeffs = (100.0 - significance_level_coeffs) / 2
  percentile_upper_coeffs = 100.0 - percentile_lower_coeffs
  percentile_lower_grid = (100.0 - significance_level_grid) / 2
  percentile_upper_grid = 100.0 - percentile_lower_grid
  for ii in range(coeffs.shape[1]):
    cur_lower = np.percentile(
        coeffs[:, ii], percentile_lower_coeffs, interpolation='lower')
    cur_upper = np.percentile(
        coeffs[:, ii], percentile_upper_coeffs, interpolation='higher')
    result_coeffs.append((cur_lower, cur_upper))
  for ii in range(x_eval_grid.shape[0]):
    cur_lower = np.percentile(
        y_grid_vals[:, ii], percentile_lower_grid, interpolation='lower')
    cur_upper = np.percentile(
        y_grid_vals[:, ii], percentile_upper_grid, interpolation='higher')
    result_grid_lower.append(cur_lower)
    result_grid_upper.append(cur_upper)
  return result_coeffs, result_grid_lower, result_grid_upper


matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

x_plotting_resolution = 200

grid_linewidth = 1.5
main_linewidth = 3
label_fontsize = 18
tick_fontsize = 16
markersize = 60


def generate_accuracy_scatter_plot(orig_acc, new_acc, orig_ci, new_ci,
                                   extra_orig_acc, extra_new_acc, extra_orig_ci,
                                   extra_new_ci,
                                   title, xlim, ylim,
                                   return_separate_legend=False,
                                   include_legend=False,
                                   num_bootstrap_samples=None,
                                   unit='%', num_legend_columns=1, title_y=1.0,
                                   grid_on='both'):
  """Generate accuracy scatter plot."""
  for x in orig_acc:
    if x > xlim[1]:
      print(x, xlim[1])
    assert x <= xlim[1]
    if x < xlim[0]:
      print(x, xlim[0])
    assert x >= xlim[0]
  for y in new_acc:
    if y > ylim[1]:
      print(y, ylim[1])
    assert y <= ylim[1]
    if y < ylim[0]:
      print(y, ylim[0])
    assert y >= ylim[0]

  fig, ax = matplotlib.pyplot.subplots(1)
  major_locator = MultipleLocator(10)
  major_formatter = FormatStrFormatter('%d')
  minor_locator = MultipleLocator(5)
  ax.grid(which=grid_on, color='lightgray', linestyle='-',
          linewidth=grid_linewidth)
  ax.xaxis.set_major_locator(major_locator)
  ax.xaxis.set_major_formatter(major_formatter)
  ax.xaxis.set_minor_locator(minor_locator)
  ax.yaxis.set_major_locator(major_locator)
  ax.yaxis.set_major_formatter(major_formatter)
  ax.yaxis.set_minor_locator(minor_locator)
  ax.tick_params(axis='both', which='major', labelsize=tick_fontsize)

  xs = np.linspace(xlim[0], xlim[1], x_plotting_resolution)
  lin_fit = scipy.stats.linregress(orig_acc, new_acc)
  slope = lin_fit[0]
  intercept = lin_fit[1]
  lin_fit_ys = xs * slope + intercept
  print(f'Slope {slope}, intercept {intercept}, r {lin_fit[2]},'
        f' pvalue {lin_fit[3]}, stderr {lin_fit[4]}')
  if num_bootstrap_samples is not None:
    coeffs_ci, fit_lower, fit_upper = get_bootstrap_cis(
        orig_acc,
        new_acc,
        num_bootstrap_samples,
        xs,
        720257663,
        significance_level_coeffs=95,
        significance_level_grid=95)
    print(f'Bootstrap CIs: {coeffs_ci}')
  sm_model = sm.OLS(new_acc,
                    np.stack([np.ones(orig_acc.shape[0]), orig_acc], axis=1))
  sm_results = sm_model.fit()
  print(sm_results.summary())

  ax.set_xlim(xlim)
  ax.set_ylim(ylim)
  ideal_repro_line = ax.plot(
      xs,
      xs,
      linestyle='dashed',
      color='black',
      linewidth=main_linewidth,
      label='Same accuracy (y = x)')
  if num_bootstrap_samples is not None:
    ax.fill_between(
        xs,
        fit_upper,
        fit_lower,
        color='tab:red',
        alpha=0.4,
        zorder=6,
        edgecolor='none',
        linewidth=0.0)
  ax.errorbar(
      orig_acc,
      new_acc,
      xerr=orig_ci,
      yerr=new_ci,
      capsize=2,
      linewidth=2,
      ls='none',
      color='tab:blue',
      alpha=0.5,
      zorder=8)
  ax.errorbar(
      extra_orig_acc,
      extra_new_acc,
      xerr=extra_orig_ci,
      yerr=extra_new_ci,
      capsize=2,
      linewidth=2,
      ls='none',
      color='tab:orange',
      alpha=0.5,
      zorder=2)
  model_points = ax.scatter(
      orig_acc,
      new_acc,
      zorder=9,
      color='tab:blue',
      s=markersize,
      label='Model',
      alpha=0.5,
      linewidths=0)
  extra_points = ax.scatter(
      extra_orig_acc,
      extra_new_acc,
      zorder=1,
      color='tab:orange',
      s=markersize,
      label='Extra models',
      alpha=0.5,
      linewidths=0)
  fit_line = ax.plot(
      xs,
      lin_fit_ys,
      color='tab:red',
      zorder=7,
      linewidth=main_linewidth,
      label='Linear fit')
  ax.set_xlabel(f'Original test accuracy ({unit})', fontsize=label_fontsize)
  ax.set_ylabel(f'New test accuracy ({unit})', fontsize=label_fontsize)
  ax.set_title(title, fontsize=label_fontsize, y=title_y)
  if include_legend:
    ax.legend([ideal_repro_line[0], model_points, extra_points, fit_line[0]], [
        'Same accuracy (y = x)', 'Model accuracy', 'Extra models', 'Linear fit'
    ],
              fontsize=label_fontsize,
              ncol=num_legend_columns,
              markerscale=1.5,
              frameon=False)
  fig.tight_layout()
  if return_separate_legend:
    fig_legend = matplotlib.pyplot.figure()
    fig_legend.legend([ideal_repro_line[0], model_points, fit_line[0]],
                      ['Same accuracy (y = x)', 'Model accuracy', 'Linear fit'],
                      fontsize=label_fontsize,
                      ncol=num_legend_columns,
                      markerscale=1.5,
                      frameon=False)
    fig_legend.tight_layout(pad=1.0)
    return fig, ax, fig_legend
  else:
    return fig, ax


def scatter_plot(acc_orig,
                 acc_new,
                 n_orig,
                 n_new,
                 extra_acc_orig,
                 extra_acc_new,
                 title,
                 xlim,
                 ylim,
                 return_separate_legend=False,
                 include_legend=False,
                 num_bootstrap_samples=None,
                 unit='%',
                 num_legend_columns=1,
                 title_y=1.0,
                 grid_on='both',
                 scale='linear'):
  """Generate scatter plot given extra model accuracies."""
  original_ci = get_model_upper_and_lower_cis(acc_orig, n_orig)
  new_ci = get_model_upper_and_lower_cis(acc_new, n_new)
  orig_acc, new_acc, orig_ci, new_ci = get_ordered_plot_data(
      acc_orig, original_ci, acc_new, new_ci, models_to_ignore=[])

  extra_original_ci = get_model_upper_and_lower_cis(extra_acc_orig, n_orig)
  extra_new_ci = get_model_upper_and_lower_cis(extra_acc_new, n_new)
  extra_plot_data = get_ordered_plot_data(extra_acc_orig, extra_original_ci,
                                          extra_acc_new, extra_new_ci,
                                          models_to_ignore=None)
  extra_orig_acc, extra_new_acc, extra_orig_ci, extra_new_ci = extra_plot_data
  if scale == 'linear':
    return generate_accuracy_scatter_plot(
        orig_acc,
        new_acc,
        orig_ci,
        new_ci,
        extra_orig_acc,
        extra_new_acc,
        extra_orig_ci,
        extra_new_ci,
        title=title,
        xlim=xlim,
        ylim=ylim,
        return_separate_legend=return_separate_legend,
        include_legend=include_legend,
        num_bootstrap_samples=num_bootstrap_samples,
        unit=unit,
        num_legend_columns=num_legend_columns,
        title_y=title_y,
        grid_on=grid_on)
  elif scale == 'probit':
    return generate_accuracy_scatter_plot_probit(
        orig_acc,
        new_acc,
        orig_ci,
        new_ci,
        extra_orig_acc,
        extra_new_acc,
        extra_orig_ci,
        extra_new_ci,
        title=title,
        xlim=xlim,
        ylim=ylim,
        return_separate_legend=return_separate_legend,
        include_legend=include_legend,
        num_bootstrap_samples=num_bootstrap_samples,
        unit=unit,
        num_legend_columns=num_legend_columns,
        title_y=title_y,
        grid_on=grid_on)
  elif scale == 'logit':
    return generate_accuracy_scatter_plot_logit(
        orig_acc,
        new_acc,
        orig_ci,
        new_ci,
        extra_orig_acc,
        extra_new_acc,
        extra_orig_ci,
        extra_new_ci,
        title=title,
        xlim=xlim,
        ylim=ylim,
        return_separate_legend=return_separate_legend,
        include_legend=include_legend,
        num_bootstrap_samples=num_bootstrap_samples,
        unit=unit,
        num_legend_columns=num_legend_columns,
        title_y=title_y,
        grid_on=grid_on)
  else:
    raise NotImplementedError(f"scale '{scale}' unknown")


def probit_from_acc(acc):
  return scipy.stats.norm.ppf(acc / 100.0)


def probit_cis(acc, ci):
  acc_transformed = probit_from_acc(acc)
  err_low = acc_transformed - probit_from_acc(acc - ci[0, :])
  err_high = probit_from_acc(acc + ci[1, :]) - acc_transformed
  return np.stack((err_low, err_high), axis=0)

def logit_from_acc(acc):
  return np.log(acc/(100-acc))


def logit_cis(acc, ci):
  acc_transformed = logit_from_acc(acc)
  err_low = acc_transformed - logit_from_acc(acc - ci[0, :])
  err_high = logit_from_acc(acc + ci[1, :]) - acc_transformed
  return np.stack((err_low, err_high), axis=0)


def generate_accuracy_scatter_plot_probit(orig_acc,
                                          new_acc,
                                          orig_ci,
                                          new_ci,
                                          extra_orig_acc,
                                          extra_new_acc,
                                          extra_orig_ci,
                                          extra_new_ci,
                                          title,
                                          xlim,
                                          ylim,
                                          return_separate_legend=False,
                                          include_legend=False,
                                          num_bootstrap_samples=None,
                                          unit='%',
                                          num_legend_columns=1,
                                          title_y=1.0,
                                          grid_on='both'):
  """Generate probit scaling of plot."""
  orig_acc_probit = probit_from_acc(orig_acc)
  new_acc_probit = probit_from_acc(new_acc)
  orig_ci_probit = probit_cis(orig_acc, orig_ci)
  new_ci_probit = probit_cis(new_acc, new_ci)

  extra_orig_acc_probit = probit_from_acc(extra_orig_acc)
  extra_new_acc_probit = probit_from_acc(extra_new_acc)
  extra_orig_ci_probit = probit_cis(extra_orig_acc, extra_orig_ci)
  extra_new_ci_probit = probit_cis(extra_new_acc, extra_new_ci)

  xlim = [probit_from_acc(xlim[0]), probit_from_acc(xlim[1])]
  ylim = [probit_from_acc(ylim[0]), probit_from_acc(ylim[1])]
  generate_accuracy_scatter_plot(
      orig_acc_probit,
      new_acc_probit,
      orig_ci_probit,
      new_ci_probit,
      extra_orig_acc_probit,
      extra_new_acc_probit,
      extra_orig_ci_probit,
      extra_new_ci_probit,
      title,
      xlim,
      ylim,
      return_separate_legend=return_separate_legend,
      include_legend=include_legend,
      num_bootstrap_samples=num_bootstrap_samples,
      unit=unit,
      num_legend_columns=num_legend_columns,
      title_y=title_y,
      grid_on=grid_on)

def generate_accuracy_scatter_plot_logit(orig_acc,
                                          new_acc,
                                          orig_ci,
                                          new_ci,
                                          extra_orig_acc,
                                          extra_new_acc,
                                          extra_orig_ci,
                                          extra_new_ci,
                                          title,
                                          xlim,
                                          ylim,
                                          return_separate_legend=False,
                                          include_legend=False,
                                          num_bootstrap_samples=None,
                                          unit='%',
                                          num_legend_columns=1,
                                          title_y=1.0,
                                          grid_on='both'):
  """Generate probit scaling of plot."""
  orig_acc_probit = logit_from_acc(orig_acc)
  new_acc_probit = logit_from_acc(new_acc)
  orig_ci_probit = logit_cis(orig_acc, orig_ci)
  new_ci_probit = logit_cis(new_acc, new_ci)

  extra_orig_acc_probit = logit_from_acc(extra_orig_acc)
  extra_new_acc_probit = logit_from_acc(extra_new_acc)
  extra_orig_ci_probit = logit_cis(extra_orig_acc, extra_orig_ci)
  extra_new_ci_probit = logit_cis(extra_new_acc, extra_new_ci)

  xlim = [logit_from_acc(xlim[0]), logit_from_acc(xlim[1])]
  ylim = [logit_from_acc(ylim[0]), logit_from_acc(ylim[1])]
  
  generate_accuracy_scatter_plot(
      orig_acc_probit,
      new_acc_probit,
      orig_ci_probit,
      new_ci_probit,
      extra_orig_acc_probit,
      extra_new_acc_probit,
      extra_orig_ci_probit,
      extra_new_ci_probit,
      title,
      xlim,
      ylim,
      return_separate_legend=return_separate_legend,
      include_legend=include_legend,
      num_bootstrap_samples=num_bootstrap_samples,
      unit=unit,
      num_legend_columns=num_legend_columns,
      title_y=title_y,
      grid_on=grid_on)


def acc_and_ci(acc_orig_dict,
               acc_new_dict,
               n_orig,
               n_new,
               scale='linear'):
  """Calculate confidence intervals for all model accuracies.
  
  Args: 
  - dicts of original and new test accuracies. Accuracies given as values between 0 and 1.

  Returns:
  - lists of accuracies and confidence intervals. 
  """
  ci_orig = get_model_upper_and_lower_cis(acc_orig_dict, n_orig)
  ci_new  = get_model_upper_and_lower_cis(acc_new_dict,  n_new)
  acc_orig, acc_new, ci_orig, ci_new = get_ordered_plot_data(
      acc_orig_dict, ci_orig, acc_new_dict, ci_new, models_to_ignore=[])

  # Scale data
  if scale == 'linear':
    pass
  elif scale == 'probit':
    ci_orig  = probit_cis(acc_orig, ci_orig)
    ci_new   = probit_cis(acc_new, ci_new)

    acc_orig = probit_from_acc(acc_orig)
    acc_new  = probit_from_acc(acc_new)

  elif scale == 'logit':
    ci_orig  = logit_cis(acc_orig, ci_orig)
    ci_new   = logit_cis(acc_new, ci_new)

    acc_orig = logit_from_acc(acc_orig)
    acc_new  = logit_from_acc(acc_new)
    
  else:
    raise NotImplementedError(f"Scale {scale} not implemented. Choose linear, probit or logit.")

  return acc_orig, acc_new, ci_orig, ci_new


## CIFAR 
def get_cifar_acc_orig(root = './data'):
  with open(os.path.join(root, 'accuracies_cifar-10.json'), 'r') as f:
    acc_orig = json.load(f)
  return acc_orig

def get_cifar_acc_new(root = './data'):
  with open(os.path.join(root, 'accuracies_cifar-10.1_v6.json'), 'r') as f:
    new_orig = json.load(f)
  return new_orig

def get_cifar_shallow_orig(root = './data'):
  with open(os.path.join(root, 'shallow_models.json'), 'r') as data_file:
    json_data = data_file.read()

  data = json.loads(json_data)

  orig_shallow_acc = {}
  for d in data:
    name = "_".join([f"{k}-{v}" for k, v in d['hyperparameters'].items()])
    orig_shallow_acc[name] = d['cifar10-test']
  return orig_shallow_acc

def get_cifar_shallow_new(root = './data'):
  with open(os.path.join(root, 'shallow_models.json'), 'r') as data_file:
    json_data = data_file.read()

  data = json.loads(json_data)

  new_shallow_acc = {}
  for d in data:
    name = "_".join([f"{k}-{v}" for k, v in d['hyperparameters'].items()])
    new_shallow_acc[name] = d['cifar10.1-v6']
  return new_shallow_acc


def get_cifar_rf_new(root = './random_features'):
  with open(os.path.join(root, 'random_features_results.pickle'), 'rb') as f:
    data = pickle.load(f)
  widths = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768]
  regs   = [1e-08, 0.0001, 0.01, 1.0, 10.0, 100.0, 1000.0, 10000.0]

  new_acc_rf = {}
  for width in widths:
    amax = np.argmax([data[width, reg]['test_10'] for reg in regs])
    new_acc_rf[f"rf_{width}"] = data[width, regs[amax]]['test_101']
  return new_acc_rf

def get_cifar_rf_orig(root = './random_features'):
  with open(os.path.join(root, 'random_features_results.pickle'), 'rb') as f:
    data = pickle.load(f)
  widths = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768]
  regs   = [1e-08, 0.0001, 0.01, 1.0, 10.0, 100.0, 1000.0, 10000.0]

  org_acc_rf = {}
  for width in widths:
    amax = np.argmax([data[width, reg]['test_10'] for reg in regs])
    org_acc_rf[f"rf_{width}"] = data[width, regs[amax]]['test_10']
  return org_acc_rf

def cifar_fit_in_linear_space(x, scale, old = False):
  if scale == 'probit':
    if old:
      return scipy.stats.norm.cdf(scipy.stats.norm.ppf(x)*0.9552800496430276 -0.3901688525827589) #just high performing models
    # return scipy.stats.norm.cdf(scipy.stats.norm.ppf(x)*0.8722404500003685 -0.2676642231021804) #with shallow models
    else:
      return scipy.stats.norm.cdf(scipy.stats.norm.ppf(x)*0.8689117034240106 -0.2739049383107107) #with shallow + random feature models
  elif scale =='logit':
    return inverse_logit(logit(x)*0.8318250483081091-0.4736061310131141)
  else:
    raise NotImplementedError(f"scale {scale} not implemented")      

def cifar_fit_in_probit_space(x):

  return 0.8689117034240106*x -0.2739049383107107

def cifar_fit_in_logit_space(x):
  return 0.8318250483081091*x-0.4736061310131141

def plot_fit_in_linear_space(xmin = 0, xmax = 1, n=1000, **kwargs):
  xs = np.linspace(xmin, xmax, n)
  ys = cifar_fit_in_linear_space(xs, scale = 'logit' )
  matplotlib.pyplot.plot(xs, ys, **kwargs)

def plot_fit_in_probit_space(xmin = 0.01, xmax = 0.99, n=1000,**kwargs):
  xs = np.linspace(probit(xmin), probit(xmax), n)
  ys = cifar_fit_in_probit_space(xs)
  matplotlib.pyplot.plot(xs, ys, **kwargs)

def plot_fit_in_logit_space(xmin = 0.01, xmax = 0.99, n=1000,**kwargs):
  xs = np.linspace(logit(xmin), logit(xmax), n)
  ys = cifar_fit_in_logit_space(xs)
  matplotlib.pyplot.plot(xs, ys, **kwargs)

def cifar_effective_robustness(x, y, scale, old = False):
  x = np.array(x)
  y = np.array(y)
  y_pred = cifar_fit_in_linear_space(x, scale = scale, old = old)
  return y-y_pred


## ImageNet 
def get_imagenet_acc_orig(root = './data'):
  with open(os.path.join(root, 'accuracies_imagenet_v1.json'), 'r') as f:
    acc_orig = json.load(f)
  return acc_orig

def get_imagenet_new_orig(root = './data'):
  with open(os.path.join(root, 'accuracies_imagenet_v2.json'), 'r') as f:
    new_orig = json.load(f)
  return new_orig

def imagenet_fit_in_linear_space(x, scale):
  if scale == 'probit':
    return scipy.stats.norm.cdf(scipy.stats.norm.ppf(x)*0.9470801297849037 -0.30906206102195954)
  elif scale == 'logit':
    return inverse_logit(logit(x)*0.9225230091337429-0.48963535536951863)
  else:
    raise NotImplementedError(f"scale {scale} not implemented") 

def imagenet_fit_in_probit_space(x):
  return 0.9470801297849037*x -0.30906206102195954

def imagenet_fit_in_logit_space(x):
  return 0.9225230091337429*x -0.48963535536951863

def plot_imagenet_fit_in_linear_space(xmin = 0, xmax = 1, n=1000, **kwargs):
  xs = np.linspace(xmin, xmax, n)
  ys = imagenet_fit_in_linear_space(xs, scale = 'logit')
  matplotlib.pyplot.plot(xs*100, ys*100, **kwargs)

def plot_imagenet_fit_in_probit_space(xmin = 0.01, xmax = 0.99, n=1000,**kwargs):
  xs = np.linspace(probit(xmin), probit(xmax), n)
  ys = imagenet_fit_in_probit_space(xs)
  matplotlib.pyplot.plot(xs, ys, **kwargs)

def plot_imagenet_fit_in_logit_space(xmin = 0.001, xmax = 0.99, n=1000,**kwargs):
  xs = np.linspace(logit(xmin), logit(xmax), n)
  ys = imagenet_fit_in_logit_space(xs)
  matplotlib.pyplot.plot(xs, ys, **kwargs)

def imagenet_effective_robustness(x, y, scale):
  x = np.array(x)
  y = np.array(y)
  y_pred = imagenet_fit_in_linear_space(x, scale = scale)
  return y-y_pred


def probit(acc):
  return scipy.stats.norm.ppf(acc)

def cdf(acc):
  return scipy.stats.norm.cdf(acc)

def logit(x):
  x = np.array(x)
  return np.log(x/(1-x))

def sigmoid(x):
  return 1/(1+np.exp(-x))

def inverse_logit(x):
  return sigmoid(x)

def add_probit_ticks(ticks = None):
  if ticks is None:
    xticks = [0.01, 0.1, 0.1, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99]
    yticks = [0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99]
  else:
    xticks = ticks
    yticks = ticks
  plt.xticks(probit(xticks), [f"{100*i:.0f}" for i in xticks])
  plt.yticks(probit(yticks), [f"{100*i:.0f}" for i in yticks])

def add_logit_ticks(ticks = None):
  if ticks is None:
    xticks = [0.01, 0.1, 0.1, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99]
    yticks = [0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99]
  else:
    xticks = ticks
    yticks = ticks
  plt.xticks(logit(xticks), [f"{100*i:.0f}" for i in xticks])
  plt.yticks(logit(yticks), [f"{100*i:.0f}" for i in yticks])


def binned_average(X, Y, num_bins = 100):
  """ Calculate the X-binned average of Y
  Args:
    - X: Array of values along the x-axis that will be binned 
    - Y: Array of same shape as X with corresponding y-axis values
    - num_bins: Number of bins
  Returns:
    - bin_centers: Array of values of the bin centers
    - bin_means: Array of means of all entries in each bin
  """
  x = np.array(X).flatten()
  y = np.array(Y).flatten()
  bin_means, bin_edges, binnumber = scipy.stats.binned_statistic(x, y, bins = num_bins)
  bin_std, _, _ = scipy.stats.binned_statistic(x, y, bins = num_bins, statistic = 'std')
  
  bin_centers = bin_edges[:-1]+np.diff(bin_edges)[0]/2
  return bin_centers, bin_means, bin_std