"""Plot the behavior of a state machine over a dataset of trajectories.

The dataset is expected to have the following structure:

```
dataset_dir/
    0/
        transitions.json
        ground_truth_states.json
        0.zip
        1.zip
        ...
    1/
        ...
    ...
```

With `transitions.json` a list of the form `["0.zip", "1.zip", ...]`.

It is assumed every serialized tensor is `float32`.
"""
from swmpo.state_machine import deserialize_state_machine
from swmpo.transition import deserialize
from swmpo.state_machine import get_visited_states
from swmpo.plotting import plot_state_machine_errors
from pathlib import Path
import argparse
import json


def main(
        episode_dir: Path,
        output_dir: Path,
        state_machine_zip: Path,
        dt: float,
        ):
    state_machine = deserialize_state_machine(state_machine_zip)

    for episode_json in episode_dir.glob("**/transitions.json"):
        # Deserialize episode
        with open(episode_json, "rt") as fp:
            transition_paths = json.load(fp)
        episode_dir = episode_json.parent
        episode = [
            deserialize(episode_dir/transition_path)
            for transition_path in transition_paths
        ]

        # Deserialize ground truth states
        ground_truth_states_path = episode_dir/"ground_truth_states.json"
        with open(ground_truth_states_path, "rt") as fp:
            ground_truth_states = json.load(fp)

        # Create output dir
        episode_name = episode_dir.name
        output_episode_dir = output_dir/episode_name
        output_episode_dir.mkdir()

        # Get visited states
        visited_states = get_visited_states(
            state_machine=state_machine,
            initial_state=0,  # we always start in state 0
            episode=episode,
            dt=dt,
        )

        # Write states
        visited_states_path = output_episode_dir/"visited_states.json"
        json_str = json.dumps(visited_states, indent=2)
        with open(visited_states_path, "wt") as fp:
            fp.write(json_str)
        print(f"Wrote {visited_states_path}")

        # Plot state machine error
        error_plot_path = output_episode_dir/"state_machine_errors.svg"
        plot_state_machine_errors(
            state_machine=state_machine,
            episode=episode,
            output_path=error_plot_path,
            initial_state=0,
            ground_truth_visited_states=ground_truth_states,
            dt=dt,
        )
        print(f"Wrote {error_plot_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog='State machine synthesis example',
        description='Synthesize a state machine',
    )
    parser.add_argument(
        '--test_trajectory_dir',
        type=Path,
        required=True,
        help=(
            "Directory with transition dataset in the format described in "
            "this module's documentation"
        ),
    )
    parser.add_argument(
        '--output_dir',
        type=Path,
        required=True,
        help='Non-existing directory to write output files'
    )
    parser.add_argument(
        '--state_machine_zip',
        type=Path,
        required=True,
        help='Serialized state machine.'
    )
    parser.add_argument(
        '--dt',
        type=float,
        required=True,
        help='Integration constant for the dynamical system.',
    )
    args = parser.parse_args()
    args.output_dir.mkdir()
    main(
        episode_dir=args.test_trajectory_dir,
        output_dir=args.output_dir,
        state_machine_zip=args.state_machine_zip,
        dt=args.dt,
    )
