from abc import abstractmethod, ABC

import torch

from adl4cv.classification.log.logger import init_logger
from adl4cv.classification.model.graph._graph.edge_attrib import EdgeAttributeType
from adl4cv.parameters.params import DefinitionSet, HyperParameterSet
from adl4cv.utils.utils import SerializableEnum


class GraphType(SerializableEnum):
    DENSE = "dense"
    SSLDENSE = "ssldense"
    DISTINCT = "distinct"


class GraphBuilderHyperParameterSet(HyperParameterSet):
    """HyperParameterSet of the GeneralNet"""

    def __init__(self,
                 edge_attrib_def: DefinitionSet = None,
                 **kwargs):
        """
        Creates new HyperParameterSet
        :param classif_net_def: The definition of the feature extraction part
        :param mpn_net_def: The definition of the message passing network part
        :func:`~ClassificationModuleHyperParameterSet.__init__`
        """
        super().__init__(**kwargs)
        self.edge_attrib_def = edge_attrib_def


class GraphBuilder(ABC):
    def __init__(self, hyperparams: GraphBuilderHyperParameterSet = GraphBuilderHyperParameterSet()):
        self.params = hyperparams
        self._edge_attrib = None
        self.module_logger = init_logger(self.__class__.__name__)
        self.device = None

    @property
    def edge_attrib(self):
        if self._edge_attrib is None and self.params.edge_attrib_def.type is not EdgeAttributeType.NO_EDGE_ATTRIB:
            self.edge_attrib = self.params.edge_attrib_def.instantiate()
        return self._edge_attrib

    @edge_attrib.setter
    def edge_attrib(self, edge_attrib):
        self._edge_attrib = edge_attrib
        self.module_logger.info(f"Edge attribute calculator set to {self._edge_attrib}!")

    @abstractmethod
    def _get_A(self, W, num_nodes):
        pass

    def _get_W(self, x):
        return self.edge_attrib.get_attribs(x) if self.edge_attrib is not None else None

    def get_graph(self, x):
        self.device = x.device

        W = self._get_W(x)
        A = self._get_A(W, x.size()[0])

        edge_index = torch.nonzero(A)
        edge_attribs = W[edge_index[:, 0], edge_index[:, 1]] if W is not None else None

        edge_attribs = torch.unsqueeze(edge_attribs, 0).T if edge_attribs is not None else None
        edge_index = edge_index.T

        return edge_attribs, edge_index

    def __str__(self):
        return f"{self.__class__.__name__} - {self.edge_attrib.__class__.__name__}"


class DenseGraphGraphBuilderHyperParameterSet(GraphBuilderHyperParameterSet):
    """HyperParameterSet of the GeneralNet"""

    def __init__(self,
                 edge_attrib_def: DefinitionSet = None,
                 **kwargs):
        """
        Creates new HyperParameterSet
        :param classif_net_def: The definition of the feature extraction part
        :param mpn_net_def: The definition of the message passing network part
        :func:`~ClassificationModuleHyperParameterSet.__init__`
        """
        super().__init__(edge_attrib_def, **kwargs)


class DenseGraphBuilderDefinitionSet(DefinitionSet):
    """Definition of the GeneralNet"""

    def __init__(self, hyperparams: DenseGraphGraphBuilderHyperParameterSet = DenseGraphGraphBuilderHyperParameterSet()):
        super().__init__(GraphType.DENSE, hyperparams)

    def instantiate(self, *args, **kwargs):
        return DenseGraphBuilder(self.hyperparams)


class DenseGraphBuilder(GraphBuilder):
    def _get_A(self, W, num_nodes):
        A = torch.ones((num_nodes, num_nodes), device=self.device)
        return A


class SSLDenseGraphGraphBuilderHyperParameterSet(DenseGraphGraphBuilderHyperParameterSet):
    pass


class SSLDenseGraphBuilderDefinitionSet(DefinitionSet):
    """Definition of the GeneralNet"""

    def __init__(self, hyperparams: SSLDenseGraphGraphBuilderHyperParameterSet = SSLDenseGraphGraphBuilderHyperParameterSet()):
        super().__init__(GraphType.SSLDENSE, hyperparams)

    def instantiate(self, *args, **kwargs):
        return SSLDenseGraphBuilder(self.hyperparams)


class SSLDenseGraphBuilder(GraphBuilder):
    def _get_A(self, W, num_nodes):
        A = torch.ones((num_nodes, num_nodes), device=self.device)
        A -= torch.eye(num_nodes, num_nodes, device=self.device)
        return A


class DistinctGraphBuilderDefinitionSet(DefinitionSet):
    """Definition of the GeneralNet"""

    def __init__(self, hyperparams: GraphBuilderHyperParameterSet = GraphBuilderHyperParameterSet()):
        super().__init__(GraphType.DISTINCT, hyperparams)

    def instantiate(self, *args, **kwargs):
        return DistinctGraphBuilder(self.hyperparams)


class DistinctGraphBuilder(GraphBuilder):
    def _get_A(self, W, num_nodes):
        A = torch.eye(num_nodes, num_nodes, device=self.device)
        return A
