import torch
import numpy as np
from utils import nodewt_to_edgewt, maybe_parallelize
from comb_modules.dijkstra import dijkstra
from perturbations import perturbations
import torch.optim as optim


# Takes a numpy array of weights (probably converted directly from Tensor)
def batch_sp(weights):
    return np.asarray(  [dijkstra(wt).shortest_path for wt in weights]  )

def batch_torch_sp(weights):
    weights_np = weights.detach().numpy()
    shortest_paths = batch_sp(weights_np)
    return torch.Tensor(shortest_paths)



class ShortestPathDPO(torch.autograd.Function):
    def __init__(self, lambda_val, neighbourhood_fn="8-grid"):
        self.lambda_val = lambda_val   # this is now the noise value
        self.neighbourhood_fn = neighbourhood_fn
        self.solver = perturbations.perturbed(batch_torch_sp,
                                              num_samples=100,
                                              sigma=0.5, #self.lambda_val
                                              noise='normal',
                                              batched=True  )#,
                                              #device=device)

    def forward(self, weights):

        return  self.solver(weights).float().to(weights.device)

        #self.weights = weights#.detach().cpu() #.numpy()
        #self.suggested_tours = [self.solver(arg) for arg in list(self.weights)]
        #return torch.stack( self.suggested_tours ).float().to(weights.device)


fwd =         perturbations.perturbed(batch_torch_sp,
                                      num_samples=100,
                                      sigma=0.1, #self.lambda_val
                                      noise='normal',
                                      batched=True  )



weights = torch.rand(100,12,12)
weights = weights.detach().cpu().numpy()
suggested_tours = batch_sp(weights)
print("suggested_tours = ")
print( suggested_tours )


sp_dpo = ShortestPathDPO(0.0) # meaningless lambda
mse = torch.nn.MSELoss()
weights = torch.Tensor( weights )
targets = torch.stack( [torch.eye(weights.shape[1]) for wt in weights] )
cost_params = torch.nn.Parameter( weights )
optimizer = optim.SGD(  [cost_params], lr=0.05  )


with torch.no_grad():
    shortest_paths = fwd(cost_params)
    loss = mse(shortest_paths,targets)
    print("loss = {}".format(loss.item()))


# Backprop test


for i in range(10):
    #shortest_paths =sp_dpo(cost_params)
    shortest_paths = fwd(cost_params)
    loss = mse(shortest_paths,targets)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print("loss = {}".format(loss.item()))
