import numpy as np
import matplotlib.pyplot as plt
from metrics import CI


# Uncomment for md table overview
# print(f"| Dataset | Model      | Test Accuracy | Kernel Target Alignment |")
# print(f"| ------- | ---------- | ------------- | ----------------------- |")
# prnt = lambda mean, ci: f" | {mean[-1]:.3f} ± {(ci[2][-1]-ci[1][-1])/2:.3f}"
prnt = lambda mean, ci: f" & ${mean[-1]:.2f} \pm {(ci[2][-1]-ci[1][-1])/2:.2f}$"

def plot(label, color, kta, const, data, steps, results):
  fill = lambda data: np.repeat(np.expand_dims(data, axis=1), 2, axis=1) if const else data
  metrics = [(lambda x: fill(x), 'Test/Accuracy'), *([(lambda x: fill(-x), 'Train/Loss')] if kta else [])]

  for i, (f, key) in enumerate(metrics):
    # if key == 'Train/Loss'
    mean, ci = f(results[key]).mean(axis=0), CI(f(results[key]))
    if const:
      axes[i].plot((0,steps), mean, label=label, color=color, linestyle='--')
      axes[i].fill_between((0,steps), *ci[1:], color=color, alpha=.1)
    else: axes[i].plot(mean, label=label, color=color); axes[i].fill_between(*ci, color=color, alpha=.1)


# def write(label, color, kta, const, data, steps, results):
def write(label, *args, suffix=''):
  import json, os
  # import pandas as pd
  data = dict(zip(('color', 'kta', 'const', 'data', 'steps', 'results'), args))
  data['results'] = {k: v.tolist() for k,v in data['results'].items()}
  classes = data["data"]['data_kwargs']['n_classes'] if 'n_classes' in data["data"]['data_kwargs'] else 2
  path = f'logs/{data["data"]["dataset"]}{classes if classes != 2 else ""}'
  if not os.path.exists(path): os.makedirs(path)
  # pd.DataFrame(data).to_json(f'{path}/{label}.json')
  with open(f'{path}/{label}{suffix}.json', 'w') as f: json.dump(data, f)

  
def plot_results(data, exclude_labels=[]):
  # axes[0].set_ylabel('Test Accuracy'); axes[1].set_ylabel('Kernel Target Alignment')
  axes[0].set_title('Test Accuracy'); axes[1].set_title('KTA')
  filtered_data = {f: d for f, d in data.items() if f not in exclude_labels}
  [plot(f, **d) for f, d in reversed(list(filtered_data.items()))]
  # axes[0].legend(loc="lower right")
  # axes[0].set_ylim([0, 1])
  for ax in axes:
    ax.title.set_fontsize(16)
    ax.tick_params(axis='both', which='major', labelsize=12)
    ymin, ymax = ax.get_ylim()
    ax.set_ylim([ max(np.floor(ymin * 10) / 10, 0),  min(np.ceil(ymax * 10) / 10, 1)])
    ax.set_xlim([0, 100])
  fig.tight_layout()
  if not os.path.exists('plots/training'): os.makedirs('plots/training')

  path = f"plots/training/{dataset}{'-p' if portrait else ''}.pdf"
  if 'ablation' in dataset:
    axes[0].legend(loc="lower right")
    extent = axes[0].get_window_extent().transformed(fig.dpi_scale_trans.inverted())
    fig.savefig(path, bbox_inches=extent.expanded(1.15, 1))
  else: fig.savefig(path)


# def plot(label, color, kta, const, data, steps, results):

def print_execution(data, dataset, noisy=False):
  fill = lambda d, data: np.repeat(np.expand_dims(d, axis=1), 2, axis=1) if data['const'] else d
  mean = lambda d, key: fill(d['results'][key], d).mean(axis=0)[-1]
  std = lambda d, key: fill(d['results'][key], d).std(axis=0)[-1]
  moe = lambda d, key: CI(fill(d['results'][key], d))[2][-1] - mean(d, key)
  key = 'Simulation/Accuracy' if noisy else 'Test/Accuracy'
  if noisy: data = {k: v for k, v in data.items() if 'Linear    ' != k  and 'RBF' not in k}
  fmt = {
    'moons': '\\texttt{moons} ($2$)',
    'circles': '\\texttt{circles} ($2$)',
    'bank': '\\texttt{bank} ($16$)',
    'mnist10': '\\texttt{MNIST} ($784$)',
    'cifar10': '\\texttt{CIFAR10} ($3072$)'
  }
  print(f"Dataset ($d$) & {' & '.join(list(data.keys()))} \\\\ \hline \n{fmt[dataset]} ", end="& ")
  end = lambda d: f"({mean(d, 'Simulation/Depth'):.0f})"+"}$" if noisy else '$'
  pre = '$\itf{' if noisy else '$'
  print(' & '.join([f"{pre}{mean(d, key):.2f} \\pm {moe(d, key):.2f} {end(d)}" for d in data.values()]) + " \\\\ \hline")
  # for d in data.values(): print(f"${mean(d, key):.2f} \\pm {moe(d, key):.2f} {add(d)}$", end=" & ")
  # for l,d in data.items():
  #   if d['const']: print(f"{l}: {d['results']['Test/Accuracy'][:]} ")
  #   else: print(f"{l}: {d['results']['Test/Accuracy'][:,-1]} ")

# Load data and generate plots 
if __name__ == "__main__":
  import os; import json; import sys
  portrait = True
  
  if portrait: fig, axes = plt.subplots(2, 1, figsize=(4, 6))
  else: fig, axes = plt.subplots(1, 2, figsize=(8, 3))

  assert len(sys.argv) > 1, "Please provide a dataset name as an argument."
  dataset = sys.argv[1]
  # print(dataset, end=" ")
  p = f"logs/{dataset}/"
  # load = lambda path: pd.read_json(open(path, 'r'))
  load = lambda path: json.load(open(path, 'r'))
  data = {f[:-5]: dict(load(p+f)) for f in os.listdir(p) if f.endswith(".json")}
  data = {f: { **d, 'results': {k: np.array(v) for k,v in d['results'].items()}} for f, d in data.items()}
  data = dict(sorted(data.items(), key=lambda item: next((i for i, o in enumerate(
    ['QGK (ours)', 'QEK', 'HEE Linear', 'QGK Static', 'HEE PCA', 'HEE', 'RBF', 'Linear']
  ) if o.strip() in item[0]), 7)))

  plot_results(data, ['HEE-D     ', 'QEK-N     '])
  print_execution(data, dataset)
  print_execution(data, dataset, noisy=True)

