
# Many functions here are adapted from https://github.com/BlackHC/BatchBALD,
# which is licensed under GNU GPL v3.

import math
import pdb

import torch
from torch import jit
import tqdm

def compute_multi_bald_batch(
    probs_B_K_C,
    acquisition_batch_size,
    device=None,
):

    assert len(probs_B_K_C.shape) == 3

    num_classes = probs_B_K_C.shape[2]
    k = probs_B_K_C.shape[1]
    # TODO: Compute BALD scores
    #partial_multi_bald_B = result.scores_B
    #partial_multi_bald_B = torch.rand(probs_B_K_C.shape[0])
    partial_multi_bald_B = mutual_information(torch.log(probs_B_K_C))

    # Now we can compute the conditional entropy
    # Paper eq. 9
    conditional_entropies_B = batch_conditional_entropy_B(probs_B_K_C)


    with torch.no_grad():
        num_samples_per_ws = 40000 // k
        num_samples = num_samples_per_ws * k

        multi_bald_batch_size = 16
        #if device.type == "cuda":
        #    # KC_memory = k*num_classes*8
        #    sample_MK_memory = num_samples * k * 8
        #    MC_memory = num_samples * num_classes * 8
        #    copy_buffer_memory = 256 * num_samples * num_classes * 8
        #    slack_memory = 2 * 2 ** 30
        #    multi_bald_batch_size = (
        #        torch_utils.get_cuda_available_memory() - (sample_MK_memory + copy_buffer_memory + slack_memory)
        #    ) // MC_memory

        #    global compute_multi_bald_bag_multi_bald_batch_size
        #    if compute_multi_bald_bag_multi_bald_batch_size != multi_bald_batch_size:
        #        compute_multi_bald_bag_multi_bald_batch_size = multi_bald_batch_size
        #        print(f"New compute_multi_bald_bag_multi_bald_batch_size = {multi_bald_batch_size}")
        #else:
        #    multi_bald_batch_size = 16

        acquisition_bag = []
        acquisition_bag_scores = []

        prev_joint_probs_M_K = None
        prev_samples_M_K = None
        for i in range(acquisition_batch_size):

            if i > 0:
                # Compute the joint entropy
                joint_entropies_B = torch.empty((len(probs_B_K_C),), dtype=torch.float64)

                exact_samples = num_classes ** i
                if exact_samples <= num_samples:
                    prev_joint_probs_M_K = joint_probs_M_K(
                        probs_B_K_C[acquisition_bag[-1]][None].to(device),
                        prev_joint_probs_M_K=prev_joint_probs_M_K,
                    )

                    # torch_utils.cuda_meminfo()
                    batch_exact_joint_entropy(
                        probs_B_K_C, prev_joint_probs_M_K, multi_bald_batch_size, device, joint_entropies_B
                    )
                else:
                    if prev_joint_probs_M_K is not None:
                        prev_joint_probs_M_K = None

                    # Gather new traces for the new acquisition_bag.
                    prev_samples_M_K = sample_M_K(
                        probs_B_K_C[acquisition_bag].to(device), S=num_samples_per_ws
                    )

                    # torch_utils.cuda_meminfo()
                    for joint_entropies_b, probs_b_K_C in tqdm.tqdm(
                        split_tensors(joint_entropies_B, probs_B_K_C, multi_bald_batch_size),
                        desc='sample batch entropy'
                    ):
                        joint_entropies_b.copy_(
                            sampled_batch(probs_b_K_C.to(device), prev_samples_M_K), non_blocking=True
                        )

                        # torch_utils.cuda_meminfo()

                    prev_samples_M_K = None

                # Paper eq. 8 (well, almost; we only use the conditional
                # entropies for the current candidates since the conditional
                # entropies for the already-selected examples in the batch are
                # the same for all candidates; the true result of eq. 8 is
                # computed below as `actual_multi_bald_B`)
                partial_multi_bald_B = joint_entropies_B - conditional_entropies_B
                joint_entropies_B = None

            # Don't allow reselection
            partial_multi_bald_B[acquisition_bag] = -math.inf

            winner_index = partial_multi_bald_B.argmax().item()

            # Actual MultiBALD is:
            actual_multi_bald_B = partial_multi_bald_B[winner_index] - torch.sum(
                conditional_entropies_B[acquisition_bag]
            )
            actual_multi_bald_B = actual_multi_bald_B.item()

            #print(f"Actual MultiBALD: {actual_multi_bald_B}")

            acquisition_bag_scores.append(actual_multi_bald_B)

            acquisition_bag.append(winner_index)

    return acquisition_bag


