"""Comparing sequences of mode sequences."""
from itertools import permutations
import math


def _get_levenshtein(a: list[int], b: list[int]) -> float:
    """Levenshtein distance."""
    # Adapted from Wikipedia
    m = len(a)
    n = len(b)

    # d is an m by n array of zeros
    d = [
        [0 for _ in range(n+1)]
        for _ in range(m+1)
    ]
    for i in range(1, m+1):
        d[i][0] = i

    for j in range(1, n+1):
        d[0][j] = j

    for j in range(1, n+1):
        for i in range(1, m+1):
            if a[i-1] == b[j-1]:
                substitutionCost = 0
            else:
                substitutionCost = 1

            d[i][j] = min(
                d[i-1][j] + 1,                   # deletion
                d[i][j-1] + 1,                   # insertion
                d[i-1][j-1] + substitutionCost)  # substitution

    return d[m][n]


def get_error(visited_states: list[int], ground_truth: list[int]) -> float:
    """Levenshtein distance."""
    return _get_levenshtein(visited_states, ground_truth)


def _get_best_permutation(
    sequence: list[int],
    ground_truth: list[int],
    indices: list[int],
) -> dict[int, int]:
    """Returns a sequence `new` with the property that `new[i] = perm[sequence[i]]`,
    where `perm` is a permutation of the mode indices that preserves the initial
    mode. The permutation is such that the returned sequence has the minimum
    error.

    That is, if we have a sequence `[0, 2, 3, 2, 4]`, we could return
    `[0, 3, 2, 3, 4]`.
    """
    best_error = math.inf
    best_perm = None
    i = 0
    for permutation in permutations(indices):
        if i > 1000:
            break  # TODO: do something about large index sets
        i += 1
        perm = {
            original: new
            for original, new in zip(indices, permutation)
        }
        new_sequence = [
            perm[original]
            for original in sequence
        ]
        new_error = get_error(
            new_sequence,
            ground_truth,
        )
        if new_error < best_error:
            best_error = new_error
            best_perm = perm
    assert best_perm is not None
    return best_perm


def get_best_permutation(
    sequence: list[int],
    ground_truth: list[int],
    initial_state: int,
) -> list[int]:
    """Returns a sequence `new` with the property that `new[i] = perm[sequence[i]]`,
    where `perm` is a permutation of the mode indices that preserves the initial
    mode. The permutation is such that the returned sequence has the minimum
    error.

    That is, if we have a sequence `[1, 2, 3, 2, 4]`, we could return
    `[1, 3, 2, 3, 4]`.
    """
    initial_state = -1
    indices = list((set(sequence)|set(ground_truth)) - set([initial_state]))
    perm = _get_best_permutation(
        sequence=sequence,
        ground_truth=ground_truth,
        indices=indices,
    )
    new_sequence = [
        perm[original]
        for original in sequence
    ]
    return new_sequence
