"""Plot results for MNIST LR tuning.

Example
-------
python plot_mnist_lr.py
"""
import os
import pdb
import csv
from collections import defaultdict

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('whitegrid')
sns.set_palette('muted')


def load_log(exp_dir, log_filename='iteration_log.csv'):
  result_dict = defaultdict(list)
  with open(os.path.join(exp_dir, log_filename), newline='') as csvfile:
    reader = csv.DictReader(csvfile)
    for row in reader:
      for key in row:
        try:
          if key in ['global_iteration', 'iteration', 'epoch']:
            result_dict[key].append(int(row[key]))
          else:
            result_dict[key].append(float(row[key]))
        except:
          result_dict[key].append(None)
  return result_dict


plot_dirs = [
    ('RMSprop', '-', 'experiments/mnist_baseline/dset:mnist-model:mlp-nl:2-b:rmsprop-m:rmsprop-bs:100-ilr:0.0001-mlr:0.1-lam:0-mstp:0-mint:10-ag:0-wd-0-val:0-ep:100-fac:0.2-dat:60,120,160-seed:11'),
    ('RMSprop-APO', '--', 'experiments/mnist_apo/dset:mnist-model:mlp-nl:2-b:rmsprop-m:rmsprop-bs:100-ilr:0.0001-mlr:0.1-lam:1e-05-mstp:1-mint:1-ag:0-wd-0-val:0-ep:100-fac:0.2-dat:60,120,160-seed:11'),
]

fig = plt.figure()

for name, lstyle, exp_path in plot_dirs:
  stats = load_log(os.path.join(exp_path), log_filename='epoch_log.csv')
  plt.plot(stats['epoch'], stats['train_loss'], label=name,
           linewidth=3, linestyle=lstyle, color='r', alpha=0.6)

plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.xlabel('Epoch', fontsize=20)
plt.ylabel('Training Loss', fontsize=20)
plt.yscale('log')
plt.legend(fontsize=18, fancybox=True, framealpha=0.3)

if not os.path.exists('figures/mnist_lr'):
  os.makedirs('figures/mnist_lr')

plt.savefig('figures/mnist_lr/mnist_lr_plot.pdf',
            bbox_inches='tight', pad_inches=0)
plt.savefig('figures/mnist_lr/mnist_lr_plot.png',
            bbox_inches='tight', pad_inches=0)
plt.close(fig)