@jit.script
def logit_entropy(logits, dim: int, keepdim: bool = False):
    return -torch.sum((torch.exp(logits) * logits).double(), dim=dim, keepdim=keepdim)

@jit.script
def logit_mean(logits, dim: int, keepdim: bool = False):
    r"""Computes $\log \left ( \frac{1}{n} \sum_i p_i \right ) =
    \log \left ( \frac{1}{n} \sum_i e^{\log p_i} \right )$.

    We pass in logits.
    """
    return torch.logsumexp(logits, dim=dim, keepdim=keepdim) - math.log(logits.shape[dim])

@jit.script
def mutual_information(logits_B_K_C):
    sample_entropies_B_K = logit_entropy(logits_B_K_C, dim=-1)
    entropy_mean_B = torch.mean(sample_entropies_B_K, dim=1)

    logits_mean_B_C = logit_mean(logits_B_K_C, dim=1)
    mean_entropy_B = logit_entropy(logits_mean_B_C, dim=-1)

    mutual_info_B = mean_entropy_B - entropy_mean_B
    return mutual_info_B


def batch_exact_joint_entropy(probs_B_K_C, prev_joint_probs_M_K, chunk_size, device, out_joint_entropies_B):
    """This one switches between devices, too."""
    for joint_entropies_b, probs_b_K_C in tqdm.tqdm(
        split_tensors(out_joint_entropies_B, probs_B_K_C, chunk_size), desc='batch_exact_joint_entropy'
    ):
        joint_entropies_b.copy_(
            exact_batch(probs_b_K_C.to(device), prev_joint_probs_M_K), non_blocking=True
        )


def batch_exact_joint_entropy_logits(logits_B_K_C, prev_joint_probs_M_K, chunk_size, device, out_joint_entropies_B):
    """This one switches between devices, too."""
    for joint_entropies_b, logits_b_K_C in tqdm.tqdm(
        split_tensors(out_joint_entropies_B, logits_B_K_C, chunk_size), desc='batch_exact_joint_entropy_logits'
    ):
        joint_entropies_b.copy_(
            exact_batch(logits_b_K_C.to(device).exp(), prev_joint_probs_M_K), non_blocking=True
        )


def gather_expand(data, dim, index):
    #if DEBUG_CHECKS:
    #    assert len(data.shape) == len(index.shape)
    #    assert all(dr == ir or 1 in (dr, ir) for dr, ir in zip(data.shape, index.shape))

    max_shape = [max(dr, ir) for dr, ir in zip(data.shape, index.shape)]
    new_data_shape = list(max_shape)
    new_data_shape[dim] = data.shape[dim]

    new_index_shape = list(max_shape)
    new_index_shape[dim] = index.shape[dim]

    data = data.expand(new_data_shape)
    index = index.expand(new_index_shape)

    return torch.gather(data, dim, index)


def split_tensors(output, input_tensor, chunk_size):
    assert len(output) == len(input_tensor)
    return list(zip(output.split(chunk_size), input_tensor.split(chunk_size)))


@jit.script
def batch_multi_choices(probs_b_C, M: int):
    """
    probs_b_C: Ni... x C

    Returns:
        choices: Ni... x M
    """
    probs_B_C = probs_b_C.reshape((-1, probs_b_C.shape[-1]))

    # samples: Ni... x draw_per_xx
    choices = torch.multinomial(probs_B_C, num_samples=M, replacement=True)

    choices_b_M = choices.reshape(list(probs_b_C.shape[:-1]) + [M])
    return choices_b_M


