import torch
import matplotlib.pyplot as plt
from os.path import join
import numpy as np


def main():
    root = "C:\Research\mixing2mixing\data\mnistSparsityHist\mnist_mlp_lr0.05_save"

    for key in ['layers.0', 'layers.1', 'layers.2', 'layers.3', 'layers.4', 'layers.5']:
        hists = []
        for epoch in range(0, 200, 10):
            feature = torch.load(join(root, "feature_epoch_{}".format(str(epoch + 1).zfill(3))), map_location="cpu")
            feature = feature[key][0]

            feature_max = 1
            feature_min = -0.02

            bins = 51
            feature[(feature - 1e-6) < 0] = -0.01
            hist = torch.histc(feature.view(-1), bins, min=feature_min, max=feature_max)
            hist = hist / feature.numel()

            hists.append([hist, epoch])

            strid = (feature_max - feature_min) / bins
            bins_start_pt = [feature_min + strid * (i + 1) for i in range(bins)]

        drawHist(bins_start_pt, hists, name=key)


def drawHist(x, hists, name):
    plt.rc('font', size=15)  # controls default text sizes
    plt.rcParams["font.weight"] = "bold"
    plt.rcParams["axes.labelweight"] = "bold"

    width = (x[1] - x[0]) / (len(hists) + 5)  # the width of the bars

    fig, ax = plt.subplots(figsize=(20, 5))

    for cnt_hist, hist in enumerate(hists):
        offset_x = width * (-len(hists)/2 + cnt_hist)
        hist_x = np.array(x) + offset_x
        hist_y, epoch = hist
        ax.bar(hist_x, hist_y, width, label=epoch)

    # Add some text for labels, title and custom x-axis tick labels, etc.
    ax.set_ylabel('Density')

    # for label in ax.get_xticklabels():
    #     label.set_rotation(20)
    #     label.set_ha('right')

    ax.legend(ncol=3)

    plt.savefig('{}.pdf'.format(name))


if __name__ == "__main__":
    main()
