import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator

#plt.rcParams['text.usetex'] = True #Let TeX do the typsetting
#plt.rcParams['text.latex.preamble'] = [r'\usepackage{sansmath}', r'\sansmath'] #Force sans-serif math mode (for axes labels)
#plt.rcParams['font.family'] = 'sans-serif' # ... for regular text
#plt.rcParams['font.sans-serif'] = 'Helvetica' # Choose a nice font here

# Each column is the result of running one algorithm.
def extract_columns(s):
    rows = []
    for row in s.split('\n'):
        tokens = row.split()
        if len(tokens) == 0:
            continue
        rows.append(tokens)
    num_rows = len(rows)
    num_cols = len(rows[0])
    ans = [[] for i in range(num_cols)]
    for row in rows:
        for i in range(len(row)):
            if row[i] in ['.', '-']:
                continue
                ans[i].append(0)
            else:
                ans[i].append(float(row[i]))
    return ans

def main():
    plt.style.use('default')
    plt.figure(figsize=(15,3))

    labels = ['Baseline-Naive', 'Baseline-KronMatMul', 'DJSSW19', 'FastKroneckerRegression']
    markers = ['o', '^', 's', 'x']
    algorithms = labels

    x_values = \
    """
    128
    256
    512
    1024
    2048
    4096
    8192
    16384
    """
    # --------------------------------------------------------------------------
    # num_cols=8
    # --------------------------------------------------------------------------
    y_values = \
    """
    0.010638022	0.00075807	0.005583762	0.007416081
    0.034859425	0.001398892	0.005075821	0.015482532
    0.134323576	0.003691965	0.006370039	0.017212951
    0.50091633	0.011494756	0.008722743	0.019233132
    1.940263127	0.046609549	0.013988233	0.0228895
    7.525319565	0.18549168	0.023632491	0.042028427
    29.94789325	0.675973614	0.040401164	0.057790705
    -	3.971652607	0.077592327	0.101902692
    """
    extract_columns(y_values)

    plt.subplot(1, 4, 1)
    x = extract_columns(x_values)[0]
    y = extract_columns(y_values)
    for i in range(len(algorithms)):
        plt.plot(x[:len(y[i])], y[i], label=labels[i], marker=markers[i])
    plt.grid()
    plt.title('d=8')
    plt.xlabel('n', fontsize='large')
    plt.ylabel('time (seconds)', fontsize='large')
    ax = plt.gca()
    ax.set_ylim([-1.5, 32.5])

    # --------------------------------------------------------------------------
    # num_cols=16
    # --------------------------------------------------------------------------
    y_values = \
    """
    0.141580555	0.001019571	0.115385713	0.094187023
    0.199180262	0.002173037	0.121229208	0.098755817
    0.548514559	0.006202522	0.121640987	0.09268078
    1.899420443	0.022152784	0.118420346	0.098242575
    7.167710972	0.09164493	0.128644473	0.111201766
    27.52374916	0.339914517	0.136500304	0.114001429
    -	2.285792954	0.159942177	0.13670338
    .	15.00283802	0.203591912	0.186859658
    """
    extract_columns(y_values)

    plt.subplot(1, 4, 2)
    x = extract_columns(x_values)[0]
    y = extract_columns(y_values)
    for i in range(len(algorithms)):
        plt.plot(x[:len(y[i])], y[i], label=labels[i], marker=markers[i])
    plt.grid()
    plt.title('d=16')
    ax = plt.gca()
    ax.set_ylim([-1.5, 32.5])

    # --------------------------------------------------------------------------
    # num_cols=32
    # --------------------------------------------------------------------------
    y_values = \
    """
    5.959340513	0.001928789	10.34777967	1.65440601
    7.035832866	0.003887378	10.47348228	1.80212774
    8.175744788	0.011755894	10.31561653	1.704491516
    13.62733798	0.047366304	10.33786235	1.680198117
    31.49962266	0.173916952	10.40883456	1.719295888
    -	1.128951463	10.25850971	1.673174591
    -	7.733512816	10.48557042	1.750514199
    -	29.71964358	10.58145285	1.861285648
    """
    extract_columns(y_values)

    plt.subplot(1, 4, 3)
    x = extract_columns(x_values)[0]
    y = extract_columns(y_values)
    for i in range(len(algorithms)):
        plt.plot(x[:len(y[i])], y[i], label=labels[i], marker=markers[i])
    plt.grid()
    plt.title('d=32')
    ax = plt.gca()
    ax.set_ylim([-1.5, 32.5])

    # --------------------------------------------------------------------------
    # num_cols=64
    # --------------------------------------------------------------------------
    y_values = \
    """
    966.1174844	0.005799663	1240.647135	29.6798138
    982.5171561	0.010165475	1243.164812	29.64527451
    970.6558238	0.026225677	1240.73648	29.37247658
    985.4096786	0.090038077	1239.860209	29.55461257
    -	0.668667386	1238.65576	29.43672288
    -	3.642191448	1243.263522	29.7029016
    -	14.51370957	1244.834666	29.89260021
    -	67.00368547	1241.556855	29.91433819
    """
    extract_columns(y_values)

    plt.subplot(1, 4, 4)
    x = extract_columns(x_values)[0]
    y = extract_columns(y_values)
    for i in range(len(algorithms)):
        plt.plot(x[:len(y[i])], y[i], label=labels[i], marker=markers[i])
    plt.grid()
    plt.title('d=64')

    # Output ----------------------------
    plt.legend(labels=algorithms, loc='lower center',
               bbox_to_anchor=(-1.3, -0.28), fancybox=False, shadow=False, ncol=4)
    
    plt.savefig("ndim2-plot-v15.png", transparent=True, bbox_inches='tight', dpi=256)
    plt.show()

main()
