from os import listdir
from collections import defaultdict
from ast import literal_eval
import torch
import numpy as np
import matplotlib.pyplot as plt

from labellines import labelLine, labelLines

results = {}

for f in listdir('longrange_results'):
    if not f.startswith('size_wise_acc'):
        continue
    print(f)
    with open('longrange_results/' + f, "r") as res:
        for line in res.readlines():
            model, dataset, seed, size, accuracy, nodes = line.split(',')
            size, accuracy, nodes = int(size), float(accuracy), int(nodes)
            if not dataset in results:
                results[dataset] = {}
            if not model in results[dataset]:
                results[dataset][model] = {}
            if not size in results[dataset][model]:
                results[dataset][model][size] = []
            results[dataset][model][size].append(accuracy)

model_order = ['neg', 'universal', 'itergnn', 'amprnn', 'ampgru', 'amplstm', 'ampact', 'ampiter', 'ampatt']
model_order2 = ['neg', 'universal', 'itergnn', 'amprnn', 'ampgru', 'amplstm', 'ampatt', 'ampact', 'ampiter']
for dataset in results:
    size_order = sorted(results[dataset][model_order[0]])
    break

print(model_order)
print(size_order)

for model in model_order:
    for dataset in results:
        if dataset != "oddeven":
            continue
        accs = [np.mean(results[dataset][model][size]) for size in size_order]
        stds = [np.std(results[dataset][model][size]) for size in size_order]
        accs = [str(acc)[:4] for acc in accs]
        stds = [str(std)[:4] for std in stds]
        accs = ['\makebox{{{}\\rpm{}}}'.format(accs[i], stds[i]) for i in range(len(accs))]
        print(model, dataset)
        print(' & '.join(accs))

"""##############################################################################"""

str_for_model = {"neg":"NEG", "universal":"Universal", "itergnn":"IterGNN", "amprnn":"AMP-RNN", "ampgru":"AMP-GRU",
                 "amplstm":"AMP-LSTM", "ampact": "AMP-ACT", "ampiter": "AMP-Iter", "ampatt": "AMP-ATT"}

results = {}

for f in listdir('longrange_results'):
    if not f.startswith('distance_wise_acc'):
        continue
    with open('longrange_results/' + f, "r") as res:
        for line in res.readlines():
            model, dataset, seed, distance, accuracy, nodes = line.split(',')
            distance, accuracy, nodes = int(distance), float(accuracy), int(nodes)
            if not dataset in results:
                results[dataset] = {}
            if not model in results[dataset]:
                results[dataset][model] = {}
            if not distance in results[dataset][model]:
                results[dataset][model][distance] = []
            results[dataset][model][distance].append(accuracy)

dataset = 'oddeven'
plt.figure(figsize=(9, 4.5))
for m in model_order:
    xs = [d for d in results[dataset][m]]
    xs = xs[1::2]
    xs = [x for x in xs if x < 22]
    ys = [np.array(results[dataset][m][d-1]) + np.array(results[dataset][m][d]) for d in xs]
    stds = np.array([np.std(y) for y in ys])
    ys = np.array([np.mean(y)/2 for y in ys])
    xs = xs
    plt.plot(xs, ys, label="{}".format(str_for_model[m]))
    #plt.fill_between(xs, ys-stds, ys+stds, alpha=0.3)
labelLines(plt.gca().get_lines(), align=True, xvals=[2, 4.5, 6.3, 2., 6.5, 10, 16, 19, 19], fontsize=14)
#plt.legend()
plt.savefig('underreaching.pdf')
#plt.show()
plt.clf()

print("###################UNDERREACHING########")

print(xs)
for d in xs:
    numbers= []
    for m in model_order2:
        ys = np.array(results[dataset][m][d - 1]) + np.array(results[dataset][m][d])
        std = np.std(ys/4)
        ys = np.mean(ys/2)
        numbers.append('\makebox{{{}\\rpm{}}}'.format(str(ys)[:4], str(std)[:4]))
    print("{}-{} & ".format(d-1, d) + " & ".join(numbers) + "\\\\")



"""##############################################################################"""

results = {}

for f in listdir('longrange_results'):
    if not f.startswith('training_range_accs'):
        continue
    with open('longrange_results/' + f, "r") as res:
        for line in res.readlines():
            model, dataset, seed, size, accuracy = line.split(',')
            size, accuracy, nodes = int(size), float(accuracy), int(nodes)
            if not dataset in results:
                results[dataset] = {}
            if not model in results[dataset]:
                results[dataset][model] = {}
            if not size in results[dataset][model]:
                results[dataset][model][size] = []
            results[dataset][model][size].append(accuracy)

