"""Collect data using a controller."""
from terrain_mass.task import Task
from terrain_mass.gymnasium import get_distance_to_target
from swmpo_experiments.terrain_mass_utils.mpc.planner import get_plan
from swmpo_experiments.terrain_mass_utils.ground_truth_states import get_ground_truth_states
from swmpo.transition import Transition
from swmpo.transition import serialize as serialize_transition
from swmpo.transition import deserialize as deserialize_transition
from swmpo.transition import serialize
from terrain_mass.plotting import plot_animation
from dataclasses import dataclass
from pathlib import Path
import tempfile
import zipfile
import random
import torch
import json
import pickle
import shutil
import concurrent.futures
import multiprocessing as mp


TARGET_POSITION_PATH = "target_position.json"
TRANSITION_PATHS = "transition_paths.json"
TASK_PATH = "task.pickle"


@dataclass
class TaskEpisode:
    episode: list[Transition]
    target_position: tuple[float, float]
    task: Task


def get_episode_transitions(
        task: Task,
        mpc_plan_len: int,
        mpc_iter_n: int,
        simulation_step_n: int,
        mpc_initial_stdev: float,
        seed: str,
        ) -> list[Transition]:
    """Return a list of transitions by calling an MPC control
    over a single episode."""
    _random = random.Random(seed)

    # Call MPC in a loop
    states = [task.environment.get_initial_state()]
    transitions = list[Transition]()

    # Sometimes call random policy to get diverse samples
    stochastic_policy_budget = 0
    max_stochastic_policy_budget = simulation_step_n//10

    # Update MPC plan looping first action
    # (this induces a bias towards periodical motion
    # but it doesn't matter too much)
    target_position_t = torch.tensor(task.target_position)
    for _ in range(simulation_step_n):
        # Initialize plan
        plan = torch.tensor([
            (0.0, 0.0)
            for _ in range(mpc_plan_len)
        ])

        ## Decide if random policy should be given budget
        #if len(states) > 1:
        #    mode_prev = task.environment.get_terrain(states[-2])
        #    mode_current = task.environment.get_terrain(states[-1])
        #    mode_change = mode_prev != mode_current
        #else:
        #    mode_change = False
        #coin_flip = _random.choice([True, False])
        #if mode_change and coin_flip:
        #    stochastic_policy_budget = max_stochastic_policy_budget

        # Decide if random policy
        if stochastic_policy_budget > 0:
            stochastic_policy_budget -= 1
            action = torch.tensor((
                _random.random()*2-1,
                _random.random()*2-1,
            ))
        else:
            # Optimize local MPC plan
            initial_candidate_plan = torch.cat([
                plan[1:],
                plan[-1].unsqueeze(0),
            ])
            plan = get_plan(
                initial_state=states[-1],
                initial_candidate_plan=initial_candidate_plan,
                environment=task.environment,
                iter_n=mpc_iter_n,
                target_position=task.target_position,
                dt=task.dt,
                initial_stdev=mpc_initial_stdev,
                success_distance_to_target=task.success_distance_to_target,
                seed=str(_random.random()),
                action_min=task.environment.action_min,
                action_max=task.environment.action_max,
                verbose=False,
            )

            action = plan[0]

        # Step simulation
        source_state = states[-1]
        next_state = task.environment.step(
            x=source_state,
            action=action,
            dt=task.dt,
        )
        states.append(next_state)

        # Record transition
        transition = Transition(
            source_state=source_state,
            action=action,
            next_state=next_state,
        )
        transitions.append(transition)

        distance = get_distance_to_target(next_state, target_position_t)
        if distance < task.success_distance_to_target:
            break

    return transitions


def deserialize_task_episode(
    zip_path: Path,
) -> TaskEpisode:
    with tempfile.TemporaryDirectory() as tmpdirname:
        output_dir = Path(tmpdirname)

        with zipfile.ZipFile(zip_path, "r") as zip_ref:
            zip_ref.extractall(output_dir)

        with open(output_dir/TARGET_POSITION_PATH, "rt") as fp:
            target_position = json.load(fp)

        with open(output_dir/TRANSITION_PATHS, "rt") as fp:
            transition_zip_paths = json.load(fp)

        episode = [
            deserialize_transition(output_dir/transition_zip_path)
            for transition_zip_path in transition_zip_paths
        ]

        with open(output_dir/TASK_PATH, "rb") as fp:
            task = pickle.load(fp)

    task_episode = TaskEpisode(
        episode=episode,
        target_position=target_position,
        task=task,
    )
    return task_episode


