import numpy as np
import heapq
import torch
from functools import partial
from comb_modules.utils import get_neighbourhood_func
from collections import namedtuple
from utils import maybe_parallelize
from perturbations import perturbations # JK 02/21

DijkstraOutput = namedtuple("DijkstraOutput", ["shortest_path", "is_unique", "transitions"])

def dijkstra(matrix, neighbourhood_fn="8-grid", request_transitions=False):

    #print("In dijkstra")


    #print("matrix = ")
    #print( matrix )
    #print("In dijkstra:")
    #print("matrix.shape = ")
    #print( matrix.shape )
    #input("waiting in dijkstra")

    x_max, y_max = matrix.shape
    neighbors_func = partial(get_neighbourhood_func(neighbourhood_fn), x_max=x_max, y_max=y_max)

    costs = np.full_like(matrix, 1.0e10)
    costs[0][0] = matrix[0][0]
    num_path = np.zeros_like(matrix)
    num_path[0][0] = 1
    priority_queue = [(matrix[0][0], (0, 0))]
    certain = set()
    transitions = dict()

    while priority_queue:
        cur_cost, (cur_x, cur_y) = heapq.heappop(priority_queue)
        if (cur_x, cur_y) in certain:
            pass

        for x, y in neighbors_func(cur_x, cur_y):
            if (x, y) not in certain:
                if matrix[x][y] + costs[cur_x][cur_y] < costs[x][y]:
                    costs[x][y] = matrix[x][y] + costs[cur_x][cur_y]
                    heapq.heappush(priority_queue, (costs[x][y], (x, y)))
                    transitions[(x, y)] = (cur_x, cur_y)
                    num_path[x, y] = num_path[cur_x, cur_y]
                elif matrix[x][y] + costs[cur_x][cur_y] == costs[x][y]:
                    num_path[x, y] += 1

        certain.add((cur_x, cur_y))
    # retrieve the path
    cur_x, cur_y = x_max - 1, y_max - 1
    on_path = np.zeros_like(matrix)
    on_path[-1][-1] = 1
    while (cur_x, cur_y) != (0, 0):
        cur_x, cur_y = transitions[(cur_x, cur_y)]
        on_path[cur_x, cur_y] = 1.0

    is_unique = num_path[-1, -1] == 1

    if request_transitions:
        return DijkstraOutput(shortest_path=on_path, is_unique=is_unique, transitions=transitions)
    else:
        return DijkstraOutput(shortest_path=on_path, is_unique=is_unique, transitions=None)


def get_solver(neighbourhood_fn):
    def solver(matrix):
        return dijkstra(matrix, neighbourhood_fn).shortest_path

    return solver

# JK 02/22
def get_solver_torch(neighbourhood_fn):
    def solver(matrix):
        #matrix = np.array(matrix)  #maybe this line isn't needed
        result = np.asarray( [dijkstra(mat, neighbourhood_fn).shortest_path for mat in matrix] )
        result = torch.Tensor( result )
        return result
        #result = dijkstra( matrix, neighbourhood_fn).shortest_path
        #return torch.Tensor( result )

    return solver


# Shortest Path layer use blackbox differentiation solver
class ShortestPath(torch.autograd.Function):
    # def __init__(self, lambda_val, neighbourhood_fn="8-grid"):
    #     self.lambda_val = lambda_val
    #     self.neighbourhood_fn = neighbourhood_fn
    #     self.solver = get_solver(neighbourhood_fn)

    @staticmethod
    def forward(ctx, weights, solver, lambda_val):
        ctx.solver = solver
        ctx.lambda_val = lambda_val
        suggested_tours = np.asarray(maybe_parallelize(solver, arg_list=list(weights.detach().cpu().numpy()
)))
        ctx.save_for_backward(weights, torch.from_numpy(suggested_tours).float())
        return torch.from_numpy(suggested_tours).float()
    
    @staticmethod
    def backward(ctx, grad_output):
        weights, suggested_tours = ctx.saved_tensors
        weights = weights.detach().cpu().numpy()

        assert grad_output.shape == suggested_tours.shape
        grad_output_numpy = grad_output.detach().cpu().numpy()
        weights_prime = np.maximum(weights + ctx.lambda_val * grad_output_numpy, 0.0)
        better_paths = np.asarray(maybe_parallelize(ctx.solver, arg_list=list(weights_prime)))
        gradient = -(suggested_tours - torch.from_numpy(better_paths).float()) / ctx.lambda_val
        # return torch.from_numpy(gradient).to(grad_output.device)
        return gradient.to(grad_output.device), None, None


# JK 02/21
class ShortestPathDPO(torch.autograd.Function):
    def __init__(self, lambda_val, neighbourhood_fn="8-grid"):
        self.lambda_val = lambda_val   # this is now the noise value
        self.neighbourhood_fn = neighbourhood_fn
        #self.solver = get_solver(neighbourhood_fn)
        solverget   = get_solver_torch(neighbourhood_fn)
        self.solver = perturbations.perturbed(solverget,
                                              num_samples=5,
                                              sigma=0.5, #self.lambda_val
                                              noise='gumbel',
                                              batched=False  )#,
                                              #device=device)

    def forward(weights):
        print("new fwd")
        self.weights = weights#.detach().cpu() #.numpy()
        #return self.solver( self.weights ).to(weights.device)
        #self.suggested_tours = np.asarray(maybe_parallelize(self.solver, arg_list=list(self.weights)))

        self.suggested_tours = [self.solver(arg) for arg in list(self.weights)]
        return torch.stack( self.suggested_tours ).float().to(weights.device)

        #return torch.from_numpy(self.suggested_tours).float().to(weights.device)