"""Transition prunning.

A transition occurs when the label between consecutive transitions changes.
For example, in this sequence we transition from mode 0 to mode 3 once and
from mode 3 to mode 2 once:

| Transition | Mode |
| ---------- | ---- |
| $t_1$      | 1    |
| $t_2$      | 1    |
| $t_3$      | 3    |
| $t_4$      | 3    |
| $t_5$      | 2    |
| $t_6$      | 2    |

In each transition, the mode label is the index of the neural network that
should be used to predict the evolution of the system. This labels are assigned
by identifying the neural network in the ensemble with the minimum error for
that transition.

As an example of what prunning a transition means, prunning the transition to
mode 3 results in either of the following two sequences:

Option 1:

| Transition | Mode |
| ---------- | ---- |
| $t_1$      | 1    |
| $t_2$      | 1    |
| $t_3$      | 1    |
| $t_4$      | 1    |
| $t_5$      | 2    |
| $t_6$      | 2    |

Option 2:

| Transition | Mode |
| ---------- | ---- |
| $t_1$      | 1    |
| $t_2$      | 1    |
| $t_3$      | 2    |
| $t_4$      | 2    |
| $t_5$      | 2    |
| $t_6$      | 2    |

We are interested in prunning transitions that offer minimal benefit to
prediction accuracy. This is a way to navigate the complexity-accuracy
tradeoff: more complex state machines can be more accurate, but are more prone
to overfitting and less interpretable.
"""
from swmpo.model import get_raw_error
from swmpo.transition import Transition
from swmpo.transition import equals as is_transition_equal
from swmpo.partition import StatePartitionItem


Label = int


def is_forward_prunnable(labels: tuple[Label, ...], i: int) -> bool:
    # Check that the immediate previous label is different
    if i == 0:
        return False
    return labels[i-1] != labels[i]


def forward_prune(labels: tuple[Label, ...], i: int) -> tuple[Label, ...]:
    new_labels = list(labels)
    for j in range(i, len(new_labels)):
        # Stop relabeling when we get to the next label
        if labels[j] != labels[i]:
            break
        new_labels[j] = labels[i-1]

    return tuple(new_labels)


def is_backward_prunnable(labels: tuple[Label, ...], i: int) -> bool:
    # Check that we just transitioned and there exists a transition in the future
    if i > 0 and not is_forward_prunnable(labels, i):
        return False
    for j in range(i, len(labels)):
        if labels[j] != labels[i]:
            return True
    return False


def get_next_change_index(labels: tuple[Label, ...], i: int) -> int:
    for k in range(i, len(labels)):
        if labels[k] != labels[i]:
            return k
    raise ValueError("No transition in the future!")


def backward_prune(labels: tuple[Label, ...], i: int) -> tuple[Label, ...]:
    new_labels = list(labels)
    k = get_next_change_index(labels, i)
    for j in range(i, k):
        new_labels[j] = labels[k]
    return tuple(new_labels)


def is_prunnable(labels: tuple[Label, ...], i: int) -> bool:
    is_fp = is_forward_prunnable(labels, i)
    is_bp = is_backward_prunnable(labels, i)
    return is_fp or is_bp


def get_performance_hit(
        model_errors: list[list[float]],  # error[transition_i, model_i]
        l1: tuple[Label, ...],
        l2: tuple[Label, ...],
        ) -> float:
    perf_hits = list()

    for ti, (l1i, l2i) in enumerate(zip(l1, l2)):
        e1 = model_errors[ti][l1i]
        e2 = model_errors[ti][l2i]
        perf_hit = abs(e2 - e1)
        perf_hits.append(perf_hit)

    return max(perf_hits)


def prune(
        model_errors: list[list[float]],  # error[transition_i, model_i]
        l: tuple[Label, ...],
        i: int,
        ) -> tuple[Label, ...]:
    is_fp = is_forward_prunnable(l, i)
    is_bp = is_backward_prunnable(l, i)

    assert is_fp or is_bp

    if is_fp and not is_bp:
        return forward_prune(l, i)
    elif is_bp and not is_fp:
        return backward_prune(l, i)

    l_fp = forward_prune(l, i)
    l_bp = backward_prune(l, i)

    perf_hit_fp = get_performance_hit(model_errors, l, l_fp)
    perf_hit_bp = get_performance_hit(model_errors, l, l_bp)

    if perf_hit_fp < perf_hit_bp:
        return l_fp
    return l_bp


def _is_epsilon_prunnable(
        model_errors: list[list[float]],  # error[transition_i, model_i]
        l: tuple[Label, ...],
        i: int,
        epsilon: float,
        ) -> bool:
    if not is_prunnable(l, i):
        return False

    l2 = prune(model_errors, l, i)

    perf_hit = get_performance_hit(model_errors, l, l2)

    return perf_hit < epsilon


def is_epsilon_prunnable(
        model_errors: list[list[float]],  # error[transition_i, model_i]
        l: tuple[Label, ...],
        epsilon: float,
        ) -> bool:
    for i in range(len(l)):
        if _is_epsilon_prunnable(model_errors, l, i, epsilon):
            return True
    return False


