"""Synthesize a state machine from a dataset of transitions.
"""
from swmpo.partition import get_partition_modes
from swmpo.sequence_distance import get_best_permutation
from swmpo.transition_prunning.epsilon_prunning import StatePartitionItem
from swmpo.transition_prunning.epsilon_prunning import get_greedily_prunned_partition
from swmpo.state_machine import get_partition_induced_state_machine
from swmpo.partition import get_partition
from swmpo.state_machine import serialize_state_machine
from swmpo_experiments.loss_plotting import plot_loss_log
from pathlib import Path
import argparse
import random
from swmpo_experiments.plot_representation import plot_learned_representation
from swmpo_experiments.visited_states_plotting import plot_visited_states
from swmpo_experiments.trajectory_dataset import Dataset
from swmpo_experiments.trajectory_dataset import deserialize_dataset
import json


# Avoid pytorch from doing threading. This is so that the script doesn't
# take over the computer's resources. You can remove these lines if not running
# on a lab computer.
import torch
torch.set_num_threads(1)


def plot_partition(
    partition: list[StatePartitionItem],
    dataset: Dataset,
    output_dir: Path,
):
    """Plot each trajectory."""
    for i, trajectory in enumerate(dataset.episodes):
        plot_path = output_dir/f"{i}.svg"
        partition_modes = get_partition_modes(
            trajectory=trajectory,
            partition=partition,
        )

        # Plot partition modes
        partition_modes = get_best_permutation(
            partition_modes,
            dataset.ground_truth_modes[i],
            initial_state=-1,
        )
        visited_states = dict(
            partition=partition_modes,
            ground_truth=dataset.ground_truth_modes[i],
        )
        plot_visited_states(
            visited_states=visited_states,
            available_indices=set(),
            output_path=plot_path,
        )
        print(f"Wrote {plot_path}")


