"""Tests for compositionality metrics"""

import random
import itertools
import string
import math

import pytest
import numpy as np

from ncc.compositionality_metrics import base
from ncc.compositionality_metrics import utils
from ncc import compositionality_metrics


POSSIBLE_COLORS = ['blue', 'green', 'gold', 'yellow', 'red'] + \
                  [f'color_{i}' for i in range(25)]
POSSIBLE_SHAPES = ['square', 'circle', 'ellipse', 'triangle', 'rectangle'] + \
                  [f'shape_{i}' for i in range(25)]


def get_trivially_compositional_protocol(
        num_colors: int,
        num_shapes: int
) -> base.Protocol:
    """Generate a base.Protocol that is trivially compositional and should
    obtain maximum compositionality scores """

    objects = itertools.product(
        POSSIBLE_COLORS[:num_colors],
        POSSIBLE_SHAPES[:num_shapes]
    )
    alphabet = list(string.ascii_letters[:num_colors+num_shapes])
    random.shuffle(alphabet)
    color_names, shape_names = alphabet[:num_colors], alphabet[num_colors:]
    color_mapping = {color: color_name for color, color_name  # pylint: disable=unnecessary-comprehension
                     in zip(POSSIBLE_COLORS[:num_colors], color_names)}
    shape_mapping = {shape: shape_name for shape, shape_name  # pylint: disable=unnecessary-comprehension
                     in zip(POSSIBLE_SHAPES[:num_shapes], shape_names)}
    protocol = []
    for color, shape in objects:
        derivation = (color, shape)
        message = ''.join((color_mapping[color], shape_mapping[shape]))
        protocol.append((derivation, message))
    return protocol


@pytest.mark.parametrize('num_shapes', [5, 20])
@pytest.mark.parametrize('num_colors', [5, 20])
@pytest.mark.parametrize(
    'metric,expected_score',
    [(compositionality_metrics.TopographicSimilarity(), 1),
     (compositionality_metrics.ContextIndependence(), 0.25),
     (compositionality_metrics.PositionalDisentanglement(), 1),
     (compositionality_metrics.BagOfWordsDisentanglement(), 1)])
def test_metric_for_fully_compositional_protocol(
        metric,
        expected_score,
        num_colors,
        num_shapes
):
    protocol = get_trivially_compositional_protocol(num_colors, num_shapes)
    score = metric.measure(protocol)
    np.testing.assert_almost_equal(score, expected_score)


def test_get_protocol():
    protocol = compositionality_metrics.get_protocol(
       messages=np.asarray([[1, 0], [3, 7]]),
       labels=[(1, 0), (0, 2)],
       features_names=['color', 'shape']
    )
    assert protocol == [
       (('color=1', 'shape=0'), 'ba'),
       (('color=0', 'shape=2'), 'dh'),
    ]


def test_switch_table_get_protocol():
    protocol = compositionality_metrics.get_protocol(
       messages=np.asarray([[2, 4], [2, 5], [3, 1], [1, 1],
                            [5, 4], [5, 5], [3, 6], [1, 6],
                            [6, 4], [6, 5], [3, 2], [1, 2],
                            [4, 4], [4, 5], [3, 0], [1, 0],
                           ]),
       labels=[(0, 0), (0, 1), (0, 2), (0, 3),
               (1, 0), (1, 1), (1, 2), (1, 3),
               (2, 0), (2, 1), (2, 2), (2, 3),
               (3, 0), (3, 1), (3, 2), (3, 3)],
       features_names=['color', 'shape']
    )
    assert protocol == [
       (('color=0', 'shape=0'), 'ce'),
       (('color=0', 'shape=1'), 'cf'),
       (('color=0', 'shape=2'), 'db'),
       (('color=0', 'shape=3'), 'bb'),
       (('color=1', 'shape=0'), 'fe'),
       (('color=1', 'shape=1'), 'ff'),
       (('color=1', 'shape=2'), 'dg'),
       (('color=1', 'shape=3'), 'bg'),
       (('color=2', 'shape=0'), 'ge'),
       (('color=2', 'shape=1'), 'gf'),
       (('color=2', 'shape=2'), 'dc'),
       (('color=2', 'shape=3'), 'bc'),
       (('color=3', 'shape=0'), 'ee'),
       (('color=3', 'shape=1'), 'ef'),
       (('color=3', 'shape=2'), 'da'),
       (('color=3', 'shape=3'), 'ba'),
    ]


def test_skip_epoch_wrapper():

    class MockMetric(base.Metric):

        def __init__(self):
            self.value = 0
            self.name = 'mock'

        def measure(self, protocol: base.Protocol) -> float:
            self.value += 0.1
            return self.value

    mock_protocol = get_trivially_compositional_protocol(5, 5)
    wrapper = compositionality_metrics.SkipEpochMetricWrapper(
        metric=MockMetric(),
        skip_epochs=2
    )
    np.testing.assert_almost_equal(wrapper.measure(mock_protocol), 0.1)
    np.testing.assert_almost_equal(wrapper.measure(mock_protocol), 0.1)
    np.testing.assert_almost_equal(wrapper.measure(mock_protocol), 0.2)
    np.testing.assert_almost_equal(wrapper.measure(mock_protocol), 0.2)
    np.testing.assert_almost_equal(wrapper.measure(mock_protocol), 0.3)


def test_disentanglement_handles_constant_protocol():
    constant_protocol = [
        (('color=0', 'shape=0'), 'ba'),
        (('color=0', 'shape=1'), 'ba'),
        (('color=1', 'shape=0'), 'ba'),
        (('color=1', 'shape=1'), 'ba'),
    ]
    positional_disentanglement = compositionality_metrics.PositionalDisentanglement()  # pylint: disable=line-too-long
    bow_disentanglement = compositionality_metrics.BagOfWordsDisentanglement()
    assert math.isnan(positional_disentanglement.measure(constant_protocol))
    assert math.isnan(bow_disentanglement.measure(constant_protocol))



def test_one_message_per_class_wrapper():
    protocol = [
        (('color=0', 'shape=0'), 'ba'),
        (('color=0', 'shape=0'), 'ba'),
        (('color=0', 'shape=0'), 'bb'),
        (('color=0', 'shape=0'), 'bc'),

        (('color=0', 'shape=1'), 'ab'),
        (('color=0', 'shape=1'), 'ab'),
        (('color=0', 'shape=1'), 'ab'),
        (('color=0', 'shape=1'), 'ba'),
        (('color=0', 'shape=1'), 'ba'),


        (('color=1', 'shape=0'), 'aa'),
        (('color=1', 'shape=1'), 'bb'),
    ]

    class Mock(base.Metric):
        name = 'mock'
        def measure(self, protocol):
            assert len(protocol) == 4
            assert (('color=0', 'shape=0'), 'ba') in protocol
            assert (('color=0', 'shape=1'), 'ab') in protocol
            assert (('color=1', 'shape=0'), 'aa') in protocol
            assert (('color=1', 'shape=1'), 'bb') in protocol
            return 0

    ompc = utils.OneMessagePerClassWrapper(Mock())
    ompc.measure(protocol)
