# -*- coding: utf-8 -*-
"""plane_plot.ipynb

Automatically generated by Colaboratory.

"""

# Commented out IPython magic to ensure Python compatibility.
### Visualizing grid plane
import argparse
import numpy as np
import os
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import seaborn as sns

# %matplotlib inline

parser = argparse.ArgumentParser(description='Plane visualization')
parser.add_argument('--result_path', type=str, default='result')
parser.add_argument('--experiment_name', type=str, default='visualization')
parser.add_argument('--model_name', type=str, default='FC')
parser.add_argument('--dataset_name', type=str, default='MNIST')
parser.add_argument('--tr_vmax', type=float, default=0.4,
                    help='color normalization parameter vmax for training loss visualization')
parser.add_argument('--tr_log_alpha', type=float, default=-5.0,
                    help='color normalization parameter log_alpha for training loss visualization')
parser.add_argument('--te_vmax', type=float, default=8.0,
                    help='color normalization parameter vmax for test error visualization')
parser.add_argument('--te_log_alpha', type=float, default=-5.0,
                    help='color normalization parameter log_alpha for test error visualization')
parser.add_argument('--random_initialization_plot_1', default=False, action='store_true')
parser.add_argument('--random_initialization_plot_2', default=False, action='store_true')

args = parser.parse_args()

output_path = os.path.join(args.result_path, args.experiment_name + '_' + args.model_name + '_' + args.dataset_name)
if args.random_initialization_plot_1:
   output_path = os.path.join(output_path, 'random_initialization_plot_1')
elif args.random_initialization_plot_2:
   output_path = os.path.join(output_path, 'random_initialization_plot_2')

file = np.load(os.path.join(output_path, 'plane.npz'))

matplotlib.rc('text', usetex=True)
matplotlib.rc('text.latex', preamble=[r'\usepackage{sansmath}', r'\sansmath'])
matplotlib.rc('font', **{'family':'sans-serif','sans-serif':['DejaVu Sans']})

matplotlib.rc('xtick.major', pad=12)
matplotlib.rc('ytick.major', pad=12)
matplotlib.rc('grid', linewidth=0.8)

sns.set_style('whitegrid')

class LogNormalize(colors.Normalize):

    def __init__(self, vmin=None, vmax=None, clip=None, log_alpha=None):
        self.log_alpha = log_alpha
        colors.Normalize.__init__(self, vmin, vmax, clip)

    def __call__(self, value, clip=None):
        log_v = np.ma.log(value - self.vmin)
        log_v = np.ma.maximum(log_v, self.log_alpha)
        return 0.9 * (log_v - self.log_alpha) / (np.log(self.vmax - self.vmin) - self.log_alpha)

def plane(grid, values, vmax=None, log_alpha=-5, N=7, cmap='jet_r'):
  cmap = plt.get_cmap(cmap)
  if vmax is None:
    clipped = values.copy()
  else:
    clipped = np.minimum(values, vmax)
  log_gamma = (np.log(clipped.max() - clipped.min()) - log_alpha) / N
  levels = clipped.min() + np.exp(log_alpha + log_gamma * np.arange(N + 1))
  levels[0] = clipped.min()
  levels[-1] = clipped.max()
  levels = np.concatenate((levels, [1e10]))
  norm = LogNormalize(clipped.min() - 1e-8, clipped.max() + 1e-8, log_alpha=log_alpha)

  contour = plt.contour(grid[:, :, 0], grid[:, :, 1], values, cmap=cmap, norm=norm,
                        linewidths=2.5,
                        zorder=1,
                        levels=levels)
  contourf = plt.contourf(grid[:, :, 0], grid[:, :, 1], values, cmap=cmap, norm=norm,
                          levels=levels,
                          zorder=0,
                          alpha=0.55)
  colorbar = plt.colorbar(format='%.2g')
  labels = list(colorbar.ax.get_yticklabels())
  labels[-1].set_text(r'$>\,$' + labels[-2].get_text())
  colorbar.ax.set_yticklabels(labels)
  return contour, contourf, colorbar


