

from os import listdir
from os.path import isfile, join
import os
import torch

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

path = "/network/scratch/name/mutual_info/v2/out_dir"
fs = [f for f in listdir(path) if isfile(join(path, f))]


MIs = []
for f in fs:
    res = torch.load(os.path.join(path, f))
    MIs.append(res[0])

MIs = torch.stack(MIs, dim=0)

print(MIs.shape)
print(MIs)

f, ax = plt.subplots(1)
avg = MIs.mean(dim=0)


std = MIs.std(dim=0)
ax.plot(avg)
ax.fill_between(range(avg.shape[0]), avg - std, avg + std, alpha=0.2)

plt.tight_layout()
f.savefig("MIs.png", bbox_inches="tight")
plt.close("all")


