"""Compare the partitioning of trajectories between models."""
import argparse
from pathlib import Path
from collections import defaultdict
import json
from swmpo.state_machine import get_visited_states
from swmpo_experiments.visited_states_plotting import plot_visited_states
from swmpo_experiments.hmm_synthesis import get_hmm_visited_states
from swmpo.state_machine import deserialize_state_machine
from swmpo_experiments.state_machine_synthesis import deserialize_dataset
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
from swmpo.sequence_distance import get_error
from swmpo.sequence_distance import get_best_permutation
import hmmlearn.hmm
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
import numpy as np
import pickle


def get_algorithms_y_pred(
    per_episode_visited_states: list[dict[str, list[int]]](),
) -> tuple[list[int], dict[str, list[int]]]:
    """Returns (y_true, (alg_id -> y_pred))"""
    algorithms_y_pred = defaultdict[str, list[int]](list)
    Y_true = list()
    for episode_visited_states in per_episode_visited_states:
        y_true = episode_visited_states["ground_truth"]
        Y_true.extend(y_true)

        for algorithm, visited_states in episode_visited_states.items():
            if algorithm == "ground_truth":
                continue
            algorithms_y_pred[algorithm].extend(visited_states)
    return Y_true, algorithms_y_pred


def get_classification_report(
    per_episode_visited_states: list[dict[str, list[int]]](),
) -> dict[str, dict]:
    Y_true, algorithms_y_pred = get_algorithms_y_pred(
        per_episode_visited_states
    )

    target_names = [
        f"mode {mode}"
        for mode in set(Y_true)
    ]

    reports = dict[str, dict]()
    for algorithm, Y_pred in algorithms_y_pred.items():
        report = classification_report(
            Y_true,
            Y_pred,
            target_names=target_names,
            output_dict=True,
        )
        reports[algorithm] = report
    return reports


def plot_partition_errors(
        errors: dict[str, list[int]],
        output_path: Path,
        ):
    # Create matplotlib figure
    fig = Figure()
    _ = FigureCanvas(fig)
    ax = fig.add_subplot()

    # Plot partition errors
    ax.set_ylabel("Edit distance")
    labels = list(errors.keys())
    error_list = [
        errors[label]
        for label in labels
    ]
    ax.boxplot(error_list, tick_labels=labels)  # will be used to label x-ticks

    # Save figure
    fig.tight_layout()
    fig.suptitle('Visited states error')
    fig.savefig(output_path)


def plot_confusion_matrix(
    Y_true: list[int],
    Y_pred: list[int],
    output_path: Path,
):
    labels = sorted(set(Y_true) | set(Y_pred))
    matrix = confusion_matrix(
        Y_true,
        Y_pred,
        labels=labels,
        normalize="true",
    )

    # Create matplotlib figure
    fig = Figure()
    _ = FigureCanvas(fig)
    ax = fig.add_subplot()

    # Plot matrix
    ax.imshow(matrix)

    # Show all ticks and label them with the respective list entries
    ax.set_xticks(
        np.arange(len(labels)),
        labels=labels,
        rotation=45,
        rotation_mode="anchor",
        horizontalalignment="right",
    )
    ax.set_yticks(np.arange(len(labels)), labels=labels)

    # Loop over data dimensions and create text annotations.
    for i in range(len(labels)):
        for j in range(len(labels)):
            ax.text(
                j,
                i,
                matrix[i, j],
                ha="center",
                va="center",
                color="w"
            )

    # Save figure
    fig.tight_layout()
    fig.suptitle('Confusion matrix')
    fig.savefig(output_path)
    print(f"Wrote {output_path}")