plt.rcParams['axes.grid'] = False
plt.figure(figsize=(12.4, 7))

contour, contourf, colorbar = plane(
    file['grid'],
    file['tr_loss'],
    vmax = args.tr_vmax,
    log_alpha = args.tr_log_alpha,
    N = 7
)

bend_coordinates = file['bend_coordinates']
fused_model_coordinates = file['fused_model_coordinates']


M1 = plt.scatter(bend_coordinates[0,0], bend_coordinates[0,1], marker='o', c='k', s=120, zorder=2, label='Base model 1')
M2 = plt.scatter(bend_coordinates[2,0], bend_coordinates[2,1], marker='D', c='k', s=120, zorder=2, label='Base model 2')
P2 = plt.scatter(bend_coordinates[1, 0], bend_coordinates[1, 1], marker='s', c='k', s=120, zorder=2, label='Permuted model 2')
F = plt.scatter(fused_model_coordinates[0], fused_model_coordinates[1], marker='*', c='r', s=200, zorder=2.5, label='Fused model')

plt.plot(bend_coordinates[[0, 2], 0], bend_coordinates[[0, 2], 1], c='k', linestyle='--', 
        dashes=(3, 4), linewidth=3, zorder=2)
if not args.random_initialization_plot_1 and not args.random_initialization_plot_2:
  plt.plot(bend_coordinates[[0, 1], 0], bend_coordinates[[0, 1], 1], c='k', linestyle='--', 
          dashes=(3, 4), linewidth=3, zorder=2)


plt.margins(0.0)
plt.yticks(fontsize=18)
plt.xticks(fontsize=18)
plt.legend(scatterpoints=1, loc='upper right', fontsize=16)
colorbar.ax.tick_params(labelsize=18)
plt.savefig(os.path.join(output_path, 'train_loss_plane.pdf'), format='pdf', bbox_inches='tight')
plt.show()

plt.figure(figsize=(12.4, 7))

contour, contourf, colorbar = plane(
    file['grid'],
    file['te_err'],
    vmax = args.te_vmax,
    log_alpha = args.te_log_alpha,
    N = 7
)

bend_coordinates = file['bend_coordinates']
fused_model_coordinates = file['fused_model_coordinates']


M1 = plt.scatter(bend_coordinates[0,0], bend_coordinates[0,1], marker='o', c='k', s=120, zorder=2, label='Base model 1')
M2 = plt.scatter(bend_coordinates[2,0], bend_coordinates[2,1], marker='D', c='k', s=120, zorder=2, label='Base model 2')
P2 = plt.scatter(bend_coordinates[1, 0], bend_coordinates[1, 1], marker='s', c='k', s=120, zorder=2, label='Permuted model 2')
F = plt.scatter(fused_model_coordinates[0], fused_model_coordinates[1], marker='*', c='r', s=200, zorder=2.5, label='Fused model')

plt.plot(bend_coordinates[[0, 2], 0], bend_coordinates[[0, 2], 1], c='k', linestyle='--', 
        dashes=(3, 4), linewidth=3, zorder=2)
if not args.random_initialization_plot_1 and not args.random_initialization_plot_2:
  plt.plot(bend_coordinates[[0, 1], 0], bend_coordinates[[0, 1], 1], c='k', linestyle='--', 
          dashes=(3, 4), linewidth=3, zorder=2)

plt.margins(0.0)
plt.yticks(fontsize=18)
plt.xticks(fontsize=18)
plt.legend(scatterpoints=1, loc='upper right', fontsize=16)
colorbar.ax.tick_params(labelsize=18)
plt.savefig(os.path.join(output_path, 'test_error_plane.pdf'), format='pdf', bbox_inches='tight')
plt.show()