import gym
from gym import vector
import h5py
from tqdm import tqdm
import math
from typing import (
    Callable,
    Literal,
    Dict,
    Optional,
    Tuple,
    List,
    Union,
    Any,
    cast,
    TypeVar,
)
from os import path
from gym.wrappers.vector_list_info import VectorListInfo
from time import sleep
import numpy as np
import torch

from datetime import datetime
from utils.env_sb3 import (
    RecordVideo,
    HistoryRecorder,
    WithReporter,
    ObsToTensor,
    TerminalOrTruncate,
    ActConverter,
    VecEnvLabeler,
)
import random
from utils.reporter import Reporter
from gym.vector import SyncVectorEnv

O = TypeVar("O")
S = TypeVar("S")


def glance(
    env: gym.Env,
    render: Union[Literal["rgb_array"], Literal["none"], Literal["human"]] = "human",
    random_seed=0,
    repeats=3,
) -> gym.Env:
    assert render in ["rgb_array", "none", "human"]
    env.seed(random_seed)
    env.action_space.seed(random_seed)
    env.observation_space.seed(random_seed)
    env.reset()

    print(env.action_space, env.observation_space)
    for _ in range(repeats):
        r = 0.0
        env.reset()
        s = False
        t = 0

        while not s:
            if render != "none":
                env.render(mode=render)
                sleep(1 / 60)

            # env.unwrapped.viewer.add_marker(pos=np.array([3.0, 2.0, 2.0]), label="goal")
            (_, rwd, stop, _) = env.step(env.action_space.sample())
            t += 1

            r += rwd

            if stop:
                print(f"rwd is: {r}, total steps: {t}")
                break
    return env



def make_vec_envs(
    envs: Union[str, Tuple[str, int]],
    max_steps: int,
    element_wrappers: List[Callable[[gym.Env], gym.Env]] = [],
    whole_wrappers: List[Callable[[gym.Env], gym.Env]] = [],
    with_reporter: Optional[Reporter] = None,
    overide_async: Optional[bool] = None,
    **kwargs
) -> gym.Env:
    vec_nums = 1 if isinstance(envs, str) else envs[1]
    env_name = envs if isinstance(envs, str) else envs[0]
    env = vector.make(
        env_name,
        vec_nums,
        wrappers=[*element_wrappers],
        asynchronous=vec_nums > 1 if overide_async is None else overide_async,
        **kwargs
    )

    for w in [
        # ObsToTensor,
        lambda e: VecEnvLabeler(e, vec_nums),
        VectorListInfo,
        TerminalOrTruncate,
        *whole_wrappers,
        lambda e: HistoryRecorder(e, max_steps, vec_nums),
        # ObsActConverter,
        ActConverter,
        lambda e: WithReporter(e, reporter=with_reporter),
    ]:
        env = w(env)

    return env


def record_video(
    env: gym.Env, algo_name: str, activate_per_episode: int = 1, name_prefix: str = ""
) -> gym.Env:
    return RecordVideo(
        env,
        f'vlog/{algo_name}_{datetime.now().strftime("%m-%d_%H-%M-%S")}',
        episode_trigger=lambda episode_id: episode_id % activate_per_episode == 0,
        name_prefix=f'{algo_name}{f"_{name_prefix}" if name_prefix != "" else ""}',
    )


def expose_markers(
    viewer,
) -> Tuple[Callable[[Tuple[float, float, float], str,], None,], Callable[[], None],]:
    def add_marker(
        pos: Tuple[float, float, float],
        label: str,
    ):
        viewer.add_marker(
            pos=np.array(pos),
            label=label,
            type=2,
            specular=0.8,
            rgba=np.array([1.0, 0.0, 0.0, 0.8]),
            size=np.array([0.5, 0.5, 0.5]),
        )

    def add_marker_dummy(*args, **kwargs):
        pass

    def remove_marker():
        del viewer._markers[:]

    def remote_marker_dummy(*args, **kwargs):
        pass

    return (
        add_marker if viewer is not None else add_marker_dummy,
        remove_marker if viewer is not None else remote_marker_dummy,
    )


def seed(seed: int):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)


def get_keys(h5file):
    keys = []

    def visitor(name, item):
        if isinstance(item, h5py.Dataset):
            keys.append(name)

    h5file.visititems(visitor)
    return keys


def get_d4rl_dataset(file_path: str, validate_env: gym.Env):
    data_dict = {}
    with h5py.File(
        path.expanduser(f"~/.d4rl/datasets/{file_path}.hdf5"), "r"
    ) as dataset_file:
        for k in tqdm(get_keys(dataset_file), desc="load datafile"):
            try:  # first try loading as an array
                data_dict[k] = dataset_file[k][:]
            except ValueError as e:  # try loading as a scalar
                data_dict[k] = dataset_file[k][()]

    # Run a few quick sanity checks
    for key in ["observations", "actions", "rewards", "terminals"]:
        assert key in data_dict, "Dataset is missing key %s" % key
    N_samples = data_dict["observations"].shape[0]
    if validate_env.observation_space.shape is not None:
        assert (
            data_dict["observations"].shape[1:] == validate_env.observation_space.shape
        ), "Observation shape does not match env: %s vs %s" % (
            str(data_dict["observations"].shape[1:]),
            str(validate_env.observation_space.shape),
        )
    assert (
        data_dict["actions"].shape[1:] == validate_env.action_space.shape
    ), "Action shape does not match env: %s vs %s" % (
        str(data_dict["actions"].shape[1:]),
        str(validate_env.action_space.shape),
    )
    if data_dict["rewards"].shape == (N_samples, 1):
        data_dict["rewards"] = data_dict["rewards"][:, 0]
    assert data_dict["rewards"].shape == (N_samples,), "Reward has wrong shape: %s" % (
        str(data_dict["rewards"].shape)
    )
    if data_dict["terminals"].shape == (N_samples, 1):
        data_dict["terminals"] = data_dict["terminals"][:, 0]
    assert data_dict["terminals"].shape == (
        N_samples,
    ), "Terminals has wrong shape: %s" % (str(data_dict["rewards"].shape))
    return data_dict
