# read all folder under file_name
import os
import json
import numpy as np
from datetime import datetime
import matplotlib.pyplot as plt

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--file_path', type=str, default="", help='where to save the trajectory')
args = parser.parse_args()

# read result.json
file_name = os.path.join(args.file_path, "data/result.json")

def read_jsonl_file(filepath):
    data = []
    with open(filepath, 'r') as file:
        for line in file:
            if line.strip():  # skip empty lines
                try:
                    entry = json.loads(line)
                    data.append(entry)
                except json.JSONDecodeError as e:
                    print(f"Error decoding line: {line}\n{e}")
    return data

# read file (examples), readlines
data = read_jsonl_file(file_name)
all_logp = []
for i in range(len(data)):
    logp_sample = [float(x) for x in data[i]["pi_sampling"]]
    all_logp.append(logp_sample)
all_logp = np.array(all_logp)

# get entropy
entropy = []
for i in range(all_logp.shape[0]):
    logp = all_logp[i]
    p = np.exp(logp)
    entropy.append(-np.sum(p * logp))
entropy = np.array(entropy).mean()

# save entropy
entropy_file = os.path.join(args.file_path, "entropy.json")
file_exist = os.path.exists(entropy_file)
if file_exist:
    with open(entropy_file, "r") as f:
        data = json.load(f)
else:
    data = {"entropy": []}

data["entropy"].append(entropy)
with open(entropy_file, "w") as f:
    json.dump(data, f, indent=4)
    
# plot entropy
entropy = data["entropy"]
plt.figure(figsize=(10, 5))
plt.title("Entropy")
plt.xlabel("Epoch")
plt.ylabel("Entropy")
plt.plot(entropy, label="Entropy")
    
plt.legend()
plt.savefig(os.path.join(args.file_path, "entropy.png"))
    

        

