"""
This file contains many utility functions for performing the tests in the TestExecutor
class. Please see that class for further documentation and usage.
"""
import gc
import warnings

import torch as t
from scipy.stats import binomtest, wilcoxon

############################################################
# TODO: FUNCTIONS TO REMOVE AFTER REFACTORING
############################################################


def tail_test(
    base_distribution,
    target,
    quantile=0.1,
    alternative="less",
    direction="target>base",
    return_empirical_quantile=False,
):
    """
    perform the tail test
    :param base_distribution: the base distribution
    :param target: the target score
    :param quantile: the quantile of the base distribution we will compare against
    :param alternative: the alternative hypothesis
    :param test_statistics: target>base or target<base, the parameter is base<target if we want to say that the probability of target > base is ...
    :param return_empirical_quantile: whether we return the empirical quantile or not

    """
    # import pdb; pdb.set_trace()
    if len(base_distribution) == 0:
        raise ValueError("The base distribution is empty")

    if direction == "target<base":
        t_statistics = (target < base_distribution).sum()
    elif direction == "target>base":
        t_statistics = (target > base_distribution).sum()

    result = binomtest(
        t_statistics, n=len(base_distribution), p=quantile, alternative=alternative
    )

    # print("i am in the test, the p value is", result.pvalue)
    if return_empirical_quantile:
        return result.pvalue, t_statistics / len(base_distribution)
    else:
        return result.pvalue


def wilcoxon_test(target, base, alternative="two-sided"):
    """
    perform the wilcoxon test
    having a circuit that's better than 95 % of the random circuits is the same as having loss that's smaller than 5 % of the random circuits
    :param target: the target, which typically is the original model output
    :param base: the base distribution, which typically is the candidate circuit output
    """
    if len(target) != len(base):
        raise ValueError("The target and base distribution have different length")
    # change them into numpy
    target = target.detach().numpy()
    base = base.detach().numpy()
    if abs(target - base).mean() == 0:
        return 1

    result = wilcoxon(target, base, alternative=alternative)
    return result.pvalue


def find_redundant_edges(inflated_circuit, candidate_circuit):
    """
    find the redundant edges in the inflated circuit
    """
    redundant_edge = set(inflated_circuit).difference(set(candidate_circuit))
    return list(redundant_edge)


## these are for the permutation test
def gaussian_kernel(X, Y, sigma):
    X = X.view(-1, 1, X.size(-1))
    Y = Y.view(1, -1, Y.size(-1))
    beta = 1 / (2 * sigma**2)
    dist = t.sum((X - Y) ** 2, dim=2)
    return t.exp(-beta * dist)


def hsic(X, Y, sigma):
    n = X.size(0)

    K = gaussian_kernel(X, X, sigma)
    L = gaussian_kernel(Y, Y, sigma)

    H = t.eye(n) - t.ones((n, n)) / n
    H = H.to(X.device)
    K_centered = t.mm(t.mm(H, K), H)
    L_centered = t.mm(t.mm(H, L), H)

    hsic_statistic = t.trace(t.mm(K_centered, L_centered)) / (n - 1) ** 2
    return hsic_statistic


def permutation_test(X, Y, num_permutations=1000):
    # check the shape of X and Y

    if X.shape != Y.shape:
        raise RuntimeError("X and Y should have the same shape")

    if len(X.shape) == 1:
        X = X.view(X.shape[0], -1)
        Y = Y.view(Y.shape[0], -1)
        warnings.warn(
            "X is not a 2D tensor, converting X and Y to 2D tensor", UserWarning
        )
    if len(X.shape) > 2:
        raise RuntimeError(
            "i don't know if the permutation test can handle 2d tensor homie"
        )
    # take the median of the pairwise distance as the sigma
    sigma = t.cdist(X, Y, p=2).median()
    n = X.size(0)
    hsic_observed = hsic(X, Y, sigma)
    # set the device to device of X
    hsic_permutations = t.zeros(num_permutations).to(X.device)

    for i in range(num_permutations):
        Y_permuted = Y[t.randperm(n)]
        hsic_permutations[i] = hsic(X, Y_permuted, sigma)

    p_value = (hsic_permutations >= hsic_observed).float().mean()

    return {
        "hsic": hsic_observed.item(),
        "p_value": p_value.item(),
        "simulated_statistics": hsic_permutations,
    }


def free_memory(tensors_to_delete=None, clear_cache=True, run_gc=True):
    """
    Free up memory by removing all hooks from the experiment object.
    """
    if tensors_to_delete is None:
        for tensor in tensors_to_delete:
            del tensor
    if clear_cache:
        t.cuda.empty_cache()
    if run_gc:
        gc.collect()
