import torch
import math
import numpy as np
import torch.nn.functional as F


def maxlin(x_orig, w_orig, eps, p):
    """solves the optimization problem, for x in [0, 1]^d and p > 1,

    max <w, delta> s.th. ||delta||_p <= eps, x + delta \in [0, 1]^d
    """
    bs = x_orig.shape[0]
    small_const = 1e-10
    x = x_orig.view(bs, -1)
    w = w_orig.view(bs, -1)
    gamma = x * (w < 0.0) + (1.0 - x) * (w > 0.0)
    delta = gamma.clone()

    w = w.abs()

    ind = gamma == 0.0
    gamma_adj, w_adj = gamma.clone(), w.clone()
    gamma_adj[ind] = small_const
    w_adj[ind] = 0.0

    mus = w_adj / (p * (gamma_adj ** (p - 1)))
    print("mus nan in tensor", mus.isnan().any())
    mussorted, ind = mus.sort(dim=1)
    gammasorted, wsorted = gamma.gather(1, ind), w_adj.gather(1, ind)

    # print(mussorted[-1])

    gammacum = torch.cat(
        [
            torch.zeros([bs, 1], device=x.device),
            (gammasorted**p).cumsum(dim=1),
        ],  # .fliplr()
        # torch.zeros([bs, 1], device=x.device),
        dim=1,
    )
    gammacum = (gammasorted**p).sum(dim=-1, keepdim=True) - gammacum
    wcum = (wsorted ** (p / (p - 1))).cumsum(dim=1)

    # print(gammacum[-1]) #wcum[-1]
    mussorted[mussorted == 0] = small_const
    mucum = torch.cat(
        [
            torch.zeros([bs, 1], device=x.device),
            wcum / (p * mussorted) ** (p / (p - 1)),
        ],
        dim=1,
    )
    print("mucum is nan", mucum.isnan().any())
    fs = gammacum + mucum - eps**p
    # print(fs[-1], gammacum[-1], mucum[-1])

    ind = fs[:, 0] > 0.0  # * (fs[-1] < 0.)
    # print(ind)
    lb = torch.zeros(bs).long()
    ub = lb + fs.shape[1]

    u = torch.arange(bs)
    for c in range(math.ceil(math.log2(fs.shape[1]))):
        a = (lb + ub) // 2
        indnew = fs[u, a] > 0.0
        lb[indnew] = a[indnew].clone()
        ub[~indnew] = a[~indnew].clone()

    # print(lb, ub)
    pmstar = wcum[u, lb - 1] / (eps**p - gammacum[u, lb])  # wcum[u, lb]
    print("pmstar is nan", pmstar.isnan().any())
    pmstar[pmstar == 0] = small_const
    deltamax = w ** (1 / (p - 1)) / pmstar.unsqueeze(1) ** (1 / p)  # ** (1 / (p - 1))
    print("deltamax is nan", deltamax.isnan().any())
    # print(deltamax)
    delta[ind] = torch.min(
        delta[ind],  # deltamax[ind].unsqueeze(1
        # ) * torch.ones_like(delta[ind])
        deltamax[ind],
    )

    return delta.view(w_orig.shape) * w_orig.sign()


def LMO(grad, x_0, eps, p=2):
    n = x_0.shape[0]
    # -grad as we had -loss in the beginning
    if p == 2:
        return -eps * (grad) / grad.view(n, -1).norm(p=p, dim=1).view(-1, 1, 1, 1) + x_0
    elif p == 1:
        # ToDo bring m calculation outside to make more efficient
        m_ = grad.view(n, -1).shape[1]
        vals, idx = grad.view(n, -1).max(dim=1)
        return (
            -eps
            * (F.one_hot(idx, num_classes=m_) * vals.sign().view(-1, 1)).view_as(grad)
            + x_0
        )
    elif 1 < p < float("inf"):
        return (
            -eps
            * grad.sign()
            * grad.abs() ** (1 / (p - 1))
            / ((grad.abs() ** (p / (p - 1))).view(n, -1).sum(1) ** (1 / p)).view(
                -1, 1, 1, 1
            )
        ) + x_0
