"""Generate and serialize trajectories."""
from swmpo.transition import serialize
from swmpo.transition import Transition
from pathlib import Path
import concurrent.futures
import argparse
import random
from autonomous_car_verification.simulator.Car import World
from autonomous_car_verification.simulator.Car import THROTTLE
from autonomous_car_verification.simulator.plot_trajectories_7a import Modes
from autonomous_car_verification.simulator.plot_trajectories_7a import ComposedModePredictor
from autonomous_car_verification.simulator.plot_trajectories_7a import ComposedSteeringPredictor
from autonomous_car_verification.simulator.plot_trajectories_7a import reverse_lidar
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
import matplotlib as mpl
import json
import inspect
import torch
import autonomous_car_verification


def get_world(dt: float, seed: str) -> World:
    """Generate a random driving world."""
    # Instantiate environment
    world = World(time_step=dt)
    return world


def generate_trajectory(
        dt: float,
        seed: str,
        output_dir: Path,
        ) -> tuple[list[Path], Path]:
    """Generate trajectories to the given directory, which is assumed to
    exist.

    Returns a list of paths of the transitions, relative to the output
    directory; and a plot of the trajectory.
    """
    w = get_world(dt, seed)
    observation, _ = w.reset(seed=seed)

    # Load pretrained policy
    pretrained_dir = Path(inspect.getfile(autonomous_car_verification.simulator.Car)).parent
    mode_predictor = ComposedModePredictor(
        pretrained_dir/'big.yml',
        pretrained_dir/'straight_little.yml',
        pretrained_dir/'square_right_little.yml',
        pretrained_dir/'square_left_little.yml',
        pretrained_dir/'sharp_right_little.yml',
        pretrained_dir/'sharp_left_little.yml',
        True,
    )
    action_scale = float(w.action_space.high[0])
    steering_ctrl = ComposedSteeringPredictor(
        pretrained_dir/'tanh64x64_right_turn_lidar.yml',
        pretrained_dir/'tanh64x64_sharp_turn_lidar.yml',
        action_scale,
    )

    # Ground mode truth mapping
    mode_map = {
        mode: i
        for i, mode in enumerate(sorted(Modes, key=lambda m: m.value))
    }

    # Generate trajectory
    transitions = list()
    ground_truth_states = list()
    while True:
        source_state = torch.tensor(observation.tolist())

        # Predict policy mode
        mode = mode_predictor.predict(observation)
        if mode == Modes.SQUARE_LEFT or mode == Modes.SHARP_LEFT:
            observation = reverse_lidar(observation)

        # Call modular policy
        delta = steering_ctrl.predict(observation, mode)

        # Step environment
        observation, reward, terminated, truncated, info = w.step(delta, THROTTLE)

        # Store transition
        next_state = torch.tensor(observation.tolist())
        action = torch.tensor(delta[0].tolist())
        transition = Transition(
            source_state=source_state,
            action=action,
            next_state=next_state,
        )
        transitions.append(transition)
        ground_truth_states.append(mode_map[mode])

        done = terminated or truncated
        if done:
            break

    # Predict policy mode for last observation
    mode = mode_predictor.predict(observation)
    ground_truth_states.append(mode_map[mode])

    # Serialize each transition
    transition_paths = [
        (f"{i}.zip").rjust(10, "0")
        for i in range(len(transitions))
    ]
    for transition, transition_path in zip(transitions, transition_paths):
        serialize(transition, output_dir/transition_path)

    # Serialize episode directory
    directory_path = output_dir/"transitions.json"
    with open(directory_path, "wt") as fp:
        json.dump(transition_paths, fp)
    print(f"Wrote {directory_path}")

    # Serialize ground truth states
    ground_truth_states_path = output_dir/"ground_truth_states.json"
    with open(ground_truth_states_path, "wt") as fp:
        json.dump(ground_truth_states, fp)
    print(f"Wrote {ground_truth_states_path}")

    # Create matplotlib figure
    figure_path = output_dir/"trajectory.png"
    fig = Figure()
    _ = FigureCanvas(fig)
    ax = fig.add_subplot()
    w.plotHalls(ax=ax)
    cmap = mpl.colormaps['Pastel1']
    mode_cmap = cmap(
        list(mode_map.values())
    )
    mode_sequence = [
        mode_cmap[mode]
        for mode in ground_truth_states
    ]
    ax.scatter(
        w.allX, w.allY, s=5, c=mode_sequence, label='straight'
    )
    fig.savefig(figure_path)
    print(f"Wrote {figure_path}")

    return transition_paths, figure_path


def main(
        trajectory_n: int,
        worker_n: int,
        output_dir: Path,
        dt: float,
        seed: str,
        ):
    _random = random.Random(seed)

    # Gather trajectories
    with concurrent.futures.ProcessPoolExecutor(worker_n) as p:
        futures = list()
        for i in range(trajectory_n):
            trajectory_output_dir = output_dir/f"{i}"
            future = p.submit(
                generate_trajectory,
                output_dir=trajectory_output_dir,
                dt=dt,
                seed=str(_random.random()),
            )
            futures.append(future)

        for future in futures:
            future.result()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog='Generate trajectories',
        description='Generate and serialize trajectories.',
    )
    parser.add_argument(
        '--output_dir',
        type=Path,
        required=True,
        help='Non-existing directory to write output files.'
    )
    parser.add_argument(
        '--trajectory_n',
        type=int,
        required=True,
        help='Number of episodes to synthesize the state machine.'
    )
    parser.add_argument(
        '--dt',
        type=float,
        required=True,
        help='Integration constant for the dynamical system.',
    )
    parser.add_argument(
        '--worker_n',
        type=int,
        required=True,
        help='Number of workers to sample episodes.',
    )
    parser.add_argument(
        '--seed',
        type=str,
        required=True,
        help='Random number generator seed.'
    )
    args = parser.parse_args()
    args.output_dir.mkdir()
    main(
        trajectory_n=args.trajectory_n,
        worker_n=args.worker_n,
        output_dir=args.output_dir,
        dt=args.dt,
        seed=args.seed,
    )
