import os
import itertools
import numpy as np
import matplotlib.pyplot as plt

plt.rcParams["figure.figsize"] = (25, 12)
plt.rcParams["font.size"] = 25
plt.rcParams["axes.labelsize"] = 27

BASEDIR = "/tmp" 

def create_plots():
    models = ["ResNet18", "ResNet50", "ResNet101"]

    configurations = ["resnet_True_he_fan_in", "resnet_True_he_fan_out", "resnet_False_he_fan_in", "resnet_False_he_fan_out",
                      "idconst_False_he_fan_out", "idconst_False_brock_fan_in",
                      "shortmult_False_he_fan_out", "shortmult_False_brock_fan_in",
                      "shortconv_False_he_fan_out", "shortconv_False_brock_fan_in"]

    for model in models:
        print(f"[INFO] Processing {model}...")
        fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, sharex=True, sharey=False)

        linestyles = ['--', '-.', ':']
        markers = ['^', 's', '>', 'd', 'v', '<', 'o', 'X', 'P', 'p']
        starts = list(range(1, 4))

        starts_cycler = itertools.cycle(starts)
        markers_cycler = itertools.cycle(markers)
        linestyle_cycler = itertools.cycle(linestyles)

        ymin, ymax = np.inf, -np.inf
        for conf in configurations:
            res = os.path.join(BASEDIR, f"{model}_{conf}.npy")
            #####
            label = os.path.split(res)[1].split('.')[0].split('_')
            label = [l.capitalize() for l in label[1:]]
            if (label[0] != "Resnet"):
                label = label[:-2]    
            label = ' '.join(label)
            label = label.replace("Resnet", "ResNet")
            label = label.replace("True", "BN")
            label = label.replace("False", '')
            label = label.replace("Idconst", "IdShort")
            label = label.replace("Shortmult", "LearnScalar")
            label = label.replace("Shortconv", "ConvShort")
            label = label.replace("Brock", "(Brock et al.)")
            label = label.replace("He", "(Our)")
            label = label.replace("(Our) Fan In", "(He fan-in)")
            label = label.replace("(Our) Fan Out", "(He fan-out)")
            #####
            res = np.load(file=res)

            fw_summation_var = res[0, :]
            bw_summation_var = res[1, :]

            ymin = min(ymin, min(fw_summation_var), min(bw_summation_var))
            ymax = max(ymax, max(fw_summation_var), max(bw_summation_var))

            mc, lc = next(markers_cycler), next(linestyle_cycler)
            start = next(starts_cycler)
            ax1.plot(fw_summation_var, markersize=13, linewidth=3, marker=mc, 
                     linestyle=lc, label=label, markevery=[0, -1]+list(range(start,len(bw_summation_var),4)))
            ax2.plot(bw_summation_var, markersize=13, linewidth=3, marker=mc, 
                     linestyle=lc, label=label, markevery=[0, -1]+list(range(start,len(bw_summation_var),4)))
        
        ax1.grid(); ax2.grid()
        ax1.set_xlabel("Depth")
        ax2.set_xlabel("Depth")
        ax1.set_ylabel("Variance")
        # ax2.set_ylabel("Variance")
        ax1.set_ylim([ymin*1e0, ymax*1e0])
        ax2.set_ylim([ymin*1e0, ymax*1e0])
        ax1.set_yscale("log")
        ax2.set_yscale("log")
        ax1.set_title("Forward")
        ax2.set_title("Backward")
        ax2.legend(loc="upper right", handlelength=2)
        plt.show()


if __name__ == "__main__":
    create_plots()
