"""State machine synthesizer inspired by the STUN algorithm."""
from collections import defaultdict
from dataclasses import dataclass
from itertools import product
from sklearn.tree import DecisionTreeClassifier
from swmpo.transition import get_vector
from swmpo.transition import Transition
from swmpo.transition_predicates import get_transition_predicates
from swmpo.partition import StatePartitionItem
from swmpo.model import get_input_output_size, get_raw_error, get_relu_mlp
import torch
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
from joblib import dump
from joblib import load
from pathlib import Path
import random
import json
import tempfile
import shutil
import zipfile


# Serialize local models
LOCAL_MODELS_ARCHITECTURE = "local_models_structure.json"
LOCAL_MODELS_DIR = "local_models"
TRANSITION_PREDICATES_DIR = "transition_predicates"
TRANSITION_HISTOGRAM_PATH = "transition_histogram.json"
NUM_CONSECUTIVE = 0


@dataclass
class StateMachine:
    """
    - transition_histogram[i][j]: how many times transition (i, j) was
      taken in the training data. Only for logging purposes.
    """
    local_models: list[torch.nn.Module]
    transition_predicates: list[list[DecisionTreeClassifier | None]]
    transition_histogram: list[list[int]]
    local_models_hidden_sizes: list[int]
    local_models_input_size: int
    local_models_output_size: int


@dataclass
class StateMachineOptimizationResult:
    state_machine: StateMachine
    partition_loss_log: list[float]


def state_machine_model(
        state_machine: StateMachine,
        prev_state: torch.Tensor,
        prev_action: torch.Tensor,
        state: torch.Tensor,
        current_node: int,
        dt: float,
        ) -> tuple[torch.Tensor, int]:
    node_indices = list(range(len(state_machine.local_models)))

    # Test each transition predicate
    transition = Transition(
        source_state=prev_state,
        action=prev_action,
        next_state=state,
    )
    x = get_vector(transition)
    acceptable_next_states = list()
    for j in node_indices:
        predicate = state_machine.transition_predicates[current_node][j]
        if predicate is not None:
            x_numpy = x.detach().reshape(1, -1).numpy()
            robustness_value = predicate.predict(x_numpy)
        else:
            robustness_value = 0.0
        if robustness_value > 0.0:
            acceptable_next_states.append(j)

    # Identify next state
    if len(acceptable_next_states) == 0:
        next_node = current_node
    else:
        next_node = acceptable_next_states[0]

    # TODO: Predict next state
    next_state = state
    return (next_state, next_node)


def get_visited_states(
    state_machine: StateMachine,
    initial_state: int,
    episode: list[Transition],
    dt: float,
) -> list[int]:
    """Return the sequence of states that the state machine traversed when
    processing the list of state-action tuples.

    In case of ties, we arbitrarily choose the first accepted transition.
    """
    current_node = initial_state
    visited_nodes = [current_node]

    consecutive_visits = 0  # Track consecutive visits for the current state

    for transition in episode:
        _, next_node = state_machine_model(
            state_machine=state_machine,
            prev_state=transition.source_state,
            state=transition.next_state,
            prev_action=transition.action,
            current_node=current_node,
            dt=dt,
        )

        # Only transition if we have visited the current state at least
        # 10 times consecutively
        if next_node != current_node:
            if consecutive_visits >= NUM_CONSECUTIVE:
                print(f"Transitioned from state {current_node} to {next_node} after {consecutive_visits} visits.")
                current_node = next_node
                consecutive_visits = 1  # Reset count for the new state
            else:
                consecutive_visits += 1
        else:
            consecutive_visits += 1

        visited_nodes.append(current_node)

    return visited_nodes


