import torch
import torch_geometric
from dataclasses import dataclass
from trainer.utils import get_adjacency_list
from problems.problem_base import Problem, State
from collections import deque
from torch_geometric.transforms import LocalDegreeProfile
from scipy import stats

class MVC(Problem):
    NAME = "MVC"

    '''
    Overall not very efficient. 
    Goal: Find better method to check if vertex cover is achieved for individual graphs in the batch.
    '''

    @staticmethod
    def make_state(*args):
        return MVCState.initialize(*args)


@dataclass
class MVCState(State):
    update_nodes: list
    order: list
    context_throwback: deque
    decoding_type: str
    input:torch_geometric.data.Data

    def concat_embeddings(self):
        return False

    @staticmethod
    def initialize(input, embeddings, n_nodes, throwback_size, decoding_type):
        batch_size = input.num_graphs
        prev_a = torch.zeros(batch_size, 1, dtype=torch.int, device=embeddings.device)

        return MVCState(
            adj=get_adjacency_list(input, embeddings.device),  # neighbors of dummy node = []
            input=input,
            embeddings=embeddings,
            visited=torch.zeros(batch_size, 1, n_nodes, device=embeddings.device).type(torch.uint8),
            sol_ind=torch.zeros(batch_size, 1, n_nodes, dtype=torch.int32, device=embeddings.device),
            sol_rep=torch.zeros(batch_size, 1, embeddings.size(-1), device=embeddings.device) - 100000,
            prev_a=prev_a,
            step=torch.zeros(1, device=embeddings.device),
            update_nodes=[],
            order=[],
            context_throwback=deque(maxlen=throwback_size),
            decoding_type=decoding_type
        )

    def update(self, selected):
        self.prev_a = selected[:, None]
        self.order.append(selected)
        self.visited = self.visited.scatter(-1, self.prev_a[:, :, None], 1)
        # self.sol_ind = self.sol_ind.scatter(-1, self.prev_a[:, :, None], 1)
        if self.decoding_type == 'local':
            self.update_nodes = []
            self.get_update_nodes(selected)

        if self.decoding_type != 'static':
            batch_size = selected.size()[0]
            last_action_embedding = self.embeddings.gather(1,
                                                           self.prev_a[:, None].expand(batch_size, 1,
                                                                                       self.embeddings.size(-1)))
            self.sol_rep = torch.max(self.sol_rep, last_action_embedding)

        self.step = self.step + 1

    def get_update_nodes(self, selected):

        n_nodes = self.visited.size()[-1]
        self.update_nodes = [y for x in [self.adj[i * n_nodes + node] for i, node in enumerate(selected)] for y in x]

    def is_done(self):
        return self.step.item() >= self.visited.size(-1)

    def get_cost(self):

        graphs = self.input.to_data_list()

        def get_order_cost(i, pi):
            graph = graphs[i]
            n = graph.num_nodes
            #subset = []

            edge_index = graph.edge_index
            #mask = torch.zeros(n, dtype=torch.bool)

            # Perform a binary search for the first index that leads to a vertex cover
            lo = 0
            hi = n
            while hi-lo >= 1:
                # Invariant: the set [pi[j] for j in range(0, lo) is not a vertex cover
                # Invariant: the set [pi[j] for j in range(0, hi) IS a vertex cover
                mid = lo + int((hi-lo)/2)

                assert lo <= mid
                assert mid < hi

                subset = [pi[j] for j in range(0, mid)]
                mask = torch.zeros(n, dtype=torch.bool)
                mask[subset] = 1
                covered_edges = mask[edge_index[0]] | mask[edge_index[1]]
                if torch.all(covered_edges):
                    hi = mid
                else:
                    lo = mid+1

            assert hi == lo

            return [pi[j] for j in range(0, lo)]

            #for v in pi:
            #    mask[subset] = 1
            #    covered_edges = mask[edge_index[0]] | mask[edge_index[1]]
            #    if torch.all(covered_edges):
            #        return len(subset)
            #    subset.append(v.item())

        if self.is_done():
                    order = torch.stack(self.order, 1)
                    return torch.tensor([get_order_cost(graph_ind, pi) for graph_ind, pi in enumerate(order)],
                                        dtype=torch.float)


def pp_state(state):
    print("prev_a: ", state.prev_a)
    print("visited: ", state.visited)
    print("sol_ind: ", state.sol_ind)
    print("update_nodes: ", state.update_nodes)
    print("edge_sets: ", state.edge_sets)


