"""Generate and serialize trajectories."""
from pathlib import Path
import random
import concurrent.futures
import argparse
import json
from swmpo.transition import Transition
from swmpo_experiments.trajectory_dataset import serialize_trajectory
import salamander_env
from salamander_env.map_parser import write_webots_project
from salamander_env.map_parser import get_random_map
from salamander_env.map_parser import Map
from typing import NewType
import subprocess
import inspect
import torch
import tempfile


Mode = NewType('Mode', int)

salamander_module_dir = Path(inspect.getfile(salamander_env)).parent


def get_preprocessed_controller_file(
    animation_output_dir: Path,
    controller_path: Path,
) -> str:
    """Returns the controller file as a string with the given
    animation output directory inserted.

    The output of this function is to be saved in the same directory as
    the controller and opened with webots.
    """
    original_line = 'output_dir = ""'
    new_line = f'output_dir = "{str(animation_output_dir.absolute())}"\n'
    with open(controller_path, "rt") as fp:
        webots_file = fp.readlines()

        preprocessed_controller_file = [
            line if line.strip() != original_line else new_line
            for line in webots_file
        ]
    return "".join(preprocessed_controller_file)


def _generate_trajectory(
    input_map: Map,
    timeout_s: float,
    animation_output_dir: Path,
) -> list[tuple[Transition, Mode]]:
    """Helper function to generate a trajectory with the corresponding
    ground-truth mode.

    Returns a list of (transition, mode) pairs, and a list of frames.
    """
    # Create a new webots directory with the given map
    with tempfile.TemporaryDirectory() as tmpdir:
        new_dir = Path(tmpdir)/"webots"
        write_webots_project(
            input_map=input_map,
            output_dir=new_dir,
        )

        # Get a preprocessed version of the webots file
        controller_path = new_dir/"controllers"/"salamander"/"salamander.py"
        new_controller_file = get_preprocessed_controller_file(
            animation_output_dir,
            controller_path=controller_path,
        )

        # Overwrite the world file with the preprocessed world file
        new_webots_file_path = new_dir/"world"/"salamander.wbt"
        with open(new_webots_file_path, "rt") as fp:
            wbt = fp.read()
        with open(new_webots_file_path, "wt") as fp:
            fp.write(wbt)

        # Overwrite the controller with the preprocessed controller file
        with open(controller_path, "wt") as fp:
            fp.write(new_controller_file)

        # Run webots
        command = [
            "xvfb-run",
            "--auto-servernum",
            "webots",
            "--stdout",
            "--stderr",
            "--batch",
            "--no-rendering",
            #"--mode=fast",
            str(new_webots_file_path),
        ]
        result = subprocess.run(
            command,
            capture_output=True,
            timeout=timeout_s,
        )
    stdout = result.stdout.decode()
    lines = stdout.split("\n")

    # Parse states and actions
    states = list()
    actions = list()
    modes = list()
    for line in lines:
        # Not all lines are controller output. Controller output is
        # JSON strings. We thus ignore all the lines that are not JSON.
        # The animation path is streamed with every observation.
        try:
            controller_info = json.loads(line)
            state = controller_info["observation"]
            mode = controller_info["mode"]
            action = controller_info["action"]
            states.append(state)
            modes.append(mode)
            actions.append(action)
        except json.JSONDecodeError:
            pass

    # Assemble trajectory
    trajectory = list[tuple[Transition, Mode]]()
    for i in range(len(states)-1):
        source_state = states[i]
        next_state = states[i+1]
        action = actions[i]
        transition = Transition(
            source_state=torch.tensor(source_state),
            action=torch.tensor(action),
            next_state=torch.tensor(next_state),
        )
        mode = modes[i]
        trajectory.append((transition, mode))

    return trajectory


class TrajectoryGenerationError(Exception):
    pass


def generate_trajectory(
    input_map: Map,
    timeout_s: float,
    output_dir: Path,
) -> None:
    animation_output_dir = output_dir/"animation"
    animation_output_dir.mkdir()
    data = _generate_trajectory(
        timeout_s=timeout_s,
        animation_output_dir=animation_output_dir,
        input_map=input_map,
    )
    print(f"Wrote {animation_output_dir}")
    transitions = [transition for transition, _ in data]
    ground_truth_modes = [int(mode) for _, mode in data]

    if len(transitions) == 0:
        # Something went wrong while generating data
        raise TrajectoryGenerationError()

    # Serialize each transition
    serialize_trajectory(
        transitions=transitions,
        ground_truth_modes=ground_truth_modes,
        output_dir=output_dir,
    )


def main(
        trajectory_n: int,
        worker_n: int,
        simulation_timeout_s: float,
        output_dir: Path,
        seed: str,
        ):
    _random = random.Random(seed)

    maps = [
        get_random_map(seed=str(_random.random()))
        for _ in range(trajectory_n)
    ]

    # Gather trajectories
    with concurrent.futures.ProcessPoolExecutor(worker_n) as p:
        futures = list()

        i = 0

        def schedule_generation(i: int):
            input_map = _random.choice(maps)
            trajectory_output_dir = output_dir/f"{i}"
            trajectory_output_dir.mkdir()
            future = p.submit(
                generate_trajectory,
                output_dir=trajectory_output_dir,
                timeout_s=simulation_timeout_s,
                input_map=input_map,
            )
            futures.append(future)

        # Initial scheduling of maps
        for _ in range(trajectory_n):
            schedule_generation(i)
            i += 1

        # Schedule more maps if previous ones failed
        while len(futures) > 0:
            try:
                future = futures.pop()
                future.result()
            except TrajectoryGenerationError:
                schedule_generation(i)
                i += 1


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog='Generate trajectories',
        description='Generate and serialize trajectories.',
    )
    parser.add_argument(
        '--output_dir',
        type=Path,
        required=True,
        help='Non-existing directory to write output files.'
    )
    parser.add_argument(
        '--trajectory_n',
        type=int,
        required=True,
        help='Number of episodes to synthesize the state machine.'
    )
    parser.add_argument(
        '--worker_n',
        type=int,
        required=True,
        help='Number of workers to sample episodes.',
    )
    parser.add_argument(
        '--seed',
        type=str,
        required=True,
        help='Random number generator seed.'
    )
    args = parser.parse_args()
    args.output_dir.mkdir()
    main(
        trajectory_n=args.trajectory_n,
        worker_n=args.worker_n,
        output_dir=args.output_dir,
        simulation_timeout_s=300,
        seed=args.seed,
    )
