import torch
import torch.nn as nn
import torch.nn.functional as F
import optimization
import numpy as np
from fast_soft_sort.pytorch_ops import soft_sort, soft_rank, phedron_project


def compute_Moreau_grad(w, z):
	"""
	compute gradient of Monreau envelop of owa function using Pav algorithm.
	input:
		w : owa weight
		z : utility vector
	return: y: gradient (R^n)
	"""

	# z_sorted, sigma = torch.sort(z, descending=True)
	z_sigma, sigma = torch.sort(z, descending=True) # M:  utility is 1D vector

	# sigma_inv = sigma.argsort(1)   # Check this, I think argsort makes this the inverse permutation
	sigma_inv = sigma.argsort() #M
	# z_sigma = optimization.reorder(z, sigma)
	#w_ = -torch.sort(w, descending=False).values   #sort w decreasingly )
	w_ = -torch.flip(w,(0,))   # reverse the order of w


	# cvxlayer_iso = optimization.get_isotonic_regression_layer(z.shape[1])
	cvxlayer_iso = optimization.get_isotonic_regression_layer(z.shape[0])

	iso = lambda p: cvxlayer_iso(p)[0]
	x   = iso( z_sigma - w_ )

	# x_sigma_inv = optimization.reorder(x,sigma_inv) #M
	x_sigma_inv = x[sigma_inv] #
	y = x_sigma_inv + z


	return y  # y isn't yet the Moreau grad, which comes from the next step in FWS





def compute_Moreau_grad_softsort(w, z):
	"""
	compute gradient of Monreau envelop of owa function using Pav algorithm.
	input:
		w : owa weight
		z : utility vector
	return: y: gradient (R^n)
	"""
	w_tilde = -torch.flip(w, (0,))
	if len(z.shape) == 1: 
		phedron_out = phedron_project(z.unsqueeze(0), input_w = w, regularization = 'l2')
	else: 
		phedron_out = phedron_project(z, input_w = w, regularization = 'l2')
	return phedron_out.squeeze()  # y isn't yet the Moreau grad, which comes from the next step in FWS




"""
Find the maximum distance between two points on the permutahedron
"""
def permutahedron_diameter(N):
	# TODO: solve this exactly (we think this is an upper bound)
	return 2*N # placeholder


"""
Get beta_0 as in Proposition 5
"""
def initialize_beta(w, pos_bias):
	P_diam = permutahedron_diameter(len(w))
	b1 = pos_bias[0]
	return 2*b1*P_diam / w.norm(p=2)


def compute_owa(w, u, group_item_mask=None): 
	if len(u.shape) == 1:
		if group_item_mask is not None: 
			u = (u*group_item_mask).sum(dim=1)/group_item_mask.sum(dim=1)
		return w@torch.sort(u).values
	else: 
		return torch.einsum('j, ij->ij', w, torch.sort(u, dim=-1).values).sum(dim=1)


def compute_user_util(rel_score, p_rank, pos_bias):
    return torch.einsum("ij, ijk -> ik", rel_score, p_rank)@pos_bias


def compute_user_util_fast(rel_score, sigma, pos_bias):
    return torch.gather(rel_score, 1, sigma)@pos_bias


def compute_item_util(p_rank, pos_bias, group_item_mask, batchify=True, merits=None): 
    if batchify:
        v = torch.einsum("ijk, k->ij", p_rank, pos_bias)
        if group_item_mask is not None: 
            v_group_sum =  torch.einsum("ij, imj->imj", v, group_item_mask).sum(dim=-1)
            group_cnt_inv = 1/group_item_mask.sum(dim=-1)
            v_group = torch.einsum('im, im->im',v_group_sum, group_cnt_inv)
            v_group = torch.einsum("im, m->im", v_group, 1/merits)
    else: 
        v = torch.einsum("ijk, k->j", p_rank, pos_bias)
        #transform v to represent group mean instead to compute projection onto w_tilde
        if group_item_mask is not None: 
            group_item_mask = group_item_mask.squeeze()
            # num_group x num_item: row corespond to group i (ordered group), col is item
            v_group = (v*group_item_mask).sum(dim=1)/group_item_mask.sum(dim=1)
    return v, v_group

def compute_item_util_fast(sigma, pos_bias_mat, group_item_mask, merits): 
    """
    only work for batchify version that use local item exposure
    """
    sigma_inv = sigma.argsort()
    v = torch.gather(pos_bias_mat, 1, sigma_inv)
    if group_item_mask is not None: 
        v_group_sum =  torch.einsum("ij, imj->imj", v, group_item_mask).sum(dim=-1)
        group_cnt_inv = 1/group_item_mask.sum(dim=-1)
        v_group = torch.einsum('im, im->im',v_group_sum, group_cnt_inv)
        v_group = torch.einsum("im, m->im", v_group, 1/merits)
    return v, v_group