### EXACT

@jit.script
def joint_probs_M_K_impl(probs_N_K_C, prev_joint_probs_M_K):
    assert prev_joint_probs_M_K.shape[1] == probs_N_K_C.shape[1]

    N, K, C = probs_N_K_C.shape
    prev_joint_probs_K_M_1 = prev_joint_probs_M_K.t()[:, :, None]

    # Using lots of memory.
    for i in range(N):
        i_K_1_C = probs_N_K_C[i][:, None, :]
        joint_probs_K_M_C = prev_joint_probs_K_M_1 * i_K_1_C
        prev_joint_probs_K_M_1 = joint_probs_K_M_C.reshape((K, -1, 1))

    prev_joint_probs_M_K = prev_joint_probs_K_M_1.squeeze(2).t()
    return prev_joint_probs_M_K


def joint_probs_M_K(probs_N_K_C, prev_joint_probs_M_K=None):
    if prev_joint_probs_M_K is not None:
        assert prev_joint_probs_M_K.shape[1] == probs_N_K_C.shape[1]

    N, K, C = probs_N_K_C.shape
    if prev_joint_probs_M_K is None:
        prev_joint_probs_M_K = torch.ones((1, K), dtype=torch.float64, device=probs_N_K_C.device)
    return joint_probs_M_K_impl(probs_N_K_C.double(), prev_joint_probs_M_K)


@jit.script
def entropy_from_M_K(joint_probs_M_K):
    probs_M = torch.mean(joint_probs_M_K, dim=1, keepdim=False)
    nats_M = -torch.log(probs_M) * probs_M
    entropy = torch.sum(nats_M)
    return entropy


@jit.script
def entropy_from_probs_b_M_C(probs_b_M_C):
    return torch.sum(-probs_b_M_C * torch.log(probs_b_M_C), dim=(1, 2))


@jit.script
def entropy_joint_probs_B_M_C(probs_B_K_C, prev_joint_probs_M_K):
    B, K, C = probs_B_K_C.shape
    M = prev_joint_probs_M_K.shape[0]
    joint_probs_B_M_C = torch.empty((B, M, C), dtype=torch.float64, device=probs_B_K_C.device)

    for i in range(B):
        torch.matmul(prev_joint_probs_M_K, probs_B_K_C[i], out=joint_probs_B_M_C[i])

    joint_probs_B_M_C /= K
    return joint_probs_B_M_C


def exact_batch(probs_B_K_C, prev_joint_probs_M_K=None):
    if prev_joint_probs_M_K is not None:
        assert prev_joint_probs_M_K.shape[1] == probs_B_K_C.shape[1]

    device = probs_B_K_C.device
    B, K, C = probs_B_K_C.shape
    probs_B_K_C = probs_B_K_C.double()

    if prev_joint_probs_M_K is None:
        prev_joint_probs_M_K = torch.ones((1, K), dtype=torch.float64, device=device)

    joint_probs_B_M_C = entropy_joint_probs_B_M_C(probs_B_K_C, prev_joint_probs_M_K)

    # Now we can compute the entropy.
    entropy_B = torch.zeros((B,), dtype=torch.float64, device=device)

    chunk_size = 256
    for entropy_b, joint_probs_b_M_C in split_tensors(entropy_B, joint_probs_B_M_C, chunk_size):
        entropy_b.copy_(entropy_from_probs_b_M_C(joint_probs_b_M_C), non_blocking=True)

    return entropy_B


@jit.script
def conditional_entropy_from_probs_B_K_C(probs_B_K_C):
    B, K, C = probs_B_K_C.shape
    return torch.sum(-probs_B_K_C * torch.log(probs_B_K_C), dim=(1, 2)) / K


