import json
import numpy as np
import matplotlib.pyplot as plt


def get_baseline(null_returns, random_returns):
    mean_null = np.mean(null_returns)
    std_null = np.std(null_returns)
    mean_random = np.mean(random_returns)
    std_random = np.std(random_returns)

    return max(
        mean_null,
        mean_random,
    )


with open("eval_null.jsonl") as f:
    null_data = [json.loads(l) for l in f.readlines()]

with open("eval_random.jsonl") as f:
    random_data = [json.loads(l) for l in f.readlines()]

mins = {}

for n_d, r_d in zip(null_data, random_data):
    assert n_d["domain"] == r_d["domain"]

    null_returns = np.array(n_d["instance_returns"])
    random_returns = np.array(r_d["instance_returns"])

    baselines = [
        get_baseline(null_returns[i, :], random_returns[i, :])
        for i in range(len(null_returns))
    ]

    mins[r_d["domain"]] = baselines

    pass

    # for i in range(10):
    #     plt.hist(
    #         random_returns[i, :],
    #         bins=50,
    #         alpha=0.5,
    #         label=f"Random {r_d['domain']}",
    #         density=True,
    #     )
    #     # plot normal distribution
    #     mu, std = np.mean(random_returns[i, :]), np.std(random_returns[i, :])
    #     xmin, xmax = plt.xlim()
    #     x = np.linspace(xmin, xmax, 100)
    #     p = np.exp(-0.5 * ((x - mu) / std) ** 2) / (std * np.sqrt(2 * np.pi))
    #     plt.plot(x, p, "k", linewidth=2, label="Random Normal Fit")

    #     plt.savefig(f"random_{r_d['domain']}_{i}.png")
    #     plt.clf()

    pass

print(json.dumps(mins))