def FWS(rel_score, w_users, w_items, num_iter, lamb, group_item_mask=None, beta=None):
	"""
	return optimal solution of OWA function using FW smoothing
	input:
		relevant_score matrix nxm for n user, m items
		num_iter: # iterations for algorithms to run
		beta: initial beta for convergent
		lamb: weight for user and item utilities
		num_users and num_items are implied by the len of w_users, w_items
		beta


	output: permutation matrix

	"""

	# TODO:
	# Assertions on relative dimension of w_users, w_items, rel_score


	#p is 3d tensor n x m xm
	# initialize p0 s.t p_i(0) sort mu_i in decreasing order
	# p0 = F.one_hot(rel_score.argsort(dim=1, descending=True)).float() # permutation matrix
	rel_score  = rel_score/rel_score.norm()

	p_rank = F.one_hot(rel_score.argsort(dim=-1, descending=True)).double()
	p_rank = torch.einsum("ijk->ikj", p_rank)# Pijk: ranking matrix

	pos_bias =1.0 / torch.log2(torch.arange(rel_score.shape[1]).double() + 2)

	if beta is None: 
		beta = initialize_beta(w_items, pos_bias)

	# beta = 0.1
	# iter_util_user, iter_util_item, iter_item_stat = [], [], []

	# cur_util_user = torch.einsum("ij, ijk -> ik", rel_score, p_rank)@pos_bias
	# cur_util_item = torch.einsum("ijk, k->j", p_rank, pos_bias)
	# owa_item = compute_owa(w_items, cur_util_item, group_item_mask)
	# iter_util_user.append(cur_util_user.sum().item())
	# iter_util_item.append(owa_item.item())
	# iter_item_stat.append(cur_util_item.numpy())

	for t in range(1, num_iter+1):
		if beta is None: 
			beta_t = beta/np.sqrt(t)
		else: 
			beta_t = beta
		# if beta_t < 1e-3: 
		# 	beta_t = 1e-3
		#user utility
		u = torch.einsum("ij, ijk -> ik", rel_score, p_rank)@pos_bias
		#item utility : 
		v = torch.einsum("ijk, k->j", p_rank, pos_bias)
		#transform v to represent group mean instead to compute projection onto w_tilde
		if group_item_mask is not None: 
			group_item_mask = group_item_mask.squeeze()
			# num_group x num_item: row corespond to group i (ordered group), col is item
			# v: mean group exposure
			v = (v*group_item_mask).sum(dim=1)/group_item_mask.sum(dim=1)
		#compute sol isotonic regression
		y1 = compute_Moreau_grad_softsort(w_users, u/beta_t)# dim 1 x num_users
		y2 = compute_Moreau_grad_softsort(w_items, v/beta_t)#dim 1 x num_items or num_group
		if w_users.shape[0] == 1: 
			y1 = y1.unsqueeze(dim=0)
		mu1 = (1-lamb)*torch.einsum("ij, i-> ij",rel_score, y1)
		mu2 = lamb * y2
		if group_item_mask is not None: 
			masked_user = torch.einsum("ij, kj->ikj", mu1, group_item_mask) # (num_user x num_group xnum_item)
			masked_item = torch.einsum("k, kj->kj", mu2, group_item_mask) # num group x num item 
			mu_hat = (masked_user + masked_item).sum(dim=1) # add corresponding item exposure to per user utilities
		else:
			mu_hat = mu1 + mu2
		# print('*******************iter {}**************************'.format(t))
		# print("mu hat:  ", mu_hat)
		# print("gradient :  ", y1)
		#Qt is obtained by sorting mu_hat in increasing order
		Qt= F.one_hot((-mu_hat).argsort(dim=1, descending=True)).double() #double check with J again
		Qt = torch.einsum("ijk->ikj", Qt)
		# print('Qt ', Qt)
		p_rank = (1 - 2/(t+2))*p_rank + 2/(t+2)*Qt
		# print("iter p_rank :  ")
		# print(p_rank)

		# owa_item = compute_owa(w_items, cur_util_item, group_item_mask)
		# iter_util_user.append(cur_util_user.sum().item())
		# iter_util_item.append(owa_item.item())
		# iter_item_stat.append(cur_util_item.numpy())

	final_item_exp = torch.einsum("ijk, k->j", p_rank, pos_bias)
	# print('Final mean group exposure: ', (cur_util_item*group_item_mask).sum(dim=1)/group_item_mask.sum(dim=1))
	# return p_rank, iter_util_user, iter_util_item, final_item_exp.numpy(), iter_item_stat
	owa_objective = compute_owa(w_items, final_item_exp, group_item_mask)

	return p_rank[0], final_item_exp, owa_objective #since we only have one user 





