from __future__ import print_function
from ortools.linear_solver import pywraplp
import numpy as np
from ortools.sat.python import cp_model
import collections
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import time, sys, math
from itertools import permutations
import cvxpy as cp
from cvxpylayers.torch.cvxpylayer import CvxpyLayer
# from fast_soft_sort.pytorch_ops import soft_sort, soft_rank, phedron_project
# from sparsemax import Sparsemax

import sparsemax

from frank_wolfe import compute_Moreau_grad_softsort
from fold_opt.fold_opt import FoldOptLayer


"""
Solve the OWA optimization problem with linear constraints
max OWA_w( Cx )
s.t.
    Ax == b
     x >= 0

Inputs: 
Numpy arrays: 
Matrix A, accessed as A[i][j]
vector b, accessed as b[i]

A is NxM
b is N

"""
def owa_optim_lp(A,b,C,w_perm):
	solver = pywraplp.Solver('OWA_OPT',
	                     pywraplp.Solver.GLOP_LINEAR_PROGRAMMING)

	(N,M) = A.shape
	(P,Q) = C.shape

	# Declare optimization variables (and their bounds)
	x = [  solver.NumVar(0, solver.infinity(),"x[{}]".format(j))
	       for j in range(M)  ]   # x >=0
	y = [  solver.NumVar(-solver.infinity(), solver.infinity(),"y[{}]".format(j)) for j in range(0,P)  ]   # y is unbounded
	z = solver.NumVar(-solver.infinity(), solver.infinity(),"z")  # z is the OWA objective value

	# A list of equality constraints, equivalent to Ax == b
	# The matrix A encompasses a constraint for each row
	constraints = [  solver.Constraint( b[i],b[i] )   for i in range(N)  ]
	for i in range(N):
		for j in range(M):
			constraints[i].SetCoefficient( x[j] ,  A[i][j] )

  # Constraints that define the OWA criteria
	crit_constraints = [  solver.Constraint( 0,0 )   for i in range(P)  ]
	for i in range(P):
		crit_constraints[i].SetCoefficient( y[i] ,  -1 )
		for j in range(Q):
			crit_constraints[i].SetCoefficient( x[j] ,  C[i][j] )

	# Lower bounding constraints for OWA objective
	owa_constraints = [  solver.Constraint( -solver.infinity(),0 )   for i in range(math.factorial(P))  ]
	count = 0 
	for i,p in enumerate(w_perm):
		for j in range(len(p)):
			owa_constraints[i].SetCoefficient( y[j] ,  -p[j] )
			owa_constraints[i].SetCoefficient( z ,  1 )

	objective = solver.Objective()
	objective.SetCoefficient(z, 1)
	objective.SetMaximization()
	solver.Solve()
	opt =  z.solution_value()
	soln_x = np.array(  [ _.solution_value() for _ in x]  )
	soln_y = np.array(  [ _.solution_value() for _ in y]  )

	return soln_x, soln_y, opt   # x,y,z


def lp_optim(A,b,c):

	"""
	Solve the related LP optimization problem 
	max c^T x
	s.t.
	    Ax == b
	     x >= 0
	"""

	"""
	A = np.array([[1, 2, 1, 0, 0], [3, -1, 0, -1, 0], [1, -1, 0, 0, 1]]).astype(float)
	b = np.array([14, 0, 2]).astype(float)
	c = np.array([3, 4, 0, 0, 0] ).astype(float)
	"""
	solver = pywraplp.Solver('LP_OPT',
	                     pywraplp.Solver.GLOP_LINEAR_PROGRAMMING)
	(N,M) = A.shape
	# Declare optimization variables (and their bounds)
	x = [solver.NumVar(0, solver.infinity(),"x[{}]".format(j)) for j in range(0,M)  ]   # x >=0 
	# A list of equality constraints, equivalent to Ax == b
	# The matrix A encompasses a constraint for each row
	constraints = [  solver.Constraint( b[i],b[i] )   for i in range(0,N)  ]
	for i in range(N):
	    for j in range(M):
	        constraints[i].SetCoefficient( x[j] ,  A[i][j] )
	objective = solver.Objective()
	for i in range(M):
		objective.SetCoefficient(x[i], c[i])
	objective.SetMaximization()
	solver.Solve()
	soln_x = np.array(  [ _.solution_value() for _ in x]  )

	return soln_x, objective.Value()



