import os
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np


def _make_dir(fn):
    folder = os.path.dirname(fn)
    if not os.path.isdir(folder):
        os.makedirs(folder)


def save_image(image, fn):
    if len(image.shape) == 2:
        image = np.expand_dims(image, -1)
    assert len(image.shape) == 3
    _make_dir(fn)
    if image.shape[2] == 1:
        plt.imsave(fn, np.squeeze(image, -1), cmap=cm.gray)
    elif image.shape[2] == 3:
        plt.imsave(fn, np.maximum(0, np.minimum(1, image)))
    else:
        assert False


def save_images(experiment_id, data, folder, size=100):
    if 0 <= size < len(data):
        data = data[:size]
    for i, image in enumerate(data):
        save_image(image, os.path.join('logs', experiment_id, folder,
                                       str(i) + '.png'))


def save_text(experiment_id, data, name):
    fn = os.path.join('logs', experiment_id, name + '.txt')
    with open(fn, 'w') as f:
        for i in range(len(data[0])):
            for d in data:
                f.write(str(d[i]) + '\t')
            f.write('\n')


def save_hidden(experiment_id, data, name):
    fn = os.path.join('logs', experiment_id, name + '.txt')
    with open(fn, 'w') as f:
        for i in range(len(data[0])):
            for j, d in enumerate(data):
                if j < 2:
                    f.write(str(d[i]) + '\t')
                else:
                    for e in d[i]:
                        f.write(str(e) + ' ')
                    f.write('\t')
            f.write('\n')
