import numpy as np
import pickle
import math
from EDGE_4_3_1 import EDGE
from pomegranate import *
from tqdm import *

ip_data = "IP/resnet18_gaussianDrp_0.4"
repr_data = "representations/resnet18_gaussianDrp_0.4"
p = 0.4
drp_noise = p/(1-p)

stab_e = 1e-6

def repr_entropy(nonoise_reprs, reprs):
    dists = []
    reprs_to_use = nonoise_reprs[:1000]
    for i in tqdm(range(len(reprs_to_use))):
        dists.append(MultivariateGaussianDistribution(reprs_to_use[i], drp_noise*np.diag(abs(reprs_to_use[i])+stab_e)))
    gmm_Z = GeneralMixtureModel(dists, weights=np.full(len(dists), (1.0 / len(dists))))
    log_probs = []
    for i in tqdm(range(int(len(reprs)/10000))):
        log_probs += gmm_Z.log_probability(reprs[i*10000:(i+1)*10000], n_jobs=100).tolist()
    return (-1.0/len(log_probs))*np.array(log_probs).sum()

def gaussian_noise_mi(reprs, nonoise_reprs, noise_variance):
    f_dim = reprs.shape[1]
    h_z = repr_entropy(nonoise_reprs, reprs)
    # our innovative way to compute conditional entropy
    h_zGivx = 0
    h_z_part = math.sqrt(2*math.pi*math.e)*noise_variance
    for i in range(f_dim):
        # we have simple multiplication of 1-dim gaussian (noise) by a constant (current value of activations)
        h_zGivx += (1.0/len(nonoise_reprs)) * np.sum(np.log(h_z_part * (np.fabs(nonoise_reprs[:,i]) + stab_e)))
    return h_z - h_zGivx

test_comp_mi_xz = {}
test_comp_mi_zy = {}
train_comp_mi_xz = {}
train_comp_mi_zy = {}

test_labels = np.load(os.path.join(repr_data, "test_labels.npy"))
train_labels = np.load(os.path.join(repr_data, "train_labels.npy"))

for f in os.listdir(repr_data):
    if "test_representations" in f:
        print(f)
        epoch = int(f.split(".")[0].split("_")[-1])
        nonoise_reprs = np.load(os.path.join(repr_data, f), allow_pickle=True)
        nonoise_reprs, ind = np.unique(nonoise_reprs, axis=0, return_index=True)
        reprs = []
        for nr in nonoise_reprs:
            for i in range(5):
                epsilon = np.random.randn(nonoise_reprs.shape[1]) * drp_noise + 1
                reprs.append(nr*epsilon)
        reprs = np.array(reprs)
        test_comp_mi_xz[epoch] = gaussian_noise_mi(reprs, nonoise_reprs, drp_noise)
        test_comp_mi_zy[epoch] = EDGE(reprs, np.repeat(np.array(test_labels[ind]), 5), U=1000, gamma=[0.7,0.01])

    if "train_representations" in f:
        print(f)
        epoch = int(f.split(".")[0].split("_")[-1])
        nonoise_reprs = np.load(os.path.join(repr_data, f), allow_pickle=True)
        reprs = []
        for nr in nonoise_reprs:
            for i in range(1):
                epsilon = np.random.randn(nonoise_reprs.shape[1]) * drp_noise + 1
                reprs.append(nr*epsilon)
        reprs = np.array(reprs)
        train_comp_mi_xz[epoch] = gaussian_noise_mi(reprs, nonoise_reprs, drp_noise)
        train_comp_mi_zy[epoch] = EDGE(reprs, train_labels, U=1000, gamma=[0.7,0.01])

pickle.dump(test_comp_mi_xz, open(os.path.join(repr_data, "test_comp_mi_xz"), "wb"))
pickle.dump(test_comp_mi_zy, open(os.path.join(repr_data, "test_comp_mi_zy"), "wb"))
pickle.dump(train_comp_mi_xz, open(os.path.join(repr_data, "train_comp_mi_xz"), "wb"))
pickle.dump(train_comp_mi_zy, open(os.path.join(repr_data, "train_comp_mi_zy"), "wb"))
