from abc import abstractmethod
import copy
from typing import Dict
import torch
from torch import Tensor
from torch_geometric.data import HeteroData, Batch
from torch_geometric.data.collate import repeat_interleave
from torch_geometric.utils import bipartite_subgraph
from typing_extensions import Self
from torch_geometric.typing import NodeType
from torch_geometric.utils import bipartite_subgraph

class BaseGraph(HeteroData):
    def __init__(self) -> None:
        super().__init__()

    @classmethod
    @abstractmethod
    def from_observation(cls, observation: dict, original=False) -> Self:
        """
        observation -> graph
        """
        pass

    @classmethod
    @abstractmethod
    def combine_graph(cls, original_graph: Self, learned_graph: Self) -> Self:
        """
        original_graph + learned_graph -> combined_graph
        only keep unsatisfied clauses and candidate literals/variables
        """
        pass

    @classmethod
    @abstractmethod
    def update_graph(cls, original_graph: Self, updated_graph: Self) -> Self:
        """
        original_graph + updated_graph -> original_graph (updated)
        """
        pass

    @abstractmethod
    def get_candidate_num(self) -> int:
        """
        return number of unassigned **literals**
        """
        pass

    @abstractmethod
    def get_candidate_indices(self) -> torch.Tensor:
        """
        return the indices of unassigned **literals**
        """
        pass

    @abstractmethod
    def get_candidate_ptr(self) -> torch.Tensor:
        """
        return the pointers to start/end of **literals** indices
        """
        pass

    @abstractmethod
    def get_candidate_batch(self) -> torch.Tensor:
        """
        return the batch vector of **literals**
        """
        pass

    @abstractmethod
    def get_score(self, assignment:torch.Tensor, hard: bool = False, clause_level: bool = False) -> torch.Tensor:
        """
        get score of the assignment
        """
        pass

    @abstractmethod
    def is_trivial(self) -> bool:
        pass

    @classmethod
    def batch_graph(cls, graph_list: list[Self]) -> Batch:
        batch = Batch.from_data_list(graph_list)
        for edge_type in batch.edge_types:
            # edge also need batch and ptr
            edge_repeats = [g[edge_type].num_edges for g in graph_list]
            edge_ptr = torch.tensor([0] + edge_repeats).long().cumsum(dim=0)
            edge_batch = repeat_interleave(edge_repeats)
            batch[edge_type].batch = edge_batch
            batch[edge_type].ptr = edge_ptr
        return batch

    def subgraph(self, subset_dict):
        data = self.clone()
        subset_dict = copy.copy(subset_dict)

        for node_type, subset in subset_dict.items():
            device = subset.device
            if subset.dtype == torch.bool:
                num_nodes = int(subset.sum())
            else:
                num_nodes = subset.size(0)
                subset = torch.unique(subset, sorted=True)
                subset_dict[node_type] = subset

            for key, value in self[node_type].items():
                if key == 'num_nodes':
                    data[node_type].num_nodes = num_nodes
                elif self[node_type].is_node_attr(key):
                    data[node_type][key] = value[subset]
                else:
                    data[node_type][key] = value

        for edge_type in self.edge_types:
            src, _, dst = edge_type

            src_subset = subset_dict.get(src)
            if src_subset is None:
                src_subset = torch.arange(data[src].num_nodes).to(device)
            dst_subset = subset_dict.get(dst)
            if dst_subset is None:
                dst_subset = torch.arange(data[dst].num_nodes).to(device)

            edge_index, _, edge_mask = bipartite_subgraph(
                (src_subset, dst_subset),
                self[edge_type].edge_index,
                relabel_nodes=True,
                size=(self[src].num_nodes, self[dst].num_nodes),
                return_edge_mask=True,
            )

            for key, value in self[edge_type].items():
                if key == 'edge_index':
                    data[edge_type].edge_index = edge_index
                elif self[edge_type].is_edge_attr(key):
                    data[edge_type][key] = value[edge_mask]
                else:
                    data[edge_type][key] = value

        return data