import os
import pickle
import torch
import numpy as np
import pandas as pd
import argparse
import matplotlib.pyplot as plt
from collections import defaultdict
from tqdm.auto import tqdm
import scipy.stats
from sklearn.metrics import pairwise_distances

#read in all files with prefix "entropy" in dirpath
def read_entropy_files(dirpath):
    entropies = []
    files = os.listdir(dirpath)
    files = [f for f in files if f.startswith("entropy_")]
    for f in files:
        #file in form "entropy: [e1, e2, ...]"
        with open(os.path.join(dirpath, f), 'r') as f:
            line = f.readline()
            parts = line.split(":")
            nums = parts[1].strip()[1:-1].split(",")
            nums = [float(n) for n in nums]
            entropies += nums
    return entropies




entropies = read_entropy_files("metrics/entropy_results/")
print(len(entropies))

#caclulate what threshold 95% of entropies are above
thresh = 0.125
print("Percent of entropies above threshold: ", len([e for e in entropies if e > thresh])/len(entropies))
#print mean
print("Mean entropy: ", np.mean(entropies))

#make figure
plt.figure(figsize=(3, 2))
plt.rcParams["font.family"] = "serif"
plt.hist(entropies, bins=25, color="green", edgecolor="black", linewidth=0.5)
plt.xlabel("Entropy", fontsize=10)
plt.ylabel("Frequency", fontsize=10)
plt.xticks(np.arange(0, 1.1, 0.2), fontsize=8)
plt.xticks(fontsize=8)
plt.yticks(fontsize=8)
plt.tight_layout()
plt.savefig("metrics/figures/entropy.pdf", bbox_inches="tight", pad_inches=0.0)