import numpy as np
import torch
from torch import Tensor


# x to probability given a reference distribution by computing the cdf
def p_value_fn(test_statistic: np.ndarray, X: np.ndarray, w=None):
    """Compute the p-value of a test statistic given a sample X.

    Args:
        test_statistic (np.ndarray): test statistic (n,m)
        X (np.ndarray): sample (N,m)

    Returns:
        np.ndarray: p-values (n,m)
    """
    if isinstance(test_statistic, list):
        test_statistic = np.array(test_statistic)
    if isinstance(X, list):
        X = np.array(X)
    if isinstance(test_statistic, torch.Tensor):
        test_statistic = test_statistic.numpy()
    if isinstance(X, torch.Tensor):
        X = X.numpy()
    if len(X.shape) == 1:
        X = X.reshape(-1, 1)
    if len(test_statistic.shape) == 1:
        test_statistic = test_statistic.reshape(-1, 1)

    mult_factor_min = np.where(X.min(0) > 0, np.array(1 / len(X)), np.array(len(X)))
    mult_factor_max = np.where(X.max(0) > 0, np.array(len(X)), np.array(1 / len(X)))
    lower_bound = X.min(0) * mult_factor_min
    upper_bound = X.max(0) * mult_factor_max
    X = np.concatenate((lower_bound.reshape(1, -1), X, upper_bound.reshape(1, -1)), axis=0)
    X = np.sort(X, axis=0)
    y_ecdf = np.concatenate([np.arange(1, X.shape[0] + 1).reshape(-1, 1) / X.shape[0]] * X.shape[1], axis=1)
    if w is not None:
        y_ecdf = y_ecdf * w.reshape(1, -1)
    return np.concatenate(list(map(lambda xx: np.interp(*xx).reshape(-1, 1), zip(test_statistic.T, X.T, y_ecdf.T))), 1)


@torch.no_grad()
def blahut_arimoto(channel: Tensor, max_iter: int = int(1e6), threshold: float = 1e-6, device=torch.device("cpu"), verbose=True):
    # C = number of classification labels, K = number of detectors, N = number of samples
    # channel = N x K x C
    # weights = N X K x 1
    if isinstance(channel, np.ndarray):
        channel = torch.from_numpy(channel).float()

    if channel.ndim == 2:
        channel = channel.unsqueeze(2)
        minus_channel = 1 - channel
        channel = torch.cat([channel, minus_channel], dim=2)

    num_samples = channel.shape[0]
    num_detectors = channel.shape[1]
    num_classes = channel.shape[2]

    # assert that sum of channels along dim=2 is 1 for all samples and detectors
    assert torch.allclose(torch.sum(channel, dim=2), torch.ones(num_samples, num_detectors)), "Channel probabilities do not sum to 1"

    # create a tensor for weights of shape N x K x 1 where each element is 1/K
    weights = torch.ones(num_samples, num_detectors, 1) / num_detectors

    weights = weights.to(device)
    channel = channel.to(device)

    for iter_id in range(max_iter):
        # compute q as the product of weights and channel for each of the N samples
        q = torch.mul(weights, channel)
        q = q / torch.sum(q, dim=1, keepdim=True)

        w1 = torch.prod(torch.pow(q, channel), dim=2, keepdim=True)
        w1 = w1 / torch.sum(w1, dim=1, keepdim=True)

        tolerance = torch.linalg.norm(w1 - weights) / torch.linalg.norm(weights)
        weights = w1
        if tolerance < threshold:
            break
    if verbose:
        print("Optimization finished with tolerance: ", tolerance.item())
    return weights.squeeze(dim=-1).detach().cpu()


@torch.no_grad()
def benchmark_blahut_arimoto(channel_sizes=[2, 10, 100, 1000], device="cpu"):
    import time
    import matplotlib.pyplot as plt
    times = []
    for size in channel_sizes:
        # Create a random tensor of the given size
        channel = torch.rand(1000, size)

        # Start the timer
        start_time = time.time()

        # Call the blahut_arimoto function
        blahut_arimoto(channel, threshold=1e-4, verbose=False, device=torch.device(device))

        # Stop the timer
        end_time = time.time()

        # Calculate the elapsed time
        elapsed_time = end_time - start_time
        print("Elapsed time for channel size {}: {}".format(size, elapsed_time))

        # Add the elapsed time to the list of times
        times.append(elapsed_time)
    
    plt.plot(channel_sizes, times)
    plt.show()

    return times

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("-d", "--device", type=str, default="cpu")
    args = parser.parse_args()
    print("DEVICE:", args.device)
    benchmark_blahut_arimoto(device=args.device)