from typing import Union

import torch
import torch_geometric
import itertools
from dataclasses import dataclass
from trainer.utils import get_adjacency_list
from problems.problem_base import Problem, State
from collections import deque
from torch_geometric.transforms import LocalDegreeProfile
from scipy import stats

@dataclass
class GCState(State):
    neighbor_colors: dict
    update_nodes: list
    context_throwback: deque
    decoding_type: str

    def concat_embeddings(self):
        return True

    @staticmethod
    def initialize(defectiveness, input, embeddings: torch.Tensor, n_nodes: int, throwback_size: int, decoding_type: str):
        batch_size = input.num_graphs
        prev_a = torch.zeros(batch_size, 1, dtype=torch.int, device=embeddings.device)
        return GCState(
            adj=get_adjacency_list(input, embeddings.device),
            embeddings=embeddings,
            visited=torch.zeros(batch_size, 1, n_nodes, device=embeddings.device, requires_grad=False).type(
                torch.uint8),
            sol_ind=torch.zeros(batch_size, 1, n_nodes, dtype=torch.int32, device=embeddings.device,
                                requires_grad=False) - 1,
            sol_rep={k: [] for k in range(batch_size)},
            prev_a=prev_a,
            step=torch.zeros(1, device=embeddings.device, requires_grad=False),
            neighbor_colors={i: {j: set() for j in range(n_nodes)} for i in range(batch_size)},
            update_nodes=[],
            context_throwback=deque(maxlen=throwback_size),
            decoding_type=decoding_type,
    )

    # the transition function
    def update(self, selected):

        self.prev_a = selected[:, None]
        self.visited = self.visited.scatter(-1, self.prev_a[:, :, None], 1)
        if self.decoding_type == 'local':
            self.update_nodes = []
            self.get_update_nodes(selected)

        colors = torch.tensor([self.color_node(i, node) for i, node in enumerate(selected)],
                              device=self.prev_a.device, dtype=torch.int, requires_grad=False)

        self.sol_ind = self.sol_ind.scatter(-1, self.prev_a[:, :, None], colors[:, None, None]).type(torch.int32)

        # update solution representation
        '''i = 0  # index of the current graph
        for node, color in zip(selected, colors):
            if color >= len(self.sol_rep[i]):
                # assert color == len(self.sol_rep[i])
                # create a new partial solution set
                self.sol_rep[i].append(self.embeddings[i, node.item()])
            else:
                self.sol_rep[i][color.item()] = torch.max(self.sol_rep[i][color.item()],
                                                          self.embeddings[i, node.item()])
            i += 1'''

        self.step = self.step + 1

    def get_update_nodes(self, selected):

        n_nodes = self.visited.size()[-1]
        self.update_nodes = [y for x in [self.adj[i * n_nodes + node] for i, node in enumerate(selected)] for y in x]

    def color_node(self, i: int, node: torch.Tensor):
        """
        @param i index of the graph (into the batch of graphs)
        @param node the node to be colored
        """
        n_nodes = self.visited.size()[-1]
        nc_current = self.neighbor_colors[i][node.item()]

        # Find the first unused color
        for color in itertools.count():
            # test if color is admissible
            if self.color_is_admissible(color, nc_current):
                # update neighbors of node
                for u in self.adj[i * n_nodes + node]:
                    self.update_neighbor_colors(color, i, u.item() % n_nodes)

                if self.decoding_type == 'static':
                    '''
                    There is no solution representation required in static decoding mode.
                    Hence, one can just append a single number for each new color in the graph.
                    '''
                    if color >= len(self.sol_rep[i]):
                        self.sol_rep[i].append(1)

                    return color

                # update solution representation
                if color >= len(self.sol_rep[i]):
                    # assert color == len(self.sol_rep[i])
                    # create a new partial solution set
                    self.sol_rep[i].append(self.embeddings[i, node.item()])
                else:
                    self.sol_rep[i][color] = torch.max(self.sol_rep[i][color], self.embeddings[i, node.item()])

                return color

    def is_done(self):
        return self.step.item() >= self.visited.size(-1)

    def get_cost(self):
        if self.is_done():
            return torch.tensor([len(self.sol_rep[k]) for k in range(self.visited.shape[0])], dtype=torch.float,
                                device=self.prev_a.device, requires_grad=False)

    def color_is_admissible(self, color: int, neighbor_colors: Union[list, set]):
        """
        @param color the color to test
        @param neighbor_colors the list or set of colors of the neighbors
        @return true if the color is admissible, false otherwise
        """
        return color not in neighbor_colors

    def update_neighbor_colors(self, color, batch_index, node_index):
        self.neighbor_colors[batch_index][node_index].add(color)


def pp_state(state):
    print("prev_a: ", state.prev_a)
    print("visited: ", state.visited)
    print("sol_ind: ", state.sol_ind)
    print('neighbor colors: ', state.neighbor_colors)
    print("update_nodes: ", state.update_nodes)
