from argparse import ArgumentParser
from collections import defaultdict
from collections.abc import Sequence
from pathlib import Path

import gymnasium
import h5py
import numpy as np
from numpy.random import Generator, default_rng
import tomlkit
from tqdm import tqdm

from offline.envs.custom import DATASETS_CONFIG_ROOT
from offline.envs.custom.navigate import ACTION_BOUND, LARGE_SIZE, SIZE
from offline.envs.utils import DATA_ROOT
from offline.types import FloatArray

SUBGOALS: tuple[tuple[FloatArray, ...], ...] = (
    (
        np.asarray([0, 3], dtype=np.float32),
        np.asarray([1, 3], dtype=np.float32),
        np.asarray([1, 0], dtype=np.float32),
        np.asarray([2, 0], dtype=np.float32),
        np.asarray([2, 3], dtype=np.float32),
        np.asarray([3, 3], dtype=np.float32),
    ),
    (
        np.asarray([1, 0], dtype=np.float32),
        np.asarray([1, 1.5], dtype=np.float32),
        np.asarray([2, 1.5], dtype=np.float32),
        np.asarray([2, 0], dtype=np.float32),
        np.asarray([3, 0], dtype=np.float32),
        np.asarray([3, 3], dtype=np.float32),
    ),
)


def build_argument_parser():
    parser = ArgumentParser()
    parser.add_argument("--force", action="store_true")
    parser.add_argument("--large", action="store_true")
    parser.add_argument("--noise", type=float, default=0.5)
    parser.add_argument("--penalty", action="store_true")
    parser.add_argument("--policy", type=int, default=0, dest="policy_type")
    parser.add_argument("--render", action="store_true")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("-s", "--silent", action="store_true")
    parser.add_argument("--steps", type=int, default=1000000)
    return parser


def policy(
    observation: FloatArray,
    subgoals: Sequence[FloatArray],
    action_space: gymnasium.spaces.Box,
    loc: FloatArray,
    scale: FloatArray,
):
    difference = subgoals[0] - observation
    if np.linalg.norm(difference) < 0.1:
        if subgoals[1:]:
            subgoals = subgoals[1:]
            difference = subgoals[0] - observation
    action = np.clip(difference, action_space.low, action_space.high)
    action = (action - loc) / scale
    action = np.clip(np.arctanh(action), -2, 2)
    return action, subgoals


def rollout(
    env: gymnasium.Env,
    large: bool,
    noise: float,
    steps: int,
    rng: Generator | None,
    seed: int,
    subgoals: Sequence[FloatArray],
) -> dict[str, np.ndarray]:
    results: dict[str, list] = defaultdict(list)
    assert isinstance(env.action_space, gymnasium.spaces.Box)
    loc = (env.action_space.high + env.action_space.low) / 2
    scale = (env.action_space.high - env.action_space.low) / 2
    num_samples = 0
    observation, _ = env.reset(seed=seed)
    if large:
        subgoals = tuple(sg * LARGE_SIZE / SIZE for sg in subgoals)
    initial_subgoals = subgoals
    with tqdm(total=steps) as progress_bar:
        while num_samples < steps:
            done = False
            subgoals = initial_subgoals
            while not done:
                results["observations"].append(observation)
                action, subgoals = policy(
                    observation, subgoals, env.action_space, loc, scale
                )
                if rng is None:
                    white_noise: float | FloatArray = 0
                else:
                    white_noise = rng.normal(0, noise, env.action_space.shape)
                action = np.tanh(action + white_noise)
                action = loc + action * scale
                (
                    observation,
                    reward,
                    terminal,
                    timeout,
                    info,
                ) = env.step(action)
                num_samples += 1
                progress_bar.update(1)
                results["actions"].append(action)
                results["rewards"].append(reward)
                results["terminals"].append(terminal)
                results["timeouts"].append(timeout)
                for key, value in info.items():
                    results[f"infos/{key}"].append(value)
                done = terminal or timeout
            observation, _ = env.reset()
    return {key: np.stack(value) for key, value in results.items()}


def save(
    *,
    config_file_path: Path,
    data_file_path: Path,
    dataset_name: str,
    env_id: str,
    max_score: float,
    min_score: float,
    results: dict[str, np.ndarray],
):
    with h5py.File(data_file_path, "w") as file:
        for key, value in results.items():
            file[key] = value
    env_spec = {
        "dataset_name": dataset_name,
        "env_id": env_id,
        "ref_max_score": max_score,
        "ref_min_score": min_score,
    }
    with open(config_file_path, "w", encoding="utf-8") as file:
        tomlkit.dump(env_spec, file)


def main(
    force: bool,
    large: bool,
    noise: float,
    penalty: bool,
    policy_type: int,
    render: bool,
    seed: int,
    silent: bool,
    steps: int,
):
    DATA_ROOT.mkdir(exist_ok=True)
    DATASETS_CONFIG_ROOT.mkdir(exist_ok=True, parents=True)
    label = (
        f"OfflineNavigate{'Large' if large else ''}"
        f"{'Penalty' if penalty else ''}Policy{policy_type}-v0"
    )

    file_name = label.replace("-", "_")
    config_file_path = DATASETS_CONFIG_ROOT / f"{file_name}.toml"
    data_file_path = DATA_ROOT / f"{file_name}.hdf5"
    if not force:
        if config_file_path.is_file():
            raise ValueError(f"Config file already exists: {config_file_path}")
        if data_file_path.is_file():
            raise ValueError(f"Data file already exists: {data_file_path}")
    env_name = (
        f"Navigate{'Large' if large else ''}"
        f"{'Penalty' if penalty else ''}-v0"
    )
    env = gymnasium.make(
        env_name, render_mode="human" if render else "rgb_array"
    )
    rng = default_rng(seed) if noise > 0 else None

    results = rollout(
        env=env,
        large=large,
        noise=noise,
        steps=steps,
        rng=rng,
        seed=seed,
        subgoals=SUBGOALS[policy_type],
    )

    if not silent:
        assert env.spec is not None
        max_episode_steps = env.spec.max_episode_steps
        assert max_episode_steps is not None
        save(
            config_file_path=config_file_path,
            data_file_path=data_file_path,
            dataset_name=label,
            env_id=env_name,
            max_score=-1 * env.spec.kwargs["size"] * 2 / ACTION_BOUND,
            min_score=-1 * max_episode_steps,
            results=results,
        )


if __name__ == "__main__":
    main(**vars(build_argument_parser().parse_args()))
