import torch
import torch.nn as nn
import torch.nn.functional as F
import optimization
import frank_wolfe
from fast_soft_sort.pytorch_ops import soft_sort, soft_rank, phedron_project
import fast_soft_sort.numpy_ops as numpy_ops


# rel_score = torch.Tensor([[1,2,3,4,5,6,7],[1,2,3,4,5,6,7]])
# p_rank = F.one_hot(rel_score.argsort(dim=1, descending=True)).float()
# pos_bias = 1. / torch.arange(1., rel_score.shape[-1] + 1)

# u = (torch.einsum("ij, ijk -> ik", rel_score, p_rank) * pos_bias).sum(dim=1) # sort utilties x pos_bias
# #item utility
# v = torch.einsum("ijk, k->j", p_rank, pos_bias)


# print(rel_score)

# print(p_rank)

# print(pos_bias)


# print(v)


# print(u)


# u_closest = optimization.isotonic_regression(u)

# print(u_closest)


# input = torch.Tensor([0,1,2,3,4,5,6])
# print(input)
# output = optimization.isotonic_regression(input)
# print(output)



# layer = optimization.get_isotonic_regression_layer(len(input))

# print('isotonic layer inputs')
# print(input.repeat(2,1))
# print('isotonic layer outputs')
# print(layer(input.repeat(2,1)))


# print('-------sorting test--------')
# def original_order(ordered, indices):
#     return ordered.gather(1, indices.argsort(1))

# def reorder(ordered, indices):
#     return ordered.gather(1, indices)



# a = torch.rand(3,5)*100

# ordered, indices = torch.sort(a,descending=True)

# original = original_order(ordered, indices)

# reordered = reorder( a, indices )


# print('a')
# print( a )
# print('ordered')
# print( ordered )
# print('indices')
# print( indices )
# print('original')
# print( original )
# print('reordered')
# print( reordered )


# print('------------gini_indices test--------------')
# print(frank_wolfe.gini_indices(10))



# print('-------moreau grad test--------')



# w = torch.rand(5)*100
# z = torch.rand(5)*100

# print('w')
# print( w )
# print('z')
# print( z )

# print('entering compute_Moreau_grad')

# y_cvx = frank_wolfe.compute_Moreau_grad(w, z)

# print('exited compute_Moreau_grad')

# print('y_cvx')
# print( y_cvx )





print('-------------------projection test-----------------------')

torch.manual_seed(0)

L = 6

N = 3
s = torch.rand( N, L ).double() * 0.1

rho = torch.flip(torch.arange(s.shape[1], dtype=float64), (0,)) + 1
w = torch.rand( s.shape[1] ).double() *10


print('rho')
print( rho )

print('w')
print( w )

print('s')
print( s )



proj = numpy_ops.Projection(-s[0], input_w=rho)
proj_out = proj.compute()
print('proj_out')
print( proj_out )

print("Testing phedron.....\n")

phedron_out = phedron_project(s, input_w = rho, regularization = 'l2')


print("Testing soft rank....\n")

softrank_out = soft_rank(s, regularization_strength = 1.0, regularization = 'l2')


print('phedron_out')
print( phedron_out )

print('softrank_out')
print( softrank_out )