def gini_indices(m):
	"""
	Generate Gini index weights for OWA function
	m is the number of criteria / items
	"""
	j = torch.arange(m,dtype=torch.float) + 1
	return (m - j + 1) / m


def gini_indices_square(m):
	"""
	Generate Gini index weights for OWA function
	m is the number of criteria / items
	"""
	j = torch.arange(m,dtype=torch.float) + 1
	return ((m - j + 1) / m)**2


def set_up_owa_solver_coef(C):
	"""
	Take C matrix : mx n (m task, n stocks)
	Output:
		A: 1 x n
		b: 1
		C: m x n
		w: m x1
	"""
	(M, N) = C.shape

	A = np.ones(N, dtype=float).reshape(1, -1)
	b = np.array([1.0])
	w = gini_indices_square(M).numpy().astype(float)
	return A,b, w

	
def owa_optim_lp_portifolio_wrapper(cost):
	A, b, w = set_up_owa_solver_coef(cost)
	w_perm = permutations(w)
	soln_x, soln_y, opt = owa_optim_lp(A, b, cost, w_perm)
	return soln_x, soln_y, opt


def cvx_qp(N,eps=5e-1):   
	x = cp.Variable(N)
	c = cp.Parameter(N)
	constraints = [  0<=x, cp.sum(x) == 1  ]  
	problem  = cp.Problem(cp.Maximize(  c @ x  - eps*cp.norm( x,p=2 )**2   ),  constraints)
	qp_cvxlayer = CvxpyLayer(problem, parameters=[c], variables=[x])
	qp_cvxlayer_cvx = lambda z: qp_cvxlayer(z, solver_args={'eps':1e-6, 'max_iters': 500_000})[0]
	return qp_cvxlayer_cvx


def owa_optim_lp_cvxpy(P, Q):
	"""
	cvxpy solver of owa max min LP model
	input: 
		C:	torch.tensor([[-7, 3], [6,8], [8, -3], [10,7], [1,5]] ).
		P,Q : C.shape, P:num task, Q:num items. 
		w: permutations(np.array([1.        , 0.80000001, 0.60000002, 0.40000001, 0.2       ]))

	"""
	# w_perm = permutations(w)
	# w = gini_indices(C.shape[0]).numpy().astype(float)
	x = cp.Variable(Q, nonneg=True)
	y = cp.Variable(P)
	z = cp.Variable()
	c = cp.Parameter((P,Q))
	w = gini_indices_square(P).numpy().astype(float)
	w_perm = permutations(w)
	constraints = [cp.sum(x) == 1, c@x==y, ]
	for i,p in enumerate(w_perm):
	  constraints.append(z <=p@y)
	problem  = cp.Problem(cp.Maximize(z),  constraints)
	owalp_cvxlayer = CvxpyLayer(problem, parameters=[c], variables=[x, y, z])
	owalp_cvxlayer_wrapper = lambda z: owalp_cvxlayer(z,solver_args={'eps':1e-6, 'max_iters': 500_000})[0]
	# opt = problem.solve()
	# soln_x, soln_y = x.value, y.value
	return owalp_cvxlayer_wrapper

	
def compute_moreau_grad_portfolio(c_batch, x, beta, w):
	"""
	C: B x M x N
	x: B x N
	w: Mx1
	"""
	p = torch.einsum("imn, in-> im", c_batch, x)
	proj_perm = compute_Moreau_grad_softsort(w=w, z=-p/beta)
	grad = torch.einsum("imn, im->in", c_batch, proj_perm) #chain rule term
	return grad