def greedy_epsilon_prune(
        model_errors: list[list[float]],
        l: tuple[Label, ...],
        epsilon: float,
        ) -> tuple[Label, ...]:
    for i in range(len(l)):
        if _is_epsilon_prunnable(model_errors, l, i, epsilon):
            l2 = prune(model_errors, l, i)
            return l2
    raise ValueError("No epsilon-prunnable transition found!")


def get_label(
        transition: Transition,
        partition: list[StatePartitionItem],
        ) -> int:
    """Return the subset of the partition that contains the transition."""
    for i, subset in enumerate(partition):
        for subset_transition in subset.subset:
            if is_transition_equal(transition, subset_transition):
                return i
    raise ValueError("Transition is not in any subset!")


def get_labels(
        episode: list[Transition],
        partition: list[StatePartitionItem],
        ) -> tuple[Label, ...]:
    """Helper function to assign a labels to each transition in the episode
    corresponding to the index of the subset of the partition in which the
    transition is."""
    labels = [
        get_label(transition=transition, partition=partition)
        for transition in episode
    ]
    return tuple(labels)


TransitionID = tuple[int, int]
PartitionCode = list[set[TransitionID]]


def _get_greedily_prunned_partition(
        model_errors: list[list[float]],
        all_labels: tuple[tuple[Label, ...], ...],
        partition_code: PartitionCode,
        episode_i: int,
        error_tolerance: float,
        ) -> PartitionCode:
    """Helper function to prune a partition with respect to a single episode."""
    epsilon = error_tolerance
    labels = all_labels[episode_i]
    original_labels = tuple(labels)

    # Greedily prune transitions until there are no more transitions
    # to prune.
    i = 0
    while is_epsilon_prunnable(model_errors, labels, epsilon):
        if i > 100:
            break
        i += 1
        new_labels = greedy_epsilon_prune(model_errors, labels, epsilon)
        perf_hit = get_performance_hit(model_errors, original_labels, new_labels)
        if perf_hit > epsilon:
            break
        else:
            labels = new_labels

    # Compute the new transition code
    new_partition_code = [
        set()
        for _ in partition_code
    ]
    for episode_j in range(len(all_labels)):
        for transition_j in range(len(all_labels[episode_j])):
            transition_id = (episode_j, transition_j)
            if episode_i == episode_j:
                # If this transition is in the episode that is being prunned
                # the new label should be the label after prunning
                label = labels[transition_j]
            else:
                # Otherwise, we use the same label
                label = all_labels[episode_j][transition_j]
            new_partition_code[label].add(transition_id)
    return new_partition_code


def get_greedily_prunned_partition(
        partition: list[StatePartitionItem],
        episodes: list[list[Transition]],
        error_tolerance: float,
        dt: float,
        ) -> list[StatePartitionItem]:
    """Reassign transitions to different subsets to minimize the
    number of transitions between subsets along the given episodes.

    All of the transitions in the episode must be in the partition.

    A sequence of transitions in the episode is moved from one subset
    to another if all of them are in the same subset originally and
    moving them induces at most `error_tolerance` prediction error

    The partition will have as many subsets, and each new subset
    `i` corresponds roughly to the same subset `i` from the old partition."""
    all_model_errors = list()
    for episode in episodes:
        # Get each model error
        model_errors = list()
        for transition in episode:
            errors = list()
            for item in partition:
                error = get_raw_error(
                    transition=transition,
                    model=item.local_model,
                    dt=dt,
                )
                errors.append(error)
            model_errors.append(errors)
        all_model_errors.append(model_errors)

    all_labels = list[tuple[Label, ...]]()
    for episode in episodes:
        labels = get_labels(episode, partition)
        all_labels.append(labels)
    all_labels = tuple(all_labels)

    # DEBUG
    j = 0
    print("Prunning transitions")
    print(f"Iter {j}")
    debug_message = "\n".join([
        f"Subset {i}: {len(item.subset)}"
        for i, item in enumerate(partition)
    ])
    print(debug_message)
    # /DEBUG

    # Assemble partition code. This is so that fast lookup for
    # "is this transition in the episode that is being prunned"
    # can be answered quickly.
    partition_code = [
        set[TransitionID]()
        for _ in partition
    ]
    for episode_i in range(len(episodes)):
        for transition_i in range(len(episodes[episode_i])):
            transition_id = (episode_i, transition_i)
            label = all_labels[episode_i][transition_i]
            partition_code[label].add(transition_id)

    # Prune the partition with information from each episode
    for i, episode in enumerate(episodes):
        partition_code = _get_greedily_prunned_partition(
            all_labels=all_labels,
            model_errors=all_model_errors[i],
            partition_code=partition_code,
            episode_i=i,
            error_tolerance=error_tolerance,
        )

        # DEBUG
        j += 1
        print(f"Iter {j}")
        debug_message = "\n".join([
            f"Subset {i}: {len(subset)}"
            for i, subset in enumerate(partition_code)
        ])
        print(debug_message)
        # /DEBUG

    new_partition = list()
    for i, item in enumerate(partition):
        subset = [
            episodes[episode_i][transition_i]
            for episode_i, transition_i in partition_code[i]
        ]
        new_item = StatePartitionItem(
            local_model=item.local_model,
            hidden_sizes=item.hidden_sizes,
            subset=subset,
        )
        new_partition.append(new_item)

    # Reassemble partition
    return new_partition
