"""A transition is a triple consisting of a source state, an action
and a next state.

Each of those is a one-dimensional `torch.Tensor`.
"""
from dataclasses import dataclass
from pathlib import Path
import tempfile
import shutil
import zipfile
import torch


@dataclass
class Transition:
    source_state: torch.Tensor
    action: torch.Tensor
    next_state: torch.Tensor


def equals(
    t1: Transition,
    t2: Transition,
) -> bool:
    if not torch.equal(t1.source_state, t2.source_state):
        return False
    if not torch.equal(t1.action, t2.action):
        return False
    if not torch.equal(t1.next_state, t2.next_state):
        return False
    return True


def get_vector(
    transition: Transition
) -> torch.Tensor:
    """Return the transition as a single vector."""
    vector = torch.cat([
        transition.source_state,
        transition.action,
        transition.next_state,
    ])
    return vector


SOURCE_STATE_PATH = "source_state.pt"
ACTION_PATH = "action.pt"
NEXT_STATE_PATH = "next_state.pt"


def serialize(t: Transition, output_zip_path: Path):
    """Serialize a transition into a ZIP file."""
    assert output_zip_path.suffix == ".zip"
    with tempfile.TemporaryDirectory() as tmpdirname:
        output_dir = Path(tmpdirname)
        torch.save(t.source_state, output_dir/SOURCE_STATE_PATH)
        torch.save(t.action, output_dir/ACTION_PATH)
        torch.save(t.next_state, output_dir/NEXT_STATE_PATH)
        shutil.make_archive(str(output_zip_path.with_suffix("")), 'zip', output_dir)


def deserialize(zip_path: Path) -> Transition:
    with tempfile.TemporaryDirectory() as tmpdirname:
        output_dir = Path(tmpdirname)

        # Load local models hidden sizes
        with zipfile.ZipFile(zip_path, "r") as zip_ref:
            zip_ref.extractall(output_dir)

        source_state = torch.load(output_dir/SOURCE_STATE_PATH, weights_only=True)
        action = torch.load(output_dir/ACTION_PATH, weights_only=True)
        next_state = torch.load(output_dir/NEXT_STATE_PATH, weights_only=True)

    return Transition(
        source_state=source_state,
        action=action,
        next_state=next_state,
    )
