import torch
from infomet import against_uniform, get_measure_fn
import numpy as np

temperatures = [
    0.1,
    0.25,
    0.5,
    1,
    1.25,
    1.5,
    2,
    2.25,
    2.5,
    3,
    3.25,
    3.5,
    4,
    4.25,
    4.5,
    5,
]

alphas = [round(a, 3) if a != 0 else 0.05 for a in np.arange(0, 2, 0.1) if a != 1] + [5]


def mk_score_function():
    fisher_rao = against_uniform(get_measure_fn("fisher_rao"))
    renyi = against_uniform(get_measure_fn("renyi"))
    # alpha_div = against_uniform(get_measure_fn("alpha"))
    kl = against_uniform(get_measure_fn("kl"))

    def compute_scores(logits):
        scores = {}
        for t in temperatures:
            probas = (logits / t).softmax(-1)
            doctor = (probas**2).sum(-1).detach().cpu()

            # bs, 1
            scores[f"doctor-{t}"] = doctor

            scores[f"kl-{t}"] = kl(probas).detach().cpu()
            scores[f"fisher_rao-{t}"] = fisher_rao(probas)
            scores[f"energy-{t}"] = (
                (-t * torch.log(torch.exp(logits / t).sum(-1))).detach().cpu()
            )
            scores[f"entropy-{t}"] = (
                (-torch.xlogy(probas, probas).sum(-1)).detach().cpu()
            )

            v, _ = torch.max(probas, dim=1)
            scores[f"msp-{t}"] = v.squeeze().detach().cpu()

            for a in alphas:
                renyi_scores = renyi(probas, alpha=a).detach().cpu()

                scores[f"renyi-{t}-{a}"] = renyi_scores

                # alphas_scores = alpha_div(probas, alpha=a).detach().cpu()
                # scores[f"alpha-{t}-{a}"] = alphas_scores

        return scores

    return compute_scores


def mk_score_function_nostep():
    fisher_rao = against_uniform(get_measure_fn("fisher_rao"))
    renyi = against_uniform(get_measure_fn("renyi"))
    # alpha_div = against_uniform(get_measure_fn("alpha"))
    kl = against_uniform(get_measure_fn("kl"))

    def compute_scores(probas, t):
        scores = {}

        doctor = (probas**2).sum(-1).detach().cpu()
        scores[f"doctor-{t}"] = doctor

        scores[f"kl-{t}"] = kl(probas).detach().cpu()
        scores[f"fisher_rao-{t}"] = fisher_rao(probas)
        # scores[f"energy-{t}"] = (
        #     (-t * torch.log(torch.exp(logits / t).sum(-1))).detach().cpu()
        # )
        scores[f"entropy-{t}"] = (-torch.xlogy(probas, probas).sum(-1)).detach().cpu()

        v, _ = torch.max(probas, dim=-1)
        scores[f"msp-{t}"] = v.squeeze().detach().cpu()

        for a in alphas:
            renyi_scores = renyi(probas, alpha=a).detach().cpu()
            scores[f"renyi-{t}-{a}"] = renyi_scores

        return scores

    return compute_scores


def compute_mahalanobis(hidden, ref_mahalanobis):
    mean, cov = ref_mahalanobis
    delta = (hidden - mean).squeeze()
    mahalanobis = torch.dot(delta, torch.matmul(torch.inverse(cov), delta))
    return torch.sqrt(mahalanobis)


def compute_projection(prob, set, div):
    divs = div(prob, set)
    v = torch.min(divs)
    return v


def mk_iproj_score_function():
    fisher_rao = get_measure_fn("fisher_rao")
    renyi = get_measure_fn("renyi")
    # alpha_div = get_measure_fn("alpha")
    kl = get_measure_fn("kl")

    def iproj(set_dist, probs, sets):
        for size in [10, 20, 50, 100, 250, 500, 750, 1000, 1500, 2000]:
            set_dist[f"iproj-kl-{size}"] = (
                compute_projection(
                    probs, sets[:size], lambda x, y: kl(ref_dist=y, hypo_dist=x)
                )
                .detach()
                .cpu()
                .tolist()
            )
            set_dist[f"iproj-frao-{size}"] = (
                compute_projection(
                    probs,
                    sets[:size],
                    lambda x, y: fisher_rao(ref_dist=y, hypo_dist=x),
                )
                .detach()
                .cpu()
                .tolist()
            )

            for a in alphas:
                set_dist[f"iproj-renyi-{a}-{size}"] = (
                    compute_projection(
                        probs,
                        sets[:size],
                        lambda x, y: renyi(ref_dist=y, hypo_dist=x, alpha=a),
                    )
                    .detach()
                    .cpu()
                    .tolist()
                )
                # set_dist[f"iproj-alpha-{a}-{size}"] = (
                #     compute_projection(
                #         probs,
                #         sets[:size],
                #         lambda x, y: alpha_div(ref_dist=y, hypo_dist=x, alpha=a),
                #     )
                #     .detach()
                #     .cpu()
                #     .tolist()
                # )

        return set_dist

    return iproj