def FWS_batch(rel_score, w_users, w_items, num_iter, lamb, p_rank=None,group_item_mask=None, beta=None, return_exp=True, merits=None, use_initial_beta=True):
	"""
	return optimal solution of OWA function using FW
	input:
		relevant_score matrix nxm for n user, m items
		num_iter: # iterations for algorithms to run
		beta: initial beta for convergent
		lamb: weight for user and item utilities
		num_users and num_items are implied by the len of w_users, w_items
		beta:  smoothing parameter, if None use FWS method else use FW method  
		group_item_mask: num query x num group x num item, mask tensor for item group identity

	output: permutation matrix

	"""

	#p is 3d tensor n x m xm
	# initialize p0 s.t p_i(0) sort mu_i in decreasing order

	pos_bias =1.0 / torch.log2(torch.arange(rel_score.shape[1]).double() + 2)

	rel_score  = rel_score/rel_score.norm()
	if p_rank is None:
		p_rank = F.one_hot(rel_score.argsort(dim=-1, descending=True)).double()
		p_rank = torch.einsum("ijk->ikj", p_rank)# Pijk: ranking matrix

	smoothing = True if beta >=0 else False
	if not smoothing: 
		use_initial_beta = False

	if not use_initial_beta: 
	    beta = initialize_beta(w_items, pos_bias)

	for t in range(1, num_iter+1):
		if smoothing | use_initial_beta: 
		    beta_t = beta/np.sqrt(t)
		else: 
		    beta_t = beta
		u = compute_user_util(rel_score, p_rank, pos_bias)
		#item utility : 
		# v = torch.einsum("ijk, k->j", p_rank, pos_bias)
		v, v_group = compute_item_util(p_rank, pos_bias, group_item_mask, batchify=True, merits=merits)
		#compute sol isotonic regression
		y1 = -compute_Moreau_grad_softsort(w_users, -u/beta_t)# dim 1 x num_users
		y2 = -compute_Moreau_grad_softsort(w_items, -v_group/beta_t)#dim num_user x num_items or num_group
		if rel_score.shape[0] == 1: 
		    y1 = y1.unsqueeze(dim=0)
		    y2  = y2.unsqueeze(dim=0)
		mu1 = (1-lamb)*torch.einsum("ij, i-> ij",rel_score, y1)
		mu2 = lamb * y2
		if group_item_mask is not None: 
			masked_user = torch.einsum("ij, imj->imj", mu1, group_item_mask) # (num_user x num_group xnum_item)
			masked_item = torch.einsum("im, imj->imj", mu2, group_item_mask) #num_user x num group x num_item
			mu_hat = (masked_user + masked_item).sum(dim=1) # num_user x num_item
		else:
			mu_hat = mu1 + mu2
		# print('*******************iter {}**************************'.format(t))

		#Qt is obtained by sorting mu_hat in increasing order
		sigma_hat= (-mu_hat).argsort(dim=1, descending=True) 
		sigma_hat_inv = sigma_hat.argsort()
		Qt= F.one_hot(sigma_hat_inv).double() 
		# print('Qt ', Qt)
		p_rank = (1 - 2/(t+2))*p_rank + 2/(t+2)*Qt	

	u = compute_user_util(rel_score, p_rank, pos_bias)
	v, v_group = compute_item_util(p_rank, pos_bias, group_item_mask, merits=merits)
	owa_objective = compute_owa(w_items, v_group, group_item_mask)


	if return_exp: 
	    return p_rank, v, v_group, owa_objective
	else: 
	    return p_rank



