import numpy as np
import matplotlib.pyplot as plt
import bmi as bmi # import Benchmarking Mutual Information from https://github.com/cbg-ethz/bmi/tree/main

#load data for 4 bit teacher model
X, Y = np.load('teacher_2pow14.npy')
nsamps=X.shape[0]


#setup the KSG estimator providing a list of the number of neighbors as karr
karr=[1,5,10,15,20]
numNeighbors=len(karr)
estimator=bmi.estimators.KSGEnsembleFirstEstimator(neighborhoods=karr)
estimator.parameters()

#calculate the KSG estimator on data that has been partitioned into gamma partions
#and store the data in an array indexed by [gamma, reps -- the particular partition, neighbor indexed in karr]
store=np.full((10, 10, numNeighbors), np.nan)
for gamma in range(1,11,1):
    for reps in range(gamma): #calculate ksg for each partition
        batch=nsamps//gamma
        estimator=bmi.estimators.KSGEnsembleFirstEstimator(neighborhoods=karr)
        estimator.fit(X[(reps)*batch:(reps+1)*batch],Y[(reps)*batch:(reps+1)*batch])
        predictions = np.asarray(list(estimator.get_predictions().values()))
        store[gamma-1,reps,:]=predictions/np.log(2)


np.save('mi4_16k_rawdata.npy',store)
