import os
import numpy as np
import matplotlib.pyplot as plt

plt.rcParams["figure.figsize"] = (25, 12)
plt.rcParams["font.size"] = 25
plt.rcParams["axes.labelsize"] = 27

LINESTYLE = ['-', ':', '-.', '--', '-', ':', '-.']
MARKES = ['P', 'D', 'X', 'o', 's', 'v', '>']
COLORS = ["#407294", "#f01806", "#9911dd", "#91dd1a", "#9c0d06", "#eea5cd", "#168d14"]
BASEDIR = "/tmp" 

def create_plots():
    models = ["ResNet18", "ResNet50", "ResNet101"]
    configurations = ["resnet_he_fi", "const_he_fo", "mult_he_fo", "conv_he_fo", "const_brock_fi", "mult_brock_fi", "conv_brock_fi",]

    for model in models:
        fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, sharex=True, sharey=True)
        fig.suptitle(model)
        ax1.set_title("Train")
        ax2.set_title("Test")
        ax1.set_ylabel("Accuracy [%]")
        ax1.set_ylim([85, 99])
        ax1.set_yticks(range(85, 99, 1))
        ax1.set_xlabel("Epoch")
        ax2.set_xlabel("Epoch")
        ax1.grid(axis="y")
        ax2.grid(axis="y")

        paths = [os.path.join(BASEDIR, f"{model}_{conf}") for conf in configurations]
        for idx, path in enumerate(paths):
            print(f"[INFO] Processing '{path}'")
            train, test = [], []
            for seed in [0, 42, 314]:
                path_seed = f"{path}_{seed}"  
                try:
                    f = open(path_seed, "r", errors="ignore")
                    lines = f.readlines()
                    f.close()
                except:
                    print(f"\t[ERROR] seed '{seed}' not found")
                    lines = []

                train_accs, test_accs = [], []
                for line in lines:
                    if ("Epoch" in line) and ("loss" in line):
                        train_accs.append(float(line.split(":")[-1].strip()))
                    elif "test" in line:
                        test_accs.append(float(line.split(":")[-1].strip()))
                if len(train_accs) > 0:
                    train_accs += [train_accs[-1]]*(300-len(train_accs))
                    test_accs += [test_accs[-1]]*(300-len(test_accs))
                else:
                    train_accs += [0]*300
                    test_accs += [0]*300
                train.append(train_accs)
                test.append(test_accs)

            train = np.array(train)
            test = np.array(test)
            
            if len(train.shape) > 1:
                x = np.arange(train.shape[1])
                # # just to manage the label # #
                label = os.path.split(path)[1].split('_')
                label = [l.capitalize() for l in label[1:]]
                label = ' '.join(label)
                label = label.replace("Resnet", "BatchNorm")
                label = label.replace("Const", "IdShort")
                label = label.replace("Mult", "LearnScalar")
                label = label.replace("Conv", "ConvShort")
                label = label.replace("Brock", "(Brock et al.)")
                label = label.replace("He", "(Our)")
                label = label.replace("Fi", '')
                label = label.replace("Fo", '')
                # # # #
                train_mean = train.mean(axis=0)
                ax1.plot(x, train_mean, label=label, 
                         linestyle=LINESTYLE[0], 
                         marker=MARKES[idx % len(MARKES)],
                         color=COLORS[idx % len(COLORS)],
                         markersize=5, linewidth=2,
                         markevery=list(range(0, len(train_mean), 10))+[-1])
                train_std = train.std(axis=0)
                ax1.fill_between(x, train_mean-train_std, train_mean+train_std,
                                 alpha=0.25, color=COLORS[idx])

                test_mean = test.mean(axis=0)
                ax2.plot(x, test_mean, label=label, 
                         linestyle=LINESTYLE[0], 
                         marker=MARKES[idx % len(MARKES)],
                         color=COLORS[idx % len(COLORS)],
                         markersize=5, linewidth=2,
                         markevery=list(range(0, len(test_mean), 10))+[-1])
                test_std = test.std(axis=0)
                ax2.fill_between(x, test_mean-test_std, test_mean+test_std,
                                 alpha=0.25, color=COLORS[idx])
                ax1.legend(loc="lower right")
            else:
                print(f"\t[NO PLOT] '{os.path.split(path)[1]}'")

        plt.show()


if __name__ == "__main__":
    create_plots()
