import torch
import numpy as np

from chip.models.iterative_model import TomographicReconstruction


def angular_distance(a, b):
    return torch.minimum(torch.abs(a - b), 180 - torch.abs(a - b))


def calculate_entropy(samples):
    bins = 10
    # Create a histogram with the specified number of bins
    hist = torch.histc(samples, bins=bins, min=0, max=400)

    # Convert histogram counts to probabilities
    probabilities = hist / torch.sum(hist)
    probabilities += 1e-8

    # Calculate entropy
    entropy = -torch.sum(probabilities * torch.log2(probabilities))
    return entropy


def entropy(images, angles):
    batch_entropy = torch.vmap(calculate_entropy)

    clipped = torch.clip(images, 0, 1).squeeze(1)
    tr = TomographicReconstruction(clipped, use_sigmoid=False).to(clipped.device)
    with torch.no_grad():
        sol_sinograms = tr.forward(angles)
        entropy_sinogram = batch_entropy(sol_sinograms.reshape(len(clipped), -1).T).reshape(180, 512)
        entropy_sinogram[torch.isnan(entropy_sinogram)] = 0.
        angle_entropy = torch.sum(entropy_sinogram, dim=1)

    return angle_entropy, entropy_sinogram, angle_entropy


def choose_top_k(scores, angles, k=1, radius=0, exclude_indices=None):
    mask = torch.ones(len(scores), dtype=torch.bool)

    if exclude_indices:
        mask[exclude_indices] = False

    top_k = []
    for _ in range(k):
        if torch.all(~mask):
            print("Warning: No angles available for selection.")
            break

        scores[~mask] = float('-inf')
        max_index = torch.argmax(scores).item()
        mask[max_index] = False  # remove selected index from mask
        top_k.append(max_index)

        # mask elements in radius
        if radius:
            mask[angular_distance(angles, angles[max_index]) < radius] = False

    return top_k

def binary_search_order(indices):
    if len(indices) <= 1:
        return indices
    else:
        mid = len(indices) // 2
        p1 = binary_search_order(indices[:mid])
        p2 = binary_search_order(indices[mid:])

        # merge
        new = []
        for i in range(min(len(p1), len(p2))):
            new.append(p1[i])
            new.append(p2[i])
        if len(p1) > len(p2):
            new.append(p1[-1])
        elif len(p2) > len(p1):
            new.append(p2[-1])
        return new


def uniform(num_angles, num_selected):
    # re-order using binary search
    order = binary_search_order([*range(num_angles)])
    return order[num_selected]

def random(tr, indices):
    return np.random.choice(indices)


def mean_gradient_norm(tr, samples, angles):
    return torch.stack([gradient_norm_squared(tr, TomographicReconstruction(sample, use_sigmoid=False), angles) for sample in samples]).mean(axis=0)


def gradient_norm_squared(tr, target_tr, angles):
    device = tr.device
    with torch.no_grad():
        target_sinogram = target_tr.forward(angles)

    scores = []
    for angle, target_sinogram in zip(angles, target_sinogram):
        tr.zero_grad()
        squared_error = ((tr.forward(torch.tensor([angle], device=device)) - target_sinogram) ** 2).mean()
        squared_error.backward()
        gradient = tr.img.grad.detach().cpu().numpy()
        scores.append(np.sum(gradient ** 2))

    return torch.tensor(scores, device=tr.device)




# def mle_gradient_norm(tr, angles, device='cpu', **kwargs):
#     mle_tr = tr.get_mle_tr()
#     with torch.no_grad():
#         mle_sinogram = mle_tr.forward(torch.tensor(angles, device=device))
#     # squared_error = ((tr.forward(indices) - mle_tr.forward(indices))**2).sum(dim=1)
#
#     scores = []
#     for angle, target_sinogram in zip(angles, mle_sinogram):
#         tr.zero_grad()
#         squared_error = ((tr.forward(torch.tensor([angle], device=device)) - target_sinogram) ** 2).mean()
#         squared_error.backward()
#         gradient = tr.img.grad.detach().cpu().numpy()
#         scores.append(np.linalg.norm(gradient))
#
#     return angles[np.argmax(scores)]


# def mse_loss(tr, indices, **kwargs):
#     mle_tr = tr.get_mle_tr()
#     with torch.no_grad():
#         mle_sinogram = mle_tr.forward(indices)
#         est_sinogram = tr.forward(indices)
#
#         scores = ((mle_sinogram - est_sinogram) ** 2).mean(axis=1)
#
#     return indices[np.argmax(scores)]

# def uniform_offset(num_angles, num_selected, shift='random'):
#     if shift == 'random':
#         shift = np.random.randint(num_angles)
#     return (uniform(num_angles, num_selected) + shift) % num_angles


# def variance(tr, indices, **kwargs):
#     with torch.no_grad():
#         img = tr.get_img()
#         var = img * (1 - img)
#         var_tr = TomographicReconstruction(prior=var, theta=tr.theta, use_sigmoid=False)
#         scores = var_tr.forward(indices).mean(dim=1).squeeze().detach().cpu().numpy()
#     return indices[np.argmax(scores)]