import argparse
import numpy as np
import time
import matplotlib.pyplot as plt
label_idx = {0: "Noisy", 1: "Noise free"}


def main(args):
    fig, ax = plt.subplots(
        2, 2, figsize=tuple(args.figsize), gridspec_kw={"wspace": args.whspace[0], "hspace": args.whspace[1]}, dpi=300
    )
    for idx, fname in enumerate(args.datafiles):
        with open(fname, "r") as fhandle:
            rawdata = fhandle.readlines()
        mean_losses = []
        std_losses = []
        for i, line in enumerate(rawdata):
            data = np.fromstring(line, sep=" ")
            if i % 2 == 0:
                mean_losses.append(data)
            else:
                std_losses.append(data)

        for i in range(4):
            y = mean_losses[i]
            yerr = std_losses[i]
            ax[i//2][i%2].plot(range(len(y)), y, c=args.linecolor[idx])
            ax[i//2][i%2].errorbar(
                list(range(len(y))), y, yerr, capsize=3.0, fmt=args.marker[idx], label=label_idx[idx], 
                color=args.linecolor[idx], markerfacecolor="none", markeredgecolor=args.linecolor[idx]
            )
            if i >= 2:
                ax[i//2][i%2].set_xlabel("number iterations", fontsize=13)
            ax[i//2][i%2].set_ylabel(r"$C_{{{}}}(\sigma_{{{}}},\rho_{{{}}})$".format(i+1, i+1, i+1), fontsize=13)

    for i in range(4):
        ax[i//2][i%2].legend(prop={"size": 13})
    plt.suptitle("(a) Loss of QSSM learning in each step", fontsize=15)
    plt.subplots_adjust(top=0.925)
    fig.savefig(f"losspath-{time.ctime(time.time())}.pdf")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--datafiles", nargs="+", help="data dirs")
    parser.add_argument("--figsize", type=float, nargs=2, default=[6.4, 4.8])
    parser.add_argument("--whspace", type=float, nargs=2, default=[0.25, 0.2])
    parser.add_argument("--marker", nargs="+")
    parser.add_argument("--linecolor", nargs="+")
    main(parser.parse_args())
