"""RRS Notes:
- The more samples, the better the estimate. If the distributions really are close, 
  then as the number of samples increases, the AUC remains somewhat static, and hence 
  the overlap is estimated as high. Conversely, if the distributions are not close, 
  the AUC climbs to 1.0 quickly and thus the overlap is estimated as low.
  
  We stop as soon as we achieve an AUC of 1.0, otherwise we'd be overcounting which 
  leads to the distributions seeming closer than they really are.
"""
import math
import sys

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.distributions as dists
from sklearn.metrics import roc_auc_score, auc


def compute_AURA(num_samples: list[int], AUCs: list[float]) -> float:
    try:
        last_index = AUCs.index(1) + 1
    except:
        last_index = len(AUCs)

    upper_bound_overlap = []
    for nsamples, metric in zip(num_samples[:last_index], AUCs[:last_index]):
        O = 1 - math.sqrt(math.log(2 / (1 - metric + 1e-10)) / nsamples)
        O = max(O, 0)
        upper_bound_overlap.append(O)
        
    try:
        AURA = auc([nsamples for nsamples in num_samples[:last_index]], upper_bound_overlap) / max(num_samples)
    except:
        AURA = 0

    return AURA

def do_plot(base, comp, savename):
    _ = plt.figure()
    plt.hist(base, label="Base", bins=50, alpha=0.5)
    plt.hist(comp, label="Comparison", bins=50, alpha=0.5)
    plt.savefig(f"./aura/gaussian/{savename}")
    plt.close()

def main():
    base = dists.Normal(loc=torch.tensor([0.0]), scale=torch.tensor([1.0]))
    locs = [0.1, 0.2, 0.4, 0.8, 10.]
    comparisons = [
        dists.Normal(loc=torch.tensor([loc]), scale=torch.tensor([1.0]))
        for loc in locs
    ]
    
    N = 12_000
    num_samples = [1, 10, 25, 50, 100, 200]
    # num_samples = [1, 10, 25, 50, 100, 200, 300, 400, 500]
    base = base.sample((N,))
    comparisons = [comp.sample((N,)) for comp in comparisons]

    # for ii in range(len(comparisons)):
    #     do_plot(base, comparisons[ii], "dist_{:.2f}.png".format(locs[ii]))

    ii = 0
    for comp in comparisons:
        
        AUCs = []
        for nsamples in num_samples:
            b = base.tolist()
            b = [b[start:start+nsamples] for start in range(0, len(b), nsamples)]
            b = [np.mean(x) for x in b]

            c = comp.tolist()
            c = [c[start:start+nsamples] for start in range(0, len(c), nsamples)]
            c = [np.mean(x) for x in c]
            
            # do_plot(b, c, "scores_{:.2f}_{}.png".format(locs[ii], nsamples))
            
            metric = roc_auc_score([0] * len(b) + [1] * len(c), b + c)
            AUCs.append(metric)
            
        AURA = compute_AURA(num_samples, AUCs)
        print(locs[ii], round(AURA, 2))
        ii += 1

    return 0

if __name__ == "__main__":
    torch.manual_seed(43)
    sys.exit(main())
