import numpy as np
import torch
from autograd import grad
from autograd import numpy as np
from autograd.scipy.special import logsumexp
from tqdm import tqdm


def get_lossfunc(pis):
    OP_MAT = pis.reshape((-1, 1)) @ pis.reshape((1, -1))

    def compute_loss(W):
        OP_MAT = pis.reshape((-1, 1)) @ pis.reshape((1, -1))
        return np.sum(pis * logsumexp(W, axis=1)) + -np.sum(OP_MAT * W)

    return compute_loss


def compute_losses_over_time_numpy(pis, T, ss):
    W0 = np.zeros((pis.shape[0], pis.shape[0]))
    W = np.copy(W0)
    f = get_lossfunc(pis)
    g = grad(f)
    losses = []
    for _ in tqdm(range(T), total=T, leave=False):
        losses.append(f(W))
        W = W - ss * g(W)
    losses.append(f(W))
    return losses


def loss_at_0(pis):
    return np.log(pis.shape[0])


def entropy(pis):
    return -np.sum(pis * np.log(pis))


def compute_pis(d, alpha):
    pis = np.array([1 / k**alpha for k in range(1, d + 1)])
    return pis / np.sum(pis)


def compute_losses_over_time_torch(pis, T, ss):
    with torch.no_grad():
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        pis_torch = torch.tensor(pis, dtype=torch.float32, device=device)

        OP_MAT = pis_torch[:, None] @ pis_torch[None, :]

        def f(W):
            return torch.dot(pis_torch, torch.logsumexp(W, dim=1)) - torch.sum(
                OP_MAT * W
            )

        def g(W):
            return pis_torch[:, None] * torch.softmax(W, dim=1) - OP_MAT

        losses = []
        W = torch.zeros((len(pis), len(pis)), dtype=torch.float32, device=device)
        for _ in tqdm(range(T), total=T, leave=False):
            losses.append(f(W).item())
            W = W - ss * g(W)
        losses.append(f(W).item())
    return losses


compute_losses_over_time = compute_losses_over_time_torch
