import sys

sys.path.append(".")

from src.tools.sharpness_tools import sam_flatness, shannon_entropy, low_pass, frob_norm, fishr, calc_eigenvalues, \
    entropy, entropy_grad


def measure_sharpness(metric, model, data_loader):
    if metric == "sam":
        result = sam_flatness(model=model, data_loader=data_loader, epsilon=0.05, tol=1e-6)
    elif metric == "shannon":
        result = shannon_entropy(model=model, data_loader=data_loader)
    elif metric == "low_pass":
        result = low_pass(model=model, data_loader=data_loader, sigma=0.01, mcmc_itr=100)
    elif metric == "frob":
        result = frob_norm(model=model, data_loader=data_loader, mcmc_itr=10)
    elif metric == "fishr":
        result = fishr(model=model, data_loader=data_loader)
    elif metric == "eig_avg":
        result, _ = calc_eigenvalues(model=model, data_loader=data_loader, max_itr=10, draws=5)
        result = result.sum()
    elif metric == "max_eig":
        result, _ = calc_eigenvalues(model=model, data_loader=data_loader, max_itr=10, draws=5)
        result = result[-1]
    elif metric == "entropy":
        result = entropy(model=model, data_loader=data_loader, gamma=10, mcmc_itr=100)
    elif metric == "entropy_grad":
        result = entropy_grad(model=model, data_loader=data_loader)
    else:
        raise NotImplementedError
    return result
