# cython: profile=True
# -*- coding: utf-8 -*-
import os
import json
import torch
from matplotlib import pyplot as plt
from folders import folders
from make_illustrations import show_weights
from plot_loss import plot_losses


def main(out_path='../figures/models/'):
    with open('models.json', 'r') as f:
        model_list = json.load(f)
    for model_def in model_list:
        model = model_def['model']
        version = model_def['version']
        folder = os.path.join(folders['models'], model, 'version%d' % version)
        folder_out = os.path.join(out_path, model, 'version%d' % version)
        os.makedirs(folder_out, exist_ok=True)
        pars_file = os.path.join(folder, 'pars.pth')
        if os.path.isfile(pars_file):
            pars = torch.load(pars_file, map_location=torch.device('cpu'))
            # show weights and save
            if model == 'linearbig':
                weights = pars['raw_module.weight']
            elif model == 'predseg1':
                weights = pars['2.raw_module.weight']
                w0_file = os.path.join(folder_out, 'weights0.pdf')
                w0 = pars['0.raw_module.weight']
                fig = show_weights(w0)
                fig.savefig(w0_file, bbox_inches='tight')
                plt.close(fig)
                wp_file = os.path.join(folder_out, 'weights_mean.pdf')
                fig = plot_mean_weights(w0)
                fig.savefig(wp_file, bbox_inches='tight')
                plt.close(fig)
            if model == 'linearbig' or model == 'predseg1':
                w_file = os.path.join(folder_out, 'weights.pdf')
                fig = show_weights(weights)
                fig.savefig(w_file, bbox_inches='tight')
                plt.close(fig)
        l_file = os.path.join(folder_out, 'losses.pdf')
        fig = plt.figure(figsize=(5, 5))
        plot_losses(model, version, n_conv=100)
        fig.savefig(l_file, bbox_inches='tight')
        plt.close(fig)

def plot_mean_weights(weights):
    w_mean = torch.mean(torch.mean(weights, -1), -1)
    fig, ax = plt.subplots()
    plt.plot(w_mean.T)
    plt.plot([0,2], [0,0], "k--")
    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)
    ax.spines["bottom"].set_visible(False)
    plt.xticks([])
    return fig

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("-o", "--out_path",
                        help="where to save to",
                        default='../figures/models/')
    args = parser.parse_args()
    main(**vars(args))