def compute_owa_subgrad(c, xk, beta, w):
	multi_obj = torch.einsum("imn, in->im", c, xk)# B x M
	multi_obj_sigma, sigma = (-multi_obj).sort(descending=True,dim=-1)
	sigma_inv = sigma.argsort(dim=-1)
	w_gini_batch = w.repeat(sigma_inv.shape[0], 1)
	w_inv = torch.gather(w_gini_batch, 1, sigma_inv)
	return torch.einsum("imn, im->in", c, w_inv)



def cvxpy_projection(x_prime):
	# print('y',  y.shape)
	B, M = x_prime.shape
	x = cp.Variable((M))
	y = cp.Parameter((M))
	constraints = [  0<=x, cp.sum(x) == 1]  
	problem = cp.Problem( cp.Minimize(  cp.norm(x - y)**2  ) , constraints )
	assert problem.is_dpp()
	# problem.solve()
	cvxlayer = CvxpyLayer(problem, parameters=[y], variables=[x])
	cvxlayer_wrapper = lambda z: cvxlayer(z)[0]
	return cvxlayer_wrapper(x_prime)




def owa_pgd_solver_wrapper(w, num_item, beta, num_iter, lr, use_subgrad):
	"""
	implement pgd of Moreau owa solver
	"""
	projector = sparsemax.Sparsemax(dim=-1)
	def pgd_solver(c, eps=1e-7, use_subgrad=use_subgrad, beta=beta):
		xk = torch.zeros(c.shape[0], c.shape[-1])
		# xk.requires_grad = True
		# print('c', c[23])
		for i in range(num_iter): 
			# print('i', i)
			# print('inside pgd', i)
			if use_subgrad:
				grad = compute_owa_subgrad(c, xk, beta, w)
			else:
				grad = compute_moreau_grad_portfolio(c, xk, beta, w)
			# xk_old = xk.clone()
			xk = projector((xk + lr*grad) + 1e-5)
			# diff = (torch.abs(xk - xk_old).sum(-1)).mean()
			# if diff < eps:
			# 	print('pgd i',i)
			# 	break

		return xk
	return pgd_solver





def owa_fw_solver_wrapper(w, num_item, beta, num_iter):
	def fw_solver(c, eps=1e-5, use_subgrad=False):
		xk = torch.zeros(c.shape[0], num_item)
		# In fold-opt this will receive c as a batch of flattened vectors 
		# reshape to matrix for this solve
		c = c.view(c.shape[0], -1, num_item)
		for i in range(num_iter):
			grad = compute_moreau_grad_portfolio(c, xk, beta, w)
			sol = F.one_hot(torch.argmax(grad,dim=-1),num_classes=num_item)
			# xk_old = xk.clone()
			xk = (1 - 2/(i+2))*xk + 2/(i+2)*sol
			# diff = (torch.abs(xk - xk_old).mean(-1)).mean()
			# if diff < eps:
			# 	print('pgd i',i)
			# 	break
		# print(diff)
		# print(xk)
		return xk

	return fw_solver

#replace the above functino with this one
def foldopt_owa_pgd_solver_wrapper(w, num_item, beta, num_iter, lr):

	# This is for the forward pass 
	# Epsilons are removed from sparsemax input
	# also torch no grad
	FW_solver = owa_fw_solver_wrapper(w, num_item, beta, num_iter)

	# This is for the backward pass
	# It only takes one iteration
	def pgd_solver_cxvpy(c, xk, eps=1e-7, use_subgrad=False, beta=beta):
		for i in range(1): 
			grad = compute_moreau_grad_portfolio(c, xk, beta, w)
			xk = cvxpy_projection((xk + lr*grad))
		return xk

	def update_wrapper(c, xk, eps=1e-7, use_subgrad=False, beta=beta):
		nvar = xk.shape[1]
		c_mat = c.view(c.shape[0],-1,nvar)

		return pgd_solver_cxvpy(c_mat, xk, eps=1e-7, use_subgrad=False, beta=beta)


	# Now create the fold-opt layer
	fwd_solver  = lambda c:    FW_solver(c)
	update_step = lambda c,x: update_wrapper(c,x)

	fold_opt_layer = FoldOptLayer(fwd_solver, update_step, n_iter=500, backprop_rule = "FPI")
	fold_opt_layer_wrapper = lambda c: fold_opt_layer(c.view(c.shape[0],-1))

	return fold_opt_layer_wrapper



