import cvxpy as cp
import numpy
import torch
from cvxpylayers.torch.cvxpylayer import CvxpyLayer


def reorder(ordered, indices):
    return ordered.gather(1, indices)

# Create isotonic regression problem in cvxpy, solve and return the solution
# Single input, no batch
# No derivatives
# Compute the monotone increasing vector closest to (input) y
def isotonic_regression( y ):
    N = len(y)
    D  = torch.eye(N)[:N-1] - torch.eye(N)[1:] # Differencing Operator

    x = cp.Variable(N)
    constraints = [  D@x <= 0 ]
    problem = cp.Problem( cp.Minimize(  cp.norm(x - y)  ) , constraints )
    assert problem.is_dpp()
    problem.solve()
    return x.value

# Create cvxlayer to replace the above optimization and return the layer
# The layer takes batch inputs and provides derivatives
def get_isotonic_regression_layer( N ):
    D  = torch.eye(N)[:N-1] - torch.eye(N)[1:] # Differencing Operator

    x = cp.Variable(N)
    y = cp.Parameter(N)
    constraints = [  D@x <= 0 ]   # x_1 <= x_2 ,  x_2 <= x_3 , ....
    problem = cp.Problem( cp.Minimize(  cp.norm(x - y)  ) , constraints )
    assert problem.is_dpp()
    cvxlayer = CvxpyLayer(problem, parameters=[y], variables=[x])
    return cvxlayer


""" unfinished
def fair_ranking_LP( y ):
    N = len(y)

    P = cp.Variable((N,N))
    constraints = [ P >= 0 ]
    problem = cp.Problem( cp.Minimize(  cp.sum(P)  ) , constraints )
    assert problem.is_dpp()
    problem.solve()
    return x.value
"""


# return a cvxpy layer that solves the fair ranking LP of Singh
# The layer takes batches of 'relevance scores' or similar
def get_fair_ranking_LP_layer( N ):
    P = cp.Variable((N,N))
    y = cp.Parameter(N)  # relevance scores
    e = torch.ones(N)
    b = (1. / torch.arange(1.,100.))[:N]  # position bias factor (for fairness constraints)
    constraints = [ P@e == 1, e@P == 1, P>=0 ]
    problem = cp.Problem( cp.Maximize(  y@(P@b)  ) , constraints )
    assert problem.is_dpp()
    cvxlayer = CvxpyLayer(problem, parameters=[y], variables=[P])
    return (lambda z: cvxlayer(z)[0])