def get_local_model_errors(
        state_machine: StateMachine,
        episode: list[Transition],
        dt: float,
        ) -> list[list[float]]:
    """Return the list of errors of each state for each transition in the
    episode. The returned value is a list of `len(state_machine.local_models)`
    lists of size `len(episode)`."""
    episode_errors = list()
    for transition in episode:
        # Evaluate each model in the current transition
        transition_errors = [
            get_raw_error(transition, model, dt)
            for model in state_machine.local_models
        ]
        episode_errors.append(transition_errors)
    return episode_errors


def serialize_state_machine(
        state_machine: StateMachine,
        output_zip_path: Path,
        ):
    """Serialize the state machine to the given directory.
    Output directory is assumed to exist.
    """
    assert output_zip_path.suffix == ".zip"
    with tempfile.TemporaryDirectory() as tmpdirname:
        output_dir = Path(tmpdirname)

        # Serialize local models
        local_models_dir = output_dir/LOCAL_MODELS_DIR
        local_models_dir.mkdir()
        for i, model in enumerate(state_machine.local_models):
            model_path = local_models_dir/f"{i}.pt"
            torch.save(model.state_dict(), model_path)

        # Serialize local models architecture
        architecture_path = output_dir/LOCAL_MODELS_ARCHITECTURE
        with open(architecture_path, "wt") as fp:
            architecture = dict(
                local_models_input_size=state_machine.local_models_input_size,
                local_models_output_size=state_machine.local_models_output_size,
                local_models_hidden_sizes=state_machine.local_models_hidden_sizes,
            )
            json.dump(architecture, fp)

        # Serialize transition predicates
        transition_predicates_dir = output_dir/TRANSITION_PREDICATES_DIR
        transition_predicates_dir.mkdir()
        state_indices = list(range(len(state_machine.local_models)))
        for i, j in product(state_indices, state_indices):
            # Serialize actual predicate
            transition_predicate = state_machine.transition_predicates[i][j]
            transition_predicate_path = transition_predicates_dir/f"{i}-{j}.joblib"
            dump(transition_predicate, transition_predicate_path)

            ## Also save a diagram for visualization
            #if transition_predicate is not None:
            #    transition_predicate_plot_path = transition_predicates_dir/f"{i}-{j}.svg"
            #    fig = Figure()
            #    _ = FigureCanvas(fig)
            #    ax = fig.add_subplot()
            #    sklearn.tree.plot_tree(transition_predicate, ax=ax)
            #    fig.savefig(transition_predicate_plot_path)

        # Serialize transition histogram
        transition_histogram_path = output_dir/TRANSITION_HISTOGRAM_PATH
        with open(transition_histogram_path, "wt") as fp:
            json.dump(state_machine.transition_histogram, fp)

        # Plot transition matrix
        transition_histogram_plot_path = output_dir/"transition_histogram.svg"
        fig = Figure()
        _ = FigureCanvas(fig)
        ax = fig.add_subplot()
        M = state_machine.transition_histogram
        ax.imshow(M)
        for i in range(len(M)):
            for j in range(len(M[i])):
                ax.text(
                    j,
                    i,
                    str(state_machine.transition_histogram[i][j]),
                    ha="center", va="center", color="w",
                )
        ticks = list(range(len(M)))
        ax.set_xticks(ticks, labels=[str(t) for t in ticks])
        ax.set_yticks(ticks, labels=[str(t) for t in ticks])
        ax.set_title("Transition histogram")
        fig.savefig(transition_histogram_plot_path)

        shutil.make_archive(str(output_zip_path.with_suffix("")), 'zip', output_dir)