def FWS_batch_fast(rel_score, w_users, w_items, num_iter, lamb, 
    p_rank=None, group_item_mask=None, beta=None, return_exp=True, 
    use_initial_beta=True,tol=2e-3, is_train=True, merits=None):
    """
    return optimal solution of OWA function using FW
    input:
        relevant_score matrix nxm for n user, m items
        num_iter: # iterations for algorithms to run
        beta: initial beta for convergent
        lamb: weight for user and item utilities
        p_rank: if None, fw method, else pgd on fix point
        num_users and num_items are implied by the len of w_users, w_items
        beta:  smoothing parameter, if None use FWS method else use FW method  
        group_item_mask: num query x num group x num item, mask tensor for item group identity
        smoothing: set True if use FWS smoothing,otherwise False
    output: permutation matrix

    """
    stopping_iter = 0
    rel_score  = torch.nn.functional.softmax(rel_score, dim=-1)
    # print('rel_score', rel_score)
    # print('rel_score norm', rel_score.norm(dim=1, keepdim=True))
    sigma = rel_score.argsort(dim=-1, descending=True)
    p_rank = F.one_hot(sigma).double()
    p_rank = torch.einsum("ijk->ikj", p_rank)
    pos_bias =1.0 / torch.log2(torch.arange(rel_score.shape[1]).double() + 2)
    pos_bias_mat = pos_bias.repeat(rel_score.shape[0], 1)
    
    smoothing = True if beta >=0 else False
    if not smoothing: 
        use_initial_beta = False
    if (not use_initial_beta): #for regular FW
        beta = initialize_beta(w_items, pos_bias)

    u = compute_user_util_fast(rel_score, sigma, pos_bias)
    
    v, v_group= compute_item_util_fast(sigma, pos_bias_mat, group_item_mask, merits)

    if (num_iter == 0) | (lamb==0):
        if return_exp: 
            owa_objective = compute_owa(w_items, v_group, group_item_mask)
            return p_rank, v, v_group, owa_objective
        else: 
            return p_rank       

    for t in range(1, num_iter+1):
        if smoothing | use_initial_beta: 
            beta_t = beta/np.sqrt(t)
        else: 
            beta_t = beta
        if t > 1: 
        	# u2 = compute_user_util(rel_score, p_rank, pos_bias)
        	# v, v_group= compute_item_util(p_rank, pos_bias, group_item_mask, batchify=True, merits=merits)

	        u = t/(t+2)*u + 2/(t+2)*torch.gather(rel_score, 1, sigma_hat)@pos_bias
	        v = t/(t+2)*v + 2/(t+2) * torch.gather(pos_bias_mat, 1, sigma_hat_inv)
	        if group_item_mask is not None: 
	            v_group_sum =  torch.einsum("ij, imj->imj", v, group_item_mask).sum(dim=-1)
	            group_cnt_inv = 1/group_item_mask.sum(dim=-1)
	            v_group = torch.einsum('im, im->im',v_group_sum, group_cnt_inv)
	            v_group = torch.einsum("im, m->im", v_group, 1/merits)

        #compute sol isotonic regression2/(t-1+2)
        y1 = -compute_Moreau_grad_softsort(w_users, -u/beta_t)# dim 1 x num_users
        y2 = -compute_Moreau_grad_softsort(w_items, -v_group/beta_t)# adding negative sign as chain rules, and negative sign inside because calling projection code of blondel

        if rel_score.shape[0] == 1: 
            y1 = y1.unsqueeze(dim=0)
            y2  = y2.unsqueeze(dim=0)
        mu1 = (1-lamb)*torch.einsum("ij, i-> ij",rel_score, y1)
        mu2 = lamb * y2

        if group_item_mask is not None: 
            masked_user = torch.einsum("ij, imj->imj", mu1, group_item_mask) # (num_user x num_group xnum_item)
            masked_item = torch.einsum("im, imj->imj", mu2, group_item_mask) #num_user x num group x num_item
            mu_hat = (masked_user + masked_item).sum(dim=1) # num_user x num_item
        else:
            mu_hat = mu1 + mu2
        # print('*******************iter {}**************************'.format(t))
        sigma_hat= (-mu_hat).argsort(dim=1, descending=True) 
        sigma_hat_inv = sigma_hat.argsort()
        Qt= F.one_hot(sigma_hat_inv).double() 
        p_rank = t/(t+2)*p_rank + 2/(t+2)*Qt


    u = t/(t+2)*u + 2/(t+2)*torch.gather(rel_score, 1, sigma_hat)@pos_bias
    v = t/(t+2)*v + 2/(t+2) * torch.gather(pos_bias_mat, 1, sigma_hat_inv)
    if group_item_mask is not None: 
        v_group_sum =  torch.einsum("ij, imj->imj", v, group_item_mask).sum(dim=-1)
        group_cnt_inv = 1/group_item_mask.sum(dim=-1)
        v_group = torch.einsum('im, im->im',v_group_sum, group_cnt_inv)
        v_group = torch.einsum("im, m->im", v_group, 1/merits)
    owa_objective = compute_owa(w_items, v_group, group_item_mask)
    if return_exp: 
        return p_rank, v, v_group, owa_objective
    else: 
        return p_rank


# user_util = []