"""Synthesize a HMM."""
import os
os.environ['OPENBLAS_NUM_THREADS'] = '1'
from swmpo_experiments.visited_states_plotting import plot_visited_states
from swmpo_experiments.state_machine_synthesis import deserialize_dataset
from swmpo.transition import Transition
from swmpo.transition import deserialize
from swmpo.sequence_distance import get_best_permutation
from pathlib import Path
import argparse
import concurrent.futures
import json
import random
import pickle
import hmmlearn.hmm
import torch

# Avoid pytorch from doing threading. This is so that the script doesn't
# take over the computer's resources. You can remove these lines if not running
# on a lab computer.
torch.set_num_threads(1)


def get_transition_vector(t: Transition) -> torch.Tensor:
    return torch.concatenate([
        t.source_state,
        t.action,
        t.next_state,
    ])


def get_hmm_visited_states(
        hmm: hmmlearn.hmm.GaussianHMM,
        episode: list[Transition],
        ) -> list[int]:
    X = [
        get_transition_vector(t).tolist()
        for t in episode
    ]
    visited_states = [
        int(x)
        for x in hmm.predict(X)
    ]
    return visited_states


def plot_hmm_episode(
        trajectory_dir: Path,
        hmm_path: Path,
        output_dir: Path,
        ):
    # Deserialize episode trajectory
    episode_json = trajectory_dir/"transitions.json"
    with open(episode_json, "rt") as fp:
        transition_paths = json.load(fp)
    episode = [
        deserialize(trajectory_dir/transition_path)
        for transition_path in transition_paths
    ]

    # Deserialize model
    with open(hmm_path, "rb") as file:
        hmm: hmmlearn.hmm.GaussianHMM = pickle.load(file)

    # Deserialize ground truth states
    ground_truth_visited_states_path = trajectory_dir/"ground_truth_states.json"
    with open(ground_truth_visited_states_path, "rt") as fp:
        ground_truth_visited_states = json.load(fp)

    # Run HMM
    visited_states = get_hmm_visited_states(hmm, episode)

    # Write states
    visited_states_path = output_dir/"visited_states.json"
    json_str = json.dumps(visited_states, indent=2)
    with open(visited_states_path, "wt") as fp:
        fp.write(json_str)
    print(f"Wrote {visited_states_path}")

    # Plot state machine error
    visited_states_plot_path = output_dir/"visited_states.svg"
    new_visited_states = get_best_permutation(
        visited_states,
        ground_truth_visited_states,
        initial_state=-1,  # HMM is unconstrained to find best permutation
    )
    plot_visited_states(
        output_path=visited_states_plot_path,
        available_indices=set(new_visited_states)|set(ground_truth_visited_states),
        visited_states=dict(
            hmm=new_visited_states,
            ground_truth=ground_truth_visited_states,
        ),
    )
    print(f"Wrote {visited_states_plot_path}")


def log_hmm(
        hmm: hmmlearn.hmm.GaussianHMM,
        episode_dir: Path,
        output_dir: Path,
        p: concurrent.futures.ProcessPoolExecutor,
        ):
    # Serialize model
    hmm_path = output_dir/"hmm.pkl"
    with open(hmm_path, "wb") as file:
        pickle.dump(hmm, file)
    print(f"Serialized HMM to {hmm_path}")

    # Serialize state machine states for each episode
    episodes_dir = output_dir/"input_episodes"
    episodes_dir.mkdir()
    futures = list()
    for i, trajectory_json in enumerate(episode_dir.glob("**/transitions.json")):
        trajectory_dir = trajectory_json.parent
        episode_output_dir = output_dir/f"{i}"
        episode_output_dir.mkdir()
        future = p.submit(
            plot_hmm_episode,
            trajectory_dir=trajectory_dir,
            hmm_path=hmm_path,
            output_dir=episode_output_dir,
        )
        futures.append(future)

    for future in futures:
        future.result()


def main(
        episode_dir: Path,
        output_dir: Path,
        worker_n: int,
        component_n: int,
        seed: str,
        ):
    _random = random.Random(seed)

    # Deserialize dataset
    dataset = deserialize_dataset(episode_dir)
    episodes = dataset.episodes

    # Synthesize HMM
    hmm = hmmlearn.hmm.GaussianHMM(
        n_components=component_n,
        random_state=int.from_bytes(_random.randbytes(3), 'big', signed=False),
    )
    X = [
        get_transition_vector(t).tolist()
        for episode in episodes
        for t in episode
    ]
    lengths = [len(t) for t in episodes]
    hmm.fit(X, lengths=lengths)

    # Log HMM
    with concurrent.futures.ProcessPoolExecutor(worker_n) as p:
        log_hmm(
            hmm=hmm,
            episode_dir=episode_dir,
            output_dir=output_dir,
            p=p,
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog='HMM synthesis example',
        description='Synthesize a HMM',
    )
    parser.add_argument(
        '--output_dir',
        type=Path,
        required=True,
        help='Non-existing directory to write output files'
    )
    parser.add_argument(
        '--component_n',
        type=int,
        required=True,
        help='Number of states in the state machine',
    )
    parser.add_argument(
        '--worker_n',
        type=int,
        required=True,
        help='Number of workers to sample episodes',
    )
    parser.add_argument(
        '--seed',
        type=str,
        required=True,
        help='Random number generator seed.'
    )
    parser.add_argument(
        '--train_trajectory_dir',
        type=Path,
        required=True,
        help=(
            "Directory with transition dataset in the format described in "
            "this module's documentation"
        ),
    )
    args = parser.parse_args()
    args.output_dir.mkdir()
    main(
        episode_dir=args.train_trajectory_dir,
        output_dir=args.output_dir,
        component_n=args.component_n,
        seed=args.seed,
        worker_n=args.worker_n,
    )