def batch_conditional_entropy_B(probs_B_K_C):
    B, K, C = probs_B_K_C.shape

    out_conditional_entropy_B = torch.empty((B,), dtype=torch.float64)

    for conditional_entropy_b, probs_b_K_C in split_tensors(out_conditional_entropy_B, probs_B_K_C, 8192):
        probs_b_K_C = probs_b_K_C.double()
        conditional_entropy_b.copy_(conditional_entropy_from_probs_B_K_C(probs_b_K_C), non_blocking=True)

    return out_conditional_entropy_B


### SAMPLING

# probs_N_K_C: #ys x #ws x #classes
# samples_K_M: samples x #ws
# samples = #ws * num_samples_per_w
def sample_M_K_unified(probs_N_K_C, S=1000):
    probs_N_K_C = probs_N_K_C.double()

    K = probs_N_K_C.shape[1]

    choices_N_1_M = batch_multi_choices(torch.mean(probs_N_K_C, dim=1, keepdim=True), S * K).long()
    probs_N_K_M = gather_expand(probs_N_K_C, dim=-1, index=choices_N_1_M)

    # exp sum log seems necessary to avoid 0s?
    # probs_K_M = torch.exp(torch.sum(torch.log(probs_N_K_M), dim=0, keepdim=False))
    probs_K_M = torch.prod(probs_N_K_M, dim=0, keepdim=False)

    samples_M_K = probs_K_M.t()
    return samples_M_K


# probs_N_K_C: #ys x #ws x #classes
# samples_K_M: samples x #ws
# samples = #ws * num_samples_per_w
def sample_M_K(probs_N_K_C, S=1000):
    probs_N_K_C = probs_N_K_C.double()

    K = probs_N_K_C.shape[1]

    choices_N_K_S = batch_multi_choices(probs_N_K_C, S).long()

    expanded_choices_N_K_K_S = choices_N_K_S[:, None, :, :]
    expanded_probs_N_K_K_C = probs_N_K_C[:, :, None, :]

    probs_N_K_K_S = gather_expand(expanded_probs_N_K_K_C, dim=-1, index=expanded_choices_N_K_K_S)
    # exp sum log seems necessary to avoid 0s?
    probs_K_K_S = torch.exp(torch.sum(torch.log(probs_N_K_K_S), dim=0, keepdim=False))
    samples_K_M = probs_K_K_S.reshape((K, -1))

    samples_M_K = samples_K_M.t()
    return samples_M_K


@jit.script
def from_M_K(samples_M_K):
    probs_M = torch.mean(samples_M_K, dim=1, keepdim=False)
    nats_M = -torch.log(probs_M)
    entropy = torch.mean(nats_M)
    return entropy


# batch_ws_ps: #batch x #ws x #classes
# prev_ws_samples: #ws x samples
# entropy: #batch
def sampled_batch(probs_B_K_C, samples_M_K):
    probs_B_K_C = probs_B_K_C.double()
    samples_M_K = samples_M_K.double()

    device = probs_B_K_C.device
    M, K = samples_M_K.shape
    B, K_, C = probs_B_K_C.shape
    assert K == K_

    p_B_M_C = torch.empty((B, M, C), dtype=torch.float64, device=device)

    for i in range(B):
        torch.matmul(samples_M_K, probs_B_K_C[i], out=p_B_M_C[i])

    p_B_M_C /= K

    q_1_M_1 = samples_M_K.mean(dim=1, keepdim=True)[None]

    # Now we can compute the entropy.
    # We store it directly on the CPU to save GPU memory.
    entropy_B = torch.zeros((B,), dtype=torch.float64)

    chunk_size = 256
    for entropy_b, p_b_M_C in split_tensors(entropy_B, p_B_M_C, chunk_size):
        entropy_b.copy_(importance_weighted_entropy_p_b_M_C(p_b_M_C, q_1_M_1, M), non_blocking=True)

    return entropy_B


@jit.script
def importance_weighted_entropy_p_b_M_C(p_b_M_C, q_1_M_1, M: int):
    return torch.sum(-torch.log(p_b_M_C) * p_b_M_C / q_1_M_1, dim=(1, 2)) / M