def main(
    episode_dir: Path,
    output_dir: Path,
    hidden_sizes: list[int],
    learning_rate: float,
    optimization_iter_n: int,
    mode_model_iter_n: int,
    cluster_dimensionality_reduce: int | None,
    clustering_information_content_regularization_scale: float,
    clustering_mutual_information_regularization_scale: float,
    partition_latent_size: int,
    partition_size: int,
    device: str,
    batch_size: int,
    mutual_information_mini_batch_size: int,
    prunning_error_tolerance: float,
    dt: float,
    min_island_size: int,
    predicate_hyperparameters: dict,
    seed: str,
):
    _random = random.Random(seed)
    dataset = deserialize_dataset(episode_dir)

    # Optimize partition
    print("Optimizing partition")
    partition_optimization_result = get_partition(
        episodes=dataset.episodes,
        hidden_sizes=hidden_sizes,
        learning_rate=learning_rate,
        optimization_iter_n=optimization_iter_n,
        mode_model_iter_n=mode_model_iter_n,
        clustering_dimensionality_reduce=cluster_dimensionality_reduce,
        clustering_information_content_regularization_scale=clustering_information_content_regularization_scale,
        clustering_mutual_information_regularization_scale=clustering_mutual_information_regularization_scale,
        latent_size=partition_latent_size,
        dt=dt,
        size=partition_size,
        min_island_size=min_island_size,
        seed=str(_random.random()),
        batch_size=batch_size,
        mutual_information_mini_batch_size=mutual_information_mini_batch_size,
        device=device,
        verbose=True,
    )
    print("Done")

    # Plot partition
    partition_plot_dir = output_dir/"partition"
    partition_plot_dir.mkdir()
    plot_partition(
        partition=partition_optimization_result.partition,
        dataset=dataset,
        output_dir=partition_plot_dir,
    )

    # Plot learned representation
    mode_world_model_plot_dir = output_dir/"mode_world_model_plot"
    mode_world_model_plot_dir.mkdir()
    plot_learned_representation(
        mode_world_model=partition_optimization_result.mode_world_model,
        partition=partition_optimization_result.partition,
        ground_truth_modes=dataset.ground_truth_modes,
        episodes=dataset.episodes,
        output_dir=mode_world_model_plot_dir,
        device=device,
    )
    print(f"Wrote {mode_world_model_plot_dir}")

    # Plot optimization loss log
    loss_log_path = output_dir/"loss_log.svg"
    plot_loss_log(
        loss_log=partition_optimization_result.loss_log,
        output_path=loss_log_path,
        label="Partition train loss",
    )
    print(f"Wrote {loss_log_path}")

    # Synthesize state machine
    print("Synthesizing state machine...")
    state_machine = get_partition_induced_state_machine(
        partition=partition_optimization_result.partition,
        predicate_hyperparameters=predicate_hyperparameters,
        seed=str(_random.random()),
    )
    print("Done")

    # Log original state machine
    state_machine_output_path = output_dir/"state_machine.zip"
    serialize_state_machine(state_machine, state_machine_output_path)
    print(f"Wrote {state_machine_output_path}")

    # Prune partition
    if prunning_error_tolerance > 0.0:
        print("Prunning partition...")
        prunned_partition = get_greedily_prunned_partition(
            partition=partition_optimization_result.partition,
            episodes=dataset.episodes,
            dt=dt,
            error_tolerance=prunning_error_tolerance,
        )
        print("Done")

        # Synthesize prunned state machine
        print("Synthesizing prunned state machine...")
        prunned_state_machine = get_partition_induced_state_machine(
            partition=prunned_partition,
            predicate_hyperparameters=predicate_hyperparameters,
            seed=str(_random.random()),
        )
        print("Done")

        # Log prunned state machine
        prunned_state_machine_output_path = output_dir/"state_machine_prunned.zip"
        serialize_state_machine(
            prunned_state_machine,
            prunned_state_machine_output_path,
        )
        print(f"Wrote {prunned_state_machine_output_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog='State machine synthesis example',
        description='Synthesize a state machine',
    )
    parser.add_argument(
        '--train_trajectory_dir',
        type=Path,
        required=True,
        help=(
            "Directory with transition dataset in the format described in "
            "this module's documentation"
        ),
    )
    parser.add_argument(
        '--output_dir',
        type=Path,
        required=True,
        help='Non-existing directory to write output files'
    )
    parser.add_argument(
        '--hidden_sizes',
        nargs='+',
        type=int,
        required=True,
        help=(
            'Hidden layer sizes for the MLPs in the state machines as a'
            ' space-separated list of integers. Use "0" if the empty list'
            ' is desired.'
        )
    )
    parser.add_argument(
        '--learning_rate',
        type=float,
        required=True,
        help='Learning rate for the state machine optimization process'
    )
    parser.add_argument(
        '--prunning_error_tolerance',
        type=float,
        required=True,
        help='Error margin for a transition to be prunned. If `0.0`, then prunning is skipped.',
    )
    parser.add_argument(
        '--partition_latent_size',
        type=int,
        required=True,
        help=(
            "Size of the mode world model's mode representation, used during partitioning."
        ),
    )
    parser.add_argument(
        '--optimization_iter_n',
        type=int,
        required=True,
        help=(
            'Number of gradient descent iterations for the state machine'
            ' optimization process'
        ),
    )
    parser.add_argument(
        '--mode_model_iter_n',
        type=int,
        required=True,
        help=(
            'Number of gradient descent iterations for the local'
            ' mode-specific model training.'
        ),
    )
    parser.add_argument(
        '--cluster_dimensionality_reduce',
        type=int,
        required=True,
        help=(
            'Dimension of the UMAP-reduced space used for clustering representations. Use `0` for no reduction.'
        ),
    )
    parser.add_argument(
        '--state_n',
        type=int,
        required=True,
        help='Number of states in the state machine',
    )
    parser.add_argument(
        '--min_island_size',
        type=int,
        required=True,
        help='Minimum size for mode "islands" in the initial partitions (smaller islands get prunned).',
    )
    parser.add_argument(
        '--dt',
        type=float,
        required=True,
        help='Integration constant for the dynamical system.',
    )
    parser.add_argument(
        '--batch_size',
        type=int,
        required=True,
        help='SGD batch size',
    )
    parser.add_argument(
        '--mutual_information_mini_batch_size',
        type=int,
        required=True,
        help='SGD batch size',
    )
    parser.add_argument(
        '--cluster_information_content_regularization_scale',
        type=float,
        required=True,
        help='Scaling for the information content regularization '
             ' term during clustering.',
    )
    parser.add_argument(
        '--cluster_mutual_information_regularization_scale',
        type=float,
        required=True,
        help='Scaling for the mutual information regularization '
             ' term during clustering.',
    )
    parser.add_argument(
        '--cuda_device',
        type=str,
        required=True,
        help='CUDA device for SGD optimization',
    )
    parser.add_argument(
        '--predicate_hyperparameters_json',
        type=Path,
        required=True,
        help='JSON with predicate synthesis hyperparameters'
    )
    parser.add_argument(
        '--seed',
        type=str,
        required=True,
        help='Random number generator seed'
    )
    args = parser.parse_args()
    args.output_dir.mkdir()
    hidden_sizes = args.hidden_sizes
    if set(hidden_sizes) == set([0]):
        hidden_sizes = list()
    if args.cluster_dimensionality_reduce == 0:
        cluster_dimensionality_reduce = None
    else:
        cluster_dimensionality_reduce = args.cluster_dimensionality_reduce

    with open(args.predicate_hyperparameters_json, "rt") as fp:
        predicate_hyperparameters = json.load(fp)

    main(
        episode_dir=args.train_trajectory_dir,
        output_dir=args.output_dir,
        hidden_sizes=hidden_sizes,
        learning_rate=args.learning_rate,
        optimization_iter_n=args.optimization_iter_n,
        mode_model_iter_n=args.mode_model_iter_n,
        cluster_dimensionality_reduce=cluster_dimensionality_reduce,
        clustering_information_content_regularization_scale=args.cluster_information_content_regularization_scale,
        clustering_mutual_information_regularization_scale=args.cluster_mutual_information_regularization_scale,
        partition_size=args.state_n,
        partition_latent_size=args.partition_latent_size,
        device=args.cuda_device,
        batch_size=args.batch_size,
        mutual_information_mini_batch_size=args.mutual_information_mini_batch_size,
        min_island_size=args.min_island_size,
        seed=args.seed,
        prunning_error_tolerance=args.prunning_error_tolerance,
        dt=args.dt,
        predicate_hyperparameters=predicate_hyperparameters,
    )