dataset = 'oddeven'
plt.figure(figsize=(9, 4.5))
for m in model_order:
    xs = [size for size in results[dataset][m]]
    xvalues = range(len(xs))
    ys = np.array([np.mean(results[dataset][m][d]) for d in xs])
    stds = np.array([np.std(results[dataset][m][d]) for d in xs])
    plt.plot(xvalues, ys, label="{}".format(str_for_model[m]))
    #plt.fill_between(xvalues, ys -stds, ys + stds,alpha=0.3)
    plt.xticks(xvalues, xs)
labelLines(plt.gca().get_lines(), align=True, xvals=[1.6, 3, 6.2, 2.5, 4.5, 6.2, 4.5, 6, 0.3], fontsize=14)
plt.savefig('oversmoothing.pdf')
#plt.show()
plt.clf()

print("###################OVERSMOOTHING########")

for model in model_order2:
    for dataset in results:
        if dataset != "oddeven":
            continue
        print(size_order)
        accs = [np.mean(results[dataset][model][size]) for size in size_order]
        stds = [np.std(results[dataset][model][size]) for size in size_order]
        accs = [str(acc)[:4] for acc in accs]
        stds = [str(std)[:4] for std in stds]
        accs = ['\makebox{{{}\\rpm{}}}'.format(accs[i], stds[i]) for i in range(len(accs))]
        print(model, dataset)
        print(' & '.join(accs))


"""##############################################################################"""

results = {}

for f in listdir('longrange_results'):
    if not f.startswith('size_wise_acc'):
        continue
    with open('longrange_results/' + f, "r") as res:
        for line in res.readlines():
            model, dataset, seed, size, accuracy, nodes = line.split(',')
            size, accuracy, nodes = int(size), float(accuracy), int(nodes)
            if not dataset in results:
                results[dataset] = {}
            if not model in results[dataset]:
                results[dataset][model] = {}
            if not size in results[dataset][model]:
                results[dataset][model][size] = []
            results[dataset][model][size].append(accuracy)



dataset = 'oddeven'
plt.figure(figsize=(9, 4.5))
for m in model_order:
    xs = [size for size in results[dataset][m]]
    xvalues = range(len(xs))
    ys = np.array([np.mean(results[dataset][m][d]) for d in xs])
    std = np.array([np.std(results[dataset][m][d]) for d in xs])
    plt.plot(xvalues, ys, label="{}".format(str_for_model[m]))
    #plt.fill_between(xvalues, ys-std, ys+std, alpha=0.3)
    plt.xticks(xvalues, xs)
labelLines(plt.gca().get_lines(), align=True, xvals=[0.4, 0.8, 6.2, 1.2, 5, 2.2, 5, 6, 0.3], fontsize=14)
plt.savefig('oversquashing1.pdf')
#plt.show()
plt.clf()

dataset = 'multioddeven'
plt.figure(figsize=(9, 4.5))
for m in model_order2:
    xs = [size for size in results[dataset][m]]
    xvalues = range(len(xs))
    ys = np.array([np.mean(results[dataset][m][d]) for d in xs])
    std = np.array([np.std(results[dataset][m][d]) for d in xs])
    plt.plot(xvalues, ys, label="{}".format(str_for_model[m]))
    #plt.fill_between(xvalues, ys-std, ys+std, alpha=0.3)
    plt.xticks(xvalues, xs)
labelLines(plt.gca().get_lines(), align=True, xvals=[0.4, 1.2, 0.3, 3.5, 2, 0.6, 5, 6, 1.5], fontsize=14)
plt.savefig('oversquashing2.pdf')
#plt.show()
plt.clf()

print("###########OVERSQUASHING#########")
for model in model_order2:
    dataset = "oddeven"
    accs = [np.mean(results[dataset][model][size]) for size in size_order]
    stds = [np.std(results[dataset][model][size]) for size in size_order]
    accs = [str(acc)[:4] for acc in accs]
    stds = [str(std)[:4] for std in stds]
    accs = ['\makebox{{{}\\rpm{}}}'.format(accs[i], stds[i]) for i in range(len(accs))]
    #print(model, dataset)
    print("{\\multirow{2}{*}{" + str_for_model[model] + "}} & " + ' & '.join(accs) + "\\\\")
    dataset = "multioddeven"
    accs = [np.mean(results[dataset][model][size]) for size in size_order]
    stds = [np.std(results[dataset][model][size]) for size in size_order]
    accs = [str(acc)[:4] for acc in accs]
    stds = [str(std)[:4] for std in stds]
    accs = ['\makebox{{{}\\rpm{}}}'.format(accs[i], stds[i]) for i in range(len(accs))]
    print(" & " + ' & '.join(accs) + "\\\\\\midrule")