def owa_optim_lp_norm_cvxpy(P, Q, eps=1e-2):
	"""
	cvxpy solver of owa max min LP model
	input: 
		C:	torch.tensor([[-7, 3], [6,8], [8, -3], [10,7], [1,5]] ).
		P,Q : C.shape, P_num task, Q_num items. 
		w: permutations(np.array([1.        , 0.80000001, 0.60000002, 0.40000001, 0.2       ]))

	"""
	
	# w_perm = permutations(w)
	# w = gini_indices(C.shape[0]).numpy().astype(float)
	x = cp.Variable(Q, nonneg=True)
	y = cp.Variable(P)
	z = cp.Variable()
	c = cp.Parameter((P,Q))
	w = gini_indices_square(P).numpy().astype(float)
	w_perm = permutations(w)
	constraints = [cp.sum(x) == 1, c@x==y, ]
	for i,p in enumerate(w_perm):
	  constraints.append(z <=p@y)
	obj = z - eps*cp.pnorm(cp.hstack([x,y,z]), p=2)**2
	problem  = cp.Problem(cp.Maximize(obj),  constraints)
	owalp_cvxlayer = CvxpyLayer(problem, parameters=[c], variables=[x, y, z])
	owalp_cvxlayer_wrapper = lambda z: owalp_cvxlayer(z,solver_args={'eps':1e-7, 'max_iters': 200_000})[0]
	# opt = problem.solve()
	# soln_x, soln_y = x.value, y.value
	return owalp_cvxlayer_wrapper


class OWASubgradientLayer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, multi_obj, w_gini):
        ctx.w_gini = w_gini # M x1 
        multi_obj_sigma, sigma = (-multi_obj).sort(descending=True,dim=-1)# B x M
        owa_loss = torch.einsum("m, im-> i", w_gini, torch.sort(multi_obj, dim=-1).values) # Compute OWa value of multi objective
        ctx.save_for_backward(w_gini, sigma)
        return owa_loss
    
    @staticmethod
    def backward(ctx, grad_output):
        w_gini, sigma = ctx.saved_tensors
        sigma_inv = sigma.argsort(dim=-1)
        w_gini_batch = w_gini.repeat(sigma_inv.shape[0], 1)
        w_inv = torch.gather(w_gini_batch, 1,sigma_inv)
        return -w_inv.to(grad_output.device), None, None 


class MoreauOWALossLayer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, multi_obj, w_gini, beta):
        ctx.w_gini = w_gini
        ctx.beta = beta
        # multi_obj_sigma, sigma = (-multi_obj).sort(descending=True,dim=-1)# B x M
        owa_loss = torch.einsum("m, im-> i", w_gini, torch.sort(multi_obj, dim=-1).values) # Compute OWa value of multi objective
        ctx.save_for_backward(multi_obj)
        return owa_loss

    @staticmethod
    def backward(ctx, grad_output):
        multi_obj,= ctx.saved_tensors
        w_gini = ctx.w_gini
        beta = ctx.beta

        z = (-multi_obj/beta)
        grad = compute_Moreau_grad_softsort(w_gini, z)

        # ø = -torch.flip(w_gini, (0,))
        # w_gini_batch = w_gini.repeat(sigma_inv.shape[0], 1)
        # w_inv = torch.gather(w_gini_batch, 1,sigma_inva

        # w = torch.tensor([1.0000, 0.6400, 0.3600, 0.1600, 0.0400])
        # grad = phedron_project(z.data, w_tilder)
        # print('grad output', (w_inv - grad).sum())
        # print('check2',grad[0:3])

        # print('grad gini', grad2)

        return -grad.to(grad_output.device), None, None 








