from abc import ABC, abstractmethod

import torch
import torch.nn.functional as F

from adl4cv.parameters.params import DefinitionSet, HyperParameterSet
from adl4cv.utils.distance import CorrelationDistance, DistanceDefinitionSet, CorrelationDistanceDefinitionSet, \
    CosineDistanceDefinitionSet, EuclideanDistanceDefinitionSet
from adl4cv.utils.utils import SerializableEnum


class EdgeAttributeType(SerializableEnum):
    NO_EDGE_ATTRIB = "no_edge_attrib"
    CORRELATION_EDGE_ATTRIB = "correlation_edge_attrib"
    COSINE_EDGE_ATTRIB = "cosine_edge_attrib"
    EUCLIDEAN_EDGE_ATTRIB = "euclidean_edge_attrib"


class EdgeAttributeHyperParameterSet(HyperParameterSet):
    """HyperParameterSet of the GeneralNet"""

    def __init__(self,
                 shift_negative_attrib: bool = False,
                 **kwargs):
        """
        Creates new HyperParameterSet
        :param classif_net_def: The definition of the feature extraction part
        :param mpn_net_def: The definition of the message passing network part
        :func:`~ClassificationModuleHyperParameterSet.__init__`
        """
        super().__init__(**kwargs)
        self.shift_negative_attrib = shift_negative_attrib


class EdgeAttribute(ABC):
    def __init__(self, hyperparams: EdgeAttributeHyperParameterSet = EdgeAttributeHyperParameterSet(), distance_metric_def: DistanceDefinitionSet = DistanceDefinitionSet()):
        self.params = hyperparams
        self.distance_metric = distance_metric_def.instantiate()

    def _set_negative_attrib(self, W):
        return F.relu(W)

    def _shift_negative_attrib(self, W):
        """ It shifts the negative probabilities towards the positive regime """
        W = W - torch.min(W)
        W = W * (torch.ones_like(W, device=W.device) - torch.eye(*W.size(), device=W.device))
        return W

    def get_similarity(self, x):
        return - self.distance_metric.get_distance_matrix(x, x)

    def get_attribs(self, x):
        W = self.get_similarity(x)

        if self.params.shift_negative_attrib:
            W = self._shift_negative_attrib(W)
        else:
            W = self._set_negative_attrib(W)

        return W


class OldCorrelationEdgeAttributeDefinitionSet(DefinitionSet):
    """Definition of the GeneralNet"""

    def __init__(self, hyperparams: EdgeAttributeHyperParameterSet = EdgeAttributeHyperParameterSet()):
        super().__init__(EdgeAttributeType.CORRELATION_EDGE_ATTRIB, hyperparams)

    def instantiate(self, *args, **kwargs):
        return OldCorrelationEdgeAttribute(self.hyperparams, CorrelationDistanceDefinitionSet())


class OldCorrelationEdgeAttribute(EdgeAttribute):
    def get_similarity(self, x):
        x = (x - x.mean(dim=1).unsqueeze(1))
        x = x / x.norm(dim=1).unsqueeze(1)
        W = torch.mm(x, x.t())
        return W


class CorrelationEdgeAttributeDefinitionSet(DefinitionSet):
    """Definition of the GeneralNet"""

    def __init__(self, hyperparams: EdgeAttributeHyperParameterSet = EdgeAttributeHyperParameterSet()):
        super().__init__(EdgeAttributeType.CORRELATION_EDGE_ATTRIB, hyperparams)

    def instantiate(self, *args, **kwargs):
        return CorrelationEdgeAttribute(self.hyperparams, CorrelationDistanceDefinitionSet())


class CorrelationEdgeAttribute(EdgeAttribute):
    pass


class CosineEdgeAttributeDefinitionSet(DefinitionSet):
    """Definition of the GeneralNet"""

    def __init__(self, hyperparams: EdgeAttributeHyperParameterSet = EdgeAttributeHyperParameterSet()):
        super().__init__(EdgeAttributeType.CORRELATION_EDGE_ATTRIB, hyperparams)

    def instantiate(self, *args, **kwargs):
        return CosineEdgeAttribute(self.hyperparams, CosineDistanceDefinitionSet())


class CosineEdgeAttribute(EdgeAttribute):
    pass


class DistanceEdgeAttributeHyperParameterSet(EdgeAttributeHyperParameterSet):
    def __init__(self,
                 shift_negative_attrib: bool = False,
                 norm: bool = True,
                 **kwargs):
        super().__init__(shift_negative_attrib, **kwargs)
        self.norm = norm


class DistanceEdgeAttributeDefinitionSet(DefinitionSet):
    """Definition of the GeneralNet"""

    def __init__(self, hyperparams: DistanceEdgeAttributeHyperParameterSet = DistanceEdgeAttributeHyperParameterSet()):
        super().__init__(EdgeAttributeType.EUCLIDEAN_EDGE_ATTRIB, hyperparams)

    def instantiate(self, *args, **kwargs):
        return DistanceEdgeAttribute(self.hyperparams, EuclideanDistanceDefinitionSet())


class DistanceEdgeAttribute(EdgeAttribute):
    def get_similarity(self, x):
        W = self.distance_metric.get_distance_matrix(x, x)
        if self.params.norm:
            W = W / W.max()
        return - W


class NoEdgeAttributeDefinitionSet(DefinitionSet):
    """Definition of the GeneralNet"""

    def __init__(self):
        super().__init__(EdgeAttributeType.NO_EDGE_ATTRIB, None)

    def instantiate(self, *args, **kwargs):
        return None
