
import json
from math import log, sqrt

import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import roc_auc_score

def estimate_tv(
    num_samples: int,
    AUC: float,
):
    total_variation_lb = sqrt((1 / num_samples) * log((2 / (1-AUC))))
    total_variation_lb = min(total_variation_lb, 1)
    total_variation_lb = max(total_variation_lb, 0)
    return total_variation_lb

deltas = [0.10, 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.20, 0.3, 0.4, 0.45, 0.5]
num_samples = list(range(0, 250+1))
num_samples = [1] + num_samples[1:]
num_samples = np.array(num_samples)
_ = plt.figure(dpi=300)
ax = plt.subplot(111)

for delta in deltas:
    AUC_upper_bound = 1 - 2 * np.exp(-num_samples * delta**2)
    # AUC_upper_bound = np.maximum(AUC_upper_bound, 0.5)
    ax.plot(num_samples, AUC_upper_bound, label="TV={:.2f}".format(delta))
    plt.xlabel("Number of Samples")
    plt.ylabel("AUC")

fname = "./tv/MUD_paraphrase_Mistral-7B-Instruct-v0.3_N=5_small=False_8-3_preference_2_average_style=1.0_content=1.0_N=100.json"
with open(fname, "r") as fin:
    data = json.load(fin)
    num_samples = [d["num_samples"] for d in data if d["num_samples"]]
    # AUC = [max(d["AUC"], 0.50) for d in data][:len(num_samples)]
    AUC = [d["AUC"] for d in data][:len(num_samples)]
    TVs = [estimate_tv(n, d) for n, d in zip(num_samples, AUC)]

    print(list(zip(num_samples, TVs)))
    j = len(AUC) - 1
    for d in reversed(AUC):
        if d <= 0.99:
            best_estimate = estimate_tv(num_samples[j], d)
            break

        j -= 1
    print("Best Estimate: {:.4f}".format(best_estimate))
    print([d["AUC"] for d in data])
    print(AUC)
    ax.plot(num_samples, AUC, label="StyleParaphrase", color="black", linestyle="--")
    plt.ylim([min([d["AUC"] for d in data]), 1.0])

plt.title("Estimated TV: {:.2f}".format(best_estimate))
plt.legend()
plt.savefig("chakraborty.png")
plt.close()