"""Topographic similarity class"""

from typing import Callable, List, Union, Tuple  # pylint: disable=import-only-modules

import gin
from scipy import stats
from scipy.spatial import distance
import editdistance

from ncc.compositionality_metrics import base


@gin.configurable
class TopographicSimilarity(base.Metric):
    """Topographic similarity"""

    name = 'topographic similarity'

    def __init__(
            self,
            input_metric: Callable = distance.hamming,
            messages_metric: Callable = editdistance.eval
    ):
        self.input_metric = input_metric
        self.messages_metric = messages_metric

    def measure(self, protocol: base.Protocol) -> float:
        distance_messages = self._compute_distances(
            sequence=[message for derivation, message in protocol],
            metric=self.messages_metric)
        distance_inputs = self._compute_distances(
            sequence=[derivation for derivation, message in protocol],
            metric=self.input_metric)
        return stats.spearmanr(
            distance_messages,
            distance_inputs
        ).correlation

    def _compute_distances(
            self,
            sequence: List[Union[str, Tuple[str, str]]],
            metric: Callable
    ) -> List[float]:
        distances = []
        for i, element_1 in enumerate(sequence):
            for element_2 in sequence[i+1:]:
                distances.append(metric(element_1, element_2))
        return distances