def deserialize_state_machine(
        zip_path: Path,
        ) -> StateMachine:
    """Load the state machine in the given ZIP file written by
    `swmpo.state_machine.serialize_state_machine`."""
    with tempfile.TemporaryDirectory() as tmpdirname:
        output_dir = Path(tmpdirname)

        # Load local models hidden sizes
        with zipfile.ZipFile(zip_path, "r") as zip_ref:
            zip_ref.extractall(output_dir)

        with open(output_dir/LOCAL_MODELS_ARCHITECTURE, "rt") as fp:
            architecture = json.load(fp)
            local_models_hidden_sizes = architecture["local_models_hidden_sizes"]
            input_size = architecture["local_models_input_size"]
            output_size = architecture["local_models_output_size"]

        # Load local models
        local_model_paths = (output_dir/LOCAL_MODELS_DIR).glob("*.pt")
        sorted_local_model_paths = sorted(
            local_model_paths,
            key=lambda path: int(path.stem),
        )
        local_models = list()
        for local_model_path in sorted_local_model_paths:
            local_model = get_relu_mlp(
                input_size=input_size,
                hidden_sizes=local_models_hidden_sizes,
                output_size=output_size,
                seed="",  # Doesn't matter: we will overwrite the weights
            )
            state_dict = torch.load(local_model_path, weights_only=True)
            local_model.load_state_dict(state_dict)
            local_model = local_model.eval()
            local_models.append(local_model)

        # Load transition predicates
        state_indices = list(range(len(local_models)))
        transition_predicates_dir = output_dir/TRANSITION_PREDICATES_DIR
        transition_predicates = defaultdict(dict)
        for i, j in product(state_indices, state_indices):
            transition_predicate_path = transition_predicates_dir/f"{i}-{j}.joblib"
            if transition_predicate_path.exists():
                predicate = load(transition_predicate_path)
                transition_predicates[i][j] = predicate
            else:
                transition_predicates[i][j] = None

        transition_predicates = [
            [
                transition_predicates[i][j]
                for j in state_indices
            ]
            for i in state_indices
        ]

        # Load transition histogram
        transition_histogram_path = output_dir/TRANSITION_HISTOGRAM_PATH
        with open(transition_histogram_path, "rt") as fp:
            transition_histogram = json.load(fp)

    state_machine = StateMachine(
        local_models=local_models,
        transition_predicates=transition_predicates,
        transition_histogram=transition_histogram,
        local_models_hidden_sizes=local_models_hidden_sizes,
        local_models_input_size=input_size,
        local_models_output_size=output_size,
    )
    return state_machine


def get_partition_induced_state_machine(
    partition: list[StatePartitionItem],
    predicate_hyperparameters: dict,
    seed: str,
) -> StateMachine:
    _random = random.Random(seed)

    # Characterize the transition predicates between the sets of the partition.
    subsets = [
        item.subset
        for item in partition
    ]
    transition_predicates = get_transition_predicates(
        partition=subsets,
        predicate_hyperparameters=predicate_hyperparameters,
        seed=str(_random.random()),
    )

    # Assemble state machine
    local_models = [
        item.local_model
        for item in partition
    ]
    all_transitions = [
        transition
        for subset in subsets
        for transition in subset
    ]
    input_size, output_size = get_input_output_size(all_transitions[0])
    all_hidden_sizes = [tuple(item.hidden_sizes) for item in partition]
    assert len(set(all_hidden_sizes)) == 1, "Local models have different hidden sizes!"
    assert partition
    hidden_sizes = partition[0].hidden_sizes
    state_machine = StateMachine(
        local_models=local_models,
        transition_predicates=transition_predicates.transition_predicates,
        transition_histogram=transition_predicates.transition_histogram,
        local_models_hidden_sizes=hidden_sizes,
        local_models_input_size=input_size,
        local_models_output_size=output_size,
    )
    return state_machine


def get_state_machine_errors(
        state_machine: StateMachine,
        episode: list[Transition],
        initial_state: int,
        dt: float,
        ) -> list[float]:
    """Return the errors of each state of the state machine."""
    current_node = initial_state
    errors = list()
    for transition in episode:
        predicted_next_state, next_node = state_machine_model(
            state_machine=state_machine,
            state=transition.next_state,
            prev_state=transition.source_state,
            prev_action=transition.action,
            current_node=current_node,
            dt=dt,
        )

        # Log error
        error = (predicted_next_state - transition.next_state).norm().item()
        errors.append(error)

        # Transition
        current_node = next_node
    return errors