def serialize_task_episode(
        task_episode: TaskEpisode,
        output_zip_path: Path,
        ):
    assert output_zip_path.suffix == ".zip"
    with tempfile.TemporaryDirectory() as tmpdirname:
        output_dir = Path(tmpdirname)

        serialized_transition_paths = [
            f"{i}.zip"
            for i in range(len(task_episode.episode))
        ]
        for transition, path in zip(
                task_episode.episode,
                serialized_transition_paths
                ):
            serialize_transition(transition, output_dir/path)

        target_position_path = output_dir/TARGET_POSITION_PATH
        with open(target_position_path, "wt") as fp:
            json.dump(task_episode.target_position, fp)

        transition_paths_path = output_dir/TRANSITION_PATHS
        with open(transition_paths_path, "wt") as fp:
            json.dump(serialized_transition_paths, fp)

        task_path = output_dir/TASK_PATH
        with open(task_path, "wb") as fp:
            pickle.dump(task_episode.task, fp)

        shutil.make_archive(
            str(output_zip_path.with_suffix("")),
            'zip',
            output_dir,
        )


def sample_task_episode(
    task: Task,
    mpc_initial_stdev: float,
    mpc_iter_n: int,
    mpc_plan_len: int,
    simulation_step_n: int,
    seed: str,
    animation_fps: int,
    output_dir: Path,
) -> None:
    """`output_dir` is assumed to exist."""
    episode = get_episode_transitions(
        task=task,
        mpc_iter_n=mpc_iter_n,
        mpc_plan_len=mpc_plan_len,
        simulation_step_n=simulation_step_n,
        mpc_initial_stdev=mpc_initial_stdev,
        seed=seed,
    )
    episode = TaskEpisode(
        episode=episode,
        target_position=task.target_position,
        task=task,
    )

    transitions = episode.episode

    # Serialize each transition
    transition_paths = [
        (f"{i}.zip").rjust(10, "0")
        for i in range(len(transitions))
    ]
    for transition, transition_path in zip(transitions, transition_paths):
        serialize(transition, output_dir/transition_path)

    # Serialize episode directory
    directory_path = output_dir/"transitions.json"
    with open(directory_path, "wt") as fp:
        json.dump(transition_paths, fp)
    print(f"Wrote {directory_path}")

    # Serialize ground truth states
    ground_truth_states_path = output_dir/"ground_truth_states.json"
    ground_truth_states = get_ground_truth_states(
        environment_instance=episode.task.environment,
        episode=transitions,
    )
    with open(ground_truth_states_path, "wt") as fp:
        json.dump(ground_truth_states, fp)
    print(f"Wrote {ground_truth_states_path}")

    # Plot trajectory animation
    episode_animation_path = output_dir/"trajectory.mp4"
    states = [
        transition.next_state
        for transition in transitions
    ]
    plot_animation(
        states,
        mass_radius=0.1,
        environment=episode.task.environment,
        fps=animation_fps,
        target_position=episode.task.target_position,
        output_path=episode_animation_path,
    )
    print(f"Wrote {episode_animation_path}")


def sample_task_episodes(
    worker_n: int,
    seed: str,
    tasks: list[Task],
    mpc_iter_n: int,
    mpc_plan_len: int,
    simulation_step_n: int,
    mpc_initial_stdev: float,
    animation_fps: int,
    output_dir: Path,
) -> None:
    """`output_dir` is assumed to exist."""
    _random = random.Random(seed)

    ctx = mp.get_context('spawn')
    with concurrent.futures.ProcessPoolExecutor(worker_n, mp_context=ctx) as p:
        episode_paths_futures = list()
        for i, task in enumerate(tasks):
            episode_dir = output_dir/f"{i}"
            episode_dir.mkdir()
            future = p.submit(
                sample_task_episode,
                task=task,
                mpc_iter_n=mpc_iter_n,
                mpc_plan_len=mpc_plan_len,
                mpc_initial_stdev=mpc_initial_stdev,
                simulation_step_n=simulation_step_n,
                animation_fps=animation_fps,
                output_dir=episode_dir,
                seed=str(_random.random()),
            )
            episode_paths_futures.append(future)

        for _, future in enumerate(episode_paths_futures):
            future.result()
            print(f"Collecting data ({_+1}/{len(tasks)})...")