def main(
        output_dir: Path,
        episode_dir: Path,
        hmm_pkl: Path,
        state_machine_zip: Path,
        dt: float,
        ):
    # Deserialize dataset
    dataset = deserialize_dataset(episode_dir)

    # Load state machine
    state_machine = deserialize_state_machine(state_machine_zip)

    # Load HMM
    with open(hmm_pkl, "rb") as file:
        hmm: hmmlearn.hmm.GaussianHMM = pickle.load(file)

    errors = defaultdict(list)
    per_episode_visited_states = list[dict[str, list[int]]]()
    for i, episode in enumerate(dataset.episodes):
        # Get ground truth states
        ground_truth_visited_states = dataset.ground_truth_modes[i]

        # Get state machine visited states
        fsm_visited_states = get_visited_states(state_machine, 0, episode, dt)
        fsm_visited_states = get_best_permutation(
            fsm_visited_states,
            ground_truth_visited_states,
            initial_state=0,
        )

        # Get HMM visited states
        hmm_visited_states = get_hmm_visited_states(hmm, episode)
        hmm_visited_states = get_best_permutation(
            hmm_visited_states,
            ground_truth_visited_states,
            initial_state=-1,  # HMM is unconstrained to find best permutation
        )

        minlen = min(
            len(fsm_visited_states),
            len(hmm_visited_states),
            len(ground_truth_visited_states),
        )

        visited_states = dict(
            hmm=hmm_visited_states[:minlen],
            fsm=fsm_visited_states[:minlen],
            ground_truth=ground_truth_visited_states[:minlen],
        )
        per_episode_visited_states.append(visited_states)

        # Plot visited states
        visited_states_plot_path = output_dir/f"visited_states_{i}.svg"
        available_indices = set([
            *hmm_visited_states,
            *fsm_visited_states,
            *ground_truth_visited_states,
        ])
        plot_visited_states(
            visited_states=visited_states,
            available_indices=available_indices,
            output_path=visited_states_plot_path,
        )

        # Get errors
        model_errors = [
            ("hmm", get_error(hmm_visited_states, ground_truth_visited_states)),
            ("fsm", get_error(fsm_visited_states, ground_truth_visited_states)),
        ]
        for model_id, model_error in model_errors:
            errors[model_id].append(model_error)

    # Log errors
    errors_path = output_dir/"visited_states_errors.json"
    with open(errors_path, "wt") as fp:
        json.dump(errors, fp)

    # Plot errors
    errors_plot_path = output_dir/"errors.svg"
    plot_partition_errors(
        errors=errors,
        output_path=errors_plot_path,
    )

    # Log visited states
    episode_visited_states_path = output_dir/"per_episode_visited_states.json"
    with open(episode_visited_states_path, "wt") as fp:
        json.dump(per_episode_visited_states, fp)

    # Log classification report
    reports = get_classification_report(per_episode_visited_states)
    reports_path = output_dir/"classification_report.json"
    with open(reports_path, "wt") as fp:
        json.dump(reports, fp, indent=2,)

    # Plot confusion matrices
    Y_true, algorithms_y_pred = get_algorithms_y_pred(
        per_episode_visited_states
    )

    for alg_id, Y_pred in algorithms_y_pred.items():
        matrix_path = output_dir / f"{alg_id}_confusion_matrix"
        plot_confusion_matrix(
            Y_true=Y_true,
            Y_pred=Y_pred,
            output_path=matrix_path,
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog='HMM/FSM benchmarking',
        description='Benchmark structured world model partitions',
    )
    parser.add_argument(
        '--test_trajectory_dir',
        type=Path,
        required=True,
        help=(
            "Directory with transition dataset in the format described in "
            "this module's documentation"
        ),
    )
    parser.add_argument(
        '--state_machine_zip',
        type=Path,
        required=True,
        help='ZIP file with the state machine.'
    )
    parser.add_argument(
        '--hmm_pkl',
        type=Path,
        required=True,
        help='PKL file with the state machine.'
    )
    parser.add_argument(
        '--output_dir',
        type=Path,
        required=True,
        help='Non-existing directory to write output files'
    )
    parser.add_argument(
        '--dt',
        type=float,
        required=True,
        help='Integration constant for the dynamical system.',
    )
    args = parser.parse_args()
    args.output_dir.mkdir()
    main(
        episode_dir=args.test_trajectory_dir,
        output_dir=args.output_dir,
        hmm_pkl=args.hmm_pkl,
        state_machine_zip=args.state_machine_zip,
        dt=args.dt,
    )
