import os
import torch

# printing train, test statistics for accepted models

# from main_swag

out_dir = "/network/scratch/name/mutual_info/v6"
print(out_dir)

data_instances = list(range(3))

archs = [[2, 256, 256, 128, 128, 5],
[2, 128, 128, 64, 64, 5],
[2, 64, 64, 32, 32, 5],
[2, 32, 32, 16, 16, 5],
]

decays = [0.0, 1e-2, 1e-1]

lamb_lrs = [0.0, 5e-3, 1e-3]

seeds = [0, 1, 2]

train_accs = []
test_accs = []
train_losses = []
test_losses = []

thresh = 0.85
skipped = 0
considered = 0

model_ind = 0
for lamb_lr in lamb_lrs:
    for arch in archs:
        for decay in decays:

            for data_instance in data_instances:
                for seed in seeds:

                    if lamb_lr == 0 or lamb_lr == 1e-3:

                        r = torch.load(os.path.join(out_dir, "results_%d.pt" % model_ind))

                        if (r["train_acc"] > thresh):
                            train_accs.append(r["train_acc"])
                            test_accs.append(r["test_acc"])

                            train_losses.append(r["train_loss"])
                            test_losses.append(r["test_loss"])
                        else:
                            skipped += 1

                        considered += 1

                    model_ind += 1

train_accs = torch.tensor(train_accs)
test_accs = torch.tensor(test_accs)
train_losses = torch.tensor(train_losses)
test_losses = torch.tensor(test_losses)



print("All:")
print(model_ind)

print("Considered:")
print(considered)

print("Skipped:")
print(skipped)

print("Resulting:")
print(train_accs.shape)


print("Train loss & %.4f & %.4f & %.4f & %.4f \\\\" % (train_losses.max(), train_losses.min(), train_losses.mean(), train_losses.std() ))
print("Train accuracy & %.4f & %.4f & %.4f & %.4f \\\\" % (train_accs.max(), train_accs.min(), train_accs.mean(), train_accs.std()))

print("Test loss & %.4f & %.4f & %.4f & %.4f \\\\" % (test_losses.max(), test_losses.min(), test_losses.mean(), test_losses.std() ))
print("Test accuracy & %.4f & %.4f & %.4f & %.4f \\\\" % (test_accs.max(), test_accs.min(), test_accs.mean(), test_accs.std() ))
