#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np

import matplotlib.pyplot as plt
import os


def save_cost_maps(save_dir, name, weight):
    heat_maps = np.maximum(weight[1], 0) * 10
    for h_i, heat_map in enumerate(heat_maps):
        fig, ax = plt.subplots(1, 1, figsize=(15, 15))
        im = ax.imshow(heat_map)
        cbar = ax.figure.colorbar(im, ax=ax, cmap="YlGn")
        cbar.ax.set_ylabel('color bar', rotation=-90, va="bottom")
        for edge, spine in ax.spines.items():
            spine.set_visible(False)
        ax.set_xticks(np.arange(heat_map.shape[1] + 1) - .5, minor=True)
        ax.set_yticks(np.arange(heat_map.shape[0] + 1) - .5, minor=True)
        ax.grid(which="minor", color="w", linestyle='-', linewidth=1)
        ax.tick_params(which="minor", bottom=False, left=False)

        for i in range(heat_map.shape[0]):
            for j in range(heat_map.shape[1]):
                ax.text(j, i, int(heat_map[i, j]),
                        ha="center", va="center", color="w", fontsize=60)

        fig.tight_layout()
        save_path = os.path.join(save_dir, '{}_cost_map_{}.png'.format(name, h_i))
        plt.savefig(save_path)
