from collections import deque
from typing import Union

import torch
from problems.GC.gc_state import GCState
from dataclasses import dataclass
from problems.problem_base import Problem
from trainer.utils import get_adjacency_list


class DefectiveGC(Problem):
    NAME = "DefectiveGC"

    # TODO Use
    defectiveness: int

    #def __init__(self, defectiveness: int = 1):
    #    super(DefectiveGC, self).__init__()
    #    self.defectiveness = defectiveness

    def make_state(self, *args):
        return DefectiveGCState.initialize(1, *args)


@dataclass
class DefectiveGCState(GCState):

    def concat_embeddings(self):
        return True

    defectiveness: int

    @staticmethod
    def initialize(defectiveness: int, input, embeddings: torch.Tensor, n_nodes: int, throwback_size: int, decoding_type: str):
        batch_size = input.num_graphs
        prev_a = torch.zeros(batch_size, 1, dtype=torch.int, device=embeddings.device)
        return DefectiveGCState(
            adj=get_adjacency_list(input, embeddings.device),
            embeddings=embeddings,
            visited=torch.zeros(batch_size, 1, n_nodes, device=embeddings.device, requires_grad=False).type(
                torch.uint8),
            sol_ind=torch.zeros(batch_size, 1, n_nodes, dtype=torch.int32, device=embeddings.device,
                                requires_grad=False) - 1,
            sol_rep={k: [] for k in range(batch_size)},
            prev_a=prev_a,
            step=torch.zeros(1, device=embeddings.device, requires_grad=False),
            neighbor_colors={i: {j: list() for j in range(n_nodes)} for i in range(batch_size)}, # Note: we need a list here!
            update_nodes=[],
            context_throwback=deque(maxlen=throwback_size),
            decoding_type=decoding_type,
            defectiveness=defectiveness
        )

    def color_is_admissible(self, color: int, neighbor_colors: Union[list, set]):
        #print("CHECK COLOR", color, neighbor_colors, neighbor_colors.count(color), self.defectiveness)
        return neighbor_colors.count(color) <= self.defectiveness

    def update_neighbor_colors(self, color, batch_index, node_index):
        #print("Update COLOR", color, batch_index, node_index)
        self.neighbor_colors[batch_index][node_index].append(color)
        #print(self.neighbor_colors[batch_index][node_index])
