"""Trajectory datasets.


Datasets are expected to have the following structure:

```
dataset_dir/
    0/
        transitions.json
        0.zip
        1.zip
        ...
    1/
        ...
    ...
```

With `transitions.json` a list of the form `["0.zip", "1.zip", ...]`.

It is assumed every serialized tensor is `float32`.
"""
from swmpo.transition import deserialize
from swmpo.transition import Transition
import json
from dataclasses import dataclass
from pathlib import Path
from swmpo.transition import serialize


@dataclass
class Dataset:
    episodes: list[list[Transition]]
    ground_truth_modes: list[list[int]]


def deserialize_dataset(dataset_dir: Path) -> Dataset:
    episodes = list()
    ground_truths = list()
    for episode_json in sorted(dataset_dir.glob("**/transitions.json")):
        with open(episode_json, "rt") as fp:
            transition_paths = json.load(fp)
        episode_dir = episode_json.parent
        episode = [
            deserialize(episode_dir/transition_path)
            for transition_path in transition_paths
        ]
        episode_dir = episode_json.parent
        ground_truth_visited_states_path = episode_dir/"ground_truth_states.json"
        with open(ground_truth_visited_states_path, "rt") as fp:
            ground_truth_visited_states = json.load(fp)
        ground_truths.append(ground_truth_visited_states)
        episodes.append(episode)
    dataset = Dataset(
        episodes=episodes,
        ground_truth_modes=ground_truths,
    )
    return dataset


def serialize_trajectory(
    transitions: list[Transition],
    ground_truth_modes: list[int],
    output_dir: Path,
):
    # Serialize each transition
    transition_paths = [
        (f"{i}.zip").rjust(10, "0")
        for i in range(len(transitions))
    ]
    for transition, transition_path in zip(transitions, transition_paths):
        serialize(transition, output_dir/transition_path)

    # Serialize episode directory
    directory_path = output_dir/"transitions.json"
    with open(directory_path, "wt") as fp:
        json.dump(transition_paths, fp)
    print(f"Wrote {directory_path}")

    # Serialize ground truth states
    ground_truth_states_path = output_dir/"ground_truth_states.json"
    with open(ground_truth_states_path, "wt") as fp:
        json.dump(ground_truth_modes, fp)
    print(f"Wrote {ground_truth_states_path}")
