"""Transition prunning based on size of transitions."""
from dataclasses import dataclass


@dataclass
class Island:
    mode: int
    start: int  # inclusive
    end: int  # inclusive

    @property
    def size(self) -> int:
        return self.end - self.start


def get_islands(
    modes: list[int],
) -> list[Island]:
    """Prune one transitions of length less than `island_size`. Assumes a
    non-empty list of modes."""
    # Identify all islands
    islands = list()
    idxs = list(range(len(modes)))
    island_mode = modes[0]
    island_start = 0
    for i in idxs:
        if modes[i] != island_mode:
            # Push island
            island = Island(
                start=island_start,
                end=i-1,
                mode=island_mode,
            )
            islands.append(island)

            # Create new island
            island_start = i
            island_mode = modes[i]
        elif i == idxs[-1]:
            # Push island
            island = Island(
                start=island_start,
                end=i-1,
                mode=island_mode,
            )
            islands.append(island)
        else:
            pass
    return islands


def get_surrounding_modes(
    modes: list[int],
    island: Island,
) -> set[int]:
    """Get the modes that surround the island."""
    surrounding_modes = set[int]()
    if island.start > 0:
        surrounding_modes.add(modes[island.start-1])
    if island.end < len(modes) - 1:
        surrounding_modes.add(modes[island.end+1])
    return surrounding_modes


def is_prunnable(
    modes: list[int],
    island: Island,
    min_island_size: int,
) -> bool:
    """True if and only if the given island is prunnable."""
    # Check size
    if island.size >= min_island_size:
        return False

    # Check island is surrounded by the same mode
    surrounding_modes = get_surrounding_modes(modes, island)
    if len(surrounding_modes) > 1:
        return False

    return True


def get_prunned_island(
    modes: list[int],
    island: Island,
) -> list[int]:
    """Prune the island from the given list of modes."""
    surrounding_modes = get_surrounding_modes(modes, island)
    assert len(surrounding_modes) == 1, "Only islands surrounded by the same mode are prunnable!"
    new_mode = surrounding_modes.pop()
    new_modes = list(modes)
    for i in range(island.start, island.end+1):
        new_modes[i] = new_mode
    return new_modes


def prune_short_transitions_step(
    modes: list[int],
    min_island_size: int,
) -> list[int]:
    """Prune a single short transitions of length less than `min_island_size`,
    if it exists. Assumes a non-empty list of modes."""
    # Get islands
    islands = get_islands(modes)

    # Get prunnable islands
    prunnable_islands = [
        island
        for island in islands
        if is_prunnable(modes, island, min_island_size)
    ]

    # If there are no prunnable islands, there is nothing to do
    if len(prunnable_islands) == 0:
        return modes

    # Choose the smallest prunnable island
    island = min(prunnable_islands, key=lambda island: island.size)

    # Prune the island
    new_modes = get_prunned_island(modes, island)

    return new_modes


def prune_short_transitions(
    modes: list[int],
    min_island_size: int,
) -> list[int]:
    """Prune short transitions of length less than `min_island_size`. Assumes a
    non-empty list of modes."""
    new_modes = prune_short_transitions_step(modes, min_island_size)
    i = 0
    while tuple(new_modes) != tuple(modes):
        if i > 10000:
            break
        i += 1
        modes = new_modes
        new_modes = prune_short_transitions_step(new_modes, min_island_size)
    return new_modes
