import torch
import os
from .util import *
import scipy.stats as stats
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import sys
import itertools
from datetime import datetime
import argparse

dataset_root = "CIFAR10"


in_dir = "/network/scratch/name/mutual_info/v2/"
data_dir = os.path.join("/network/scratch/name/datasets/", dataset_root)
out_dir = "/network/scratch/name/mutual_info/v2/out_dir_backup"

batch_sizes = [64, 128, 1024]
decays = [1e-5, 0.0001, 0.001] # weird formatting
model_names = ["PreResNet56", "PreResNet83", "PreResNet110"]

settings = []
for b in batch_sizes:
    for d in decays:
        for m in model_names:
            settings.append((b, d, m))

num_settings = len(settings)
print("num settings: %s" % num_settings)


for s_i, (batch_size, wd, model_name) in enumerate(settings):
    theta_D_key = "%s_0_%s_%s_%s" % (dataset_root, batch_size, wd, model_name) # todo 
    theta_D_p = os.path.join(out_dir, "theta_D_key_%s.pt" % theta_D_key)
    if os.path.exists(theta_D_p):
        MI_theta_D = torch.load(theta_D_p)
        print((batch_size, wd, model_name))
        print(MI_theta_D)
    else:
        print("Skipping %s" % str((batch_size, wd, model_name)))
    print("---")