import torch
from typing import NamedTuple
from utils.boolmask import mask_long2bool, mask_long_scatter
from problems.tsp.problem_tsp import nearest_neighbor_graph
import numpy as np


class StateDTSP(NamedTuple):
    """class used to keep track of DTSP state"""

    # Fixed input
    loc: torch.Tensor
    dist: torch.Tensor

    ids: torch.Tensor # keeps track of original fixed data index of rows

    # State
    first_a: torch.Tensor
    prev_a: torch.Tensor
    visited_: torch.Tensor # keeps track of nodes that have been visited
    lengths: torch.Tensor
    cur_coord: torch.Tensor
    i: torch.Tensor  # Keeps track of step
    graph: torch.Tensor


    @property
    def visited(self):
        if self.visited_.dtype == torch.uint8:
            return self.visited_
        else:
            return mask_long2bool(self.visited_, n=self.loc.size(-2))

    def __getitem__(self, key):
        assert torch.is_tensor(key) or isinstance(key, slice)  # If tensor, idx all tensors by this tensor:
        return self._replace(
            ids=self.ids[key],
            first_a=self.first_a[key],
            prev_a=self.prev_a[key],
            visited_=self.visited_[key],
            lengths=self.lengths[key],
            cur_coord=self.cur_coord[key] if self.cur_coord is not None else None,
            loc = self.loc[key],
            graph = self.graph[key],
        )
    
    @staticmethod
    def initialize(loc, graph, visited_dtype=torch.uint8):

        batch_size, n_loc, _ = loc.size()
        prev_a = torch.zeros(batch_size, 1, dtype=torch.long, device=loc.device)
        return StateDTSP(
            loc=loc,
            dist=(loc[:, :, None, :] - loc[:, None, :, :]).norm(p=2, dim=-1),
            ids=torch.arange(batch_size, dtype=torch.int64, device=loc.device)[:, None],  # Add steps dimension
            first_a=prev_a,
            prev_a=prev_a,
            # Keep visited with depot so we can scatter efficiently (if there is an action for depot)
            visited_=(  # Visited as mask is easier to understand, as long more memory efficient
                torch.zeros(
                    batch_size, 1, n_loc,
                    dtype=torch.uint8, device=loc.device
                )
                if visited_dtype == torch.uint8
                else torch.zeros(batch_size, 1, (n_loc + 63) // 64, dtype=torch.int64, device=loc.device)  # Ceil
            ),
            lengths=torch.zeros(batch_size, 1, device=loc.device),
            cur_coord=None,
            i=torch.zeros(1, dtype=torch.int64, device=loc.device),  # Vector with length num_steps
            graph=graph
        )
    
    def get_final_cost(self):

        assert self.all_finished()
        # assert self.visited_.

        return self.lengths + (self.loc[self.ids, self.first_a, :] - self.cur_coord).norm(p=2, dim=-1)
    
    def update(self, selected):

        prev_a = self.prev_a

        for j, action in enumerate(selected):
            if action is not None:
                # Update the state
                prev_a[j] = selected[j]  # Add dimension for step

        cur_coord = self.loc[self.ids, prev_a]
        lengths = self.lengths
        if self.cur_coord is not None:  # Don't add length for first action (selection of start node)
            lengths = self.lengths + (cur_coord - self.cur_coord).norm(p=2, dim=-1)  # (batch_dim, 1)

        # Update should only be called with just 1 parallel step,
        # in which case we can check this way if we should update
        first_a = prev_a if self.i.item() == 0 else self.first_a

        if self.visited_.dtype == torch.uint8:
            # Add one dimension since we write a single value
            visited_ = self.visited_.scatter(-1, prev_a[:, :, None], 1)
        else:
            visited_ = mask_long_scatter(self.visited_, prev_a)

        return self._replace(first_a=first_a, prev_a=prev_a, visited_=visited_,
                            lengths=lengths, cur_coord=cur_coord, i=self.i + 1)
                
    
    def arrivals(self, new_node_locs):
        # For now, this will just add the new node to the list of nodes and update the mask. Later the graph will need to be updated.
        loc = torch.cat((self.loc, new_node_locs), 1)
        batch_size, _, _ = loc.size()
        visited_ = torch.cat((self.visited_, torch.zeros((batch_size,1,1), device=self.loc.device, dtype=torch.uint8)), 2)
        cpu_locs = loc.cpu()
        graph = ~torch.BoolTensor(np.array([nearest_neighbor_graph(cpu_locs[i], neighbors=0.2, knn_strat='percentage') for i in range(len(loc))]))
        return self._replace(visited_=visited_, loc=loc, graph=graph)

    def all_finished(self):
        return self.visited.all()

    def get_current_node(self):
        return self.prev_a

    def get_mask(self):
        return self.visited
    
    def get_graph_mask(self):
        batch_size, n_loc, _ = self.loc.size()
        if self.i.item() == 0:
            return torch.zeros(batch_size, 1, n_loc, dtype=torch.uint8, device=self.loc.device)
        else:
            return self.graph.gather(1, self.prev_a.unsqueeze(-1).expand(-1, -1, n_loc))
            
    def get_graph(self):
        return self.graph
        
    def get_nn(self, k=None):
        # Insert step dimension
        # Nodes already visited get inf so they do not make it
        if k is None:
            k = self.loc.size(-2) - self.i.item()  # Number of remaining
        return (self.dist[self.ids, :, :] + self.visited.float()[:, :, None, :] * 1e6).topk(k, dim=-1, largest=False)[1]

    def get_nn_current(self, k=None):
        assert False, "Currently not implemented, look into which neighbours to use in step 0?"
        # Note: if this is called in step 0, it will have k nearest neighbours to node 0, which may not be desired
        # so it is probably better to use k = None in the first iteration
        if k is None:
            k = self.loc.size(-2)
        k = min(k, self.loc.size(-2) - self.i.item())  # Number of remaining
        return (
            self.dist[
                self.ids,
                self.prev_a
            ] +
            self.visited.float() * 1e6
        ).topk(k, dim=-1, largest=False)[1]

    def construct_solutions(self, actions):
        return actions