from __future__ import annotations

from typing import Optional

import gym
import time
import numpy as np
import torch

from .walker_env import EnvConfig, BipedalWalkerCustom
from envs.registration import register as gym_register


def get_config(
    name="default",
    ground_roughness=0,
    pit_gap=[],
    stump_width=[],
    stump_height=[],
    stump_float=[],
    stair_height=[],
    stair_width=[],
    stair_steps=[],
):

    config = EnvConfig(
        name=name,
        ground_roughness=ground_roughness,
        pit_gap=pit_gap,
        stump_width=stump_width,
        stump_height=stump_height,
        stump_float=stump_float,
        stair_height=stair_height,
        stair_width=stair_width,
        stair_steps=stair_steps,
    )

    return config


class SeededBipedalWalker(BipedalWalkerCustom):

    def __init__(self, env_config: EnvConfig, seed: int | None = None):
        super().__init__(env_config=env_config, seed=seed)
        self._initial_seed = seed

    def reset(self, seed=None, **kwargs):

        # assert seed is not None

        if seed is not None:
            self._initial_seed = int(seed)

        super().seed(self._initial_seed)
        obs = super()._reset_env()
        return obs


# ---------------------------------------------------------------------
# default env
# ---------------------------------------------------------------------
class BipedalWalkerDefault(SeededBipedalWalker):
    def __init__(self, seed: int | None = None):
        config = get_config()
        super().__init__(env_config=config, seed=seed)


## stump height
class BipedalWalkerMedStumps(SeededBipedalWalker):
    def __init__(self, seed: int | None = None):
        config = get_config(stump_height=[2, 2], stump_width=[1, 2], stump_float=[0, 1])
        super().__init__(env_config=config, seed=seed)


class BipedalWalkerHighStumps(SeededBipedalWalker):
    def __init__(self, seed: int | None = None):
        config = get_config(stump_height=[5, 5], stump_width=[1, 2], stump_float=[0, 1])
        super().__init__(env_config=config, seed=seed)


## pit gap
class BipedalWalkerMedPits(SeededBipedalWalker):
    def __init__(self, seed: int | None = None):
        config = get_config(pit_gap=[5, 5])
        super().__init__(env_config=config, seed=seed)


class BipedalWalkerWidePits(SeededBipedalWalker):
    def __init__(self, seed: int | None = None):
        config = get_config(pit_gap=[10, 10])
        super().__init__(env_config=config, seed=seed)


# stair height + number of stairs
class BipedalWalkerMedStairs(SeededBipedalWalker):
    def __init__(self, seed: int | None = None):
        config = get_config(stair_height=[2, 2], stair_steps=[5], stair_width=[4, 5])
        super().__init__(env_config=config, seed=seed)


class BipedalWalkerHighStairs(SeededBipedalWalker):
    def __init__(self, seed: int | None = None):
        config = get_config(stair_height=[5, 5], stair_steps=[9], stair_width=[4, 5])
        super().__init__(env_config=config, seed=seed)


# ground roughness
class BipedalWalkerMedRoughness(SeededBipedalWalker):
    def __init__(self, seed: int | None = None):
        config = get_config(ground_roughness=5)
        super().__init__(env_config=config, seed=seed)


class BipedalWalkerHighRoughness(SeededBipedalWalker):
    def __init__(self, seed: int | None = None):
        config = get_config(ground_roughness=9)
        super().__init__(env_config=config, seed=seed)


# everything maxed out
class BipedalWalkerInsane(SeededBipedalWalker):
    def __init__(self, seed: int | None = None):
        config = get_config(
            stump_height=[5, 5],
            stump_width=[1, 2],
            stump_float=[0, 1],
            pit_gap=[10, 10],
            stair_height=[5, 5],
            stair_steps=[9],
            stair_width=[4, 5],
            ground_roughness=9,
        )
        super().__init__(env_config=config, seed=seed)


# Boundary environment
class BipedalWalkerBoundary(SeededBipedalWalker):
    def __init__(self, seed: int | None = None):
        config = get_config(
            stump_height=[2.2, 2.4],
            stump_width=[1, 2],
            stump_float=[0, 1],
            pit_gap=[4, 6],
            stair_height=[2.0, 2.2],
            stair_steps=[3, 6],
            stair_width=[4, 5],
            ground_roughness=6,
        )
        super().__init__(env_config=config, seed=seed)


class BipedalWalkerFull(SeededBipedalWalker):
    """
    - seed is fixed to 1
    - stair_steps: list[int]
    - everything else is float (as intervals in a list[float])
    - stump_float / stump_width / stair_width use the fixed defaults
    """

    def __init__(
        self,
        *,
        ground_roughness: float = 9.0,
        stair_steps: list[int] = [9],
        stair_height: list[float] = [5.0, 5.0],
        stump_height: list[float] = [5.0, 5.0],
        pit_gap: list[float] = [10.0, 10.0],
    ):
        config = get_config(
            stump_height=stump_height,
            stump_width=[1.0, 2.0],  # default
            stump_float=[0.0, 1.0],  # default
            pit_gap=pit_gap,
            stair_height=stair_height,
            stair_steps=stair_steps,
            stair_width=[4.0, 5.0],  # default
            ground_roughness=ground_roughness,
        )
        super().__init__(env_config=config, seed=1)


# ---------------------------------------------------------------------


# ---------------------------------------------------------------------
class BipedalWalkerZeroShot(BipedalWalkerCustom):

    def __init__(self, seed: int | None = None):
        base_config = get_config(
            stump_height=[],
            stump_width=[],
            stump_float=[],
            pit_gap=[],
            stair_height=[],
            stair_steps=0,
            stair_width=[],
            ground_roughness=0,
        )
        super().__init__(env_config=base_config, seed=seed)
        self.level_seed = seed

    def reset(self, seed=None, **kwargs):
        if seed is not None:
            self.level_seed = int(seed)
        elif self.level_seed is None:
            self.level_seed = int(str(time.time() / 1000)[-6:])

        super().seed(self.level_seed)

        # Use deterministic RNG to sample parameters
        rng = np.random.RandomState(self.level_seed)

        stump_high = rng.uniform(2.4, 2.6)
        gap_high = rng.uniform(6, 8)
        roughness = rng.uniform(6, 8)

        config = get_config(
            stump_height=[2.4, stump_high],
            stump_width=[1, 2],
            stump_float=[0, 1],
            pit_gap=[5, gap_high],
            stair_height=[2, 2.4],
            stair_steps=[5, 9],
            stair_width=[],
            ground_roughness=roughness,
        )

        super().re_init(config, self.level_seed)
        return super()._reset_env()


## PCG "Extremely Challenging" Env
class BipedalWalkerXChal(BipedalWalkerCustom):

    def __init__(self, seed: int | None = None):
        base_config = get_config(
            stump_height=[],
            stump_width=[],
            stump_float=[],
            pit_gap=[],
            stair_height=[],
            stair_steps=0,
            stair_width=[],
            ground_roughness=0,
        )
        super().__init__(env_config=base_config, seed=seed)
        self.level_seed = seed

    def reset(self, seed=None, **kwargs):
        if seed is not None:
            self.level_seed = int(seed)
        elif self.level_seed is None:
            self.level_seed = int(str(time.time() / 1000)[-6:])

        super().seed(self.level_seed)

        rng = np.random.RandomState(self.level_seed)

        stump_high = rng.uniform(2.4, 3.0)
        gap_high = rng.uniform(6, 8)
        roughness = rng.uniform(4.5, 8)

        config = get_config(
            stump_height=[0, stump_high],
            stump_width=[1, 2],
            stump_float=[0, 1],
            pit_gap=[0, gap_high],
            stair_height=[],
            stair_steps=0,
            stair_width=[],
            ground_roughness=roughness,
        )

        super().re_init(config, self.level_seed)
        return super()._reset_env()


## POET Rose
roses = {
    "1a": [5.6, 2.4, 2.82, 6.4, 4.48],
    "1b": [5.44, 1.8, 2.82, 6.72, 4.48],
    "2a": [7.2, 1.98, 2.82, 7.2, 5.6],
    "2b": [5.76, 2.16, 2.76, 7.2, 1.6],
    "3a": [5.28, 1.98, 2.76, 7.2, 4.8],
    "3b": [4.8, 2.4, 2.76, 4.48, 4.8],
}


class BipedalWalkerPOETRose(SeededBipedalWalker):
    def __init__(self, rose_id="1a", seed: int | None = None):
        id_vals = roses[rose_id]
        config = get_config(
            stump_height=[id_vals[1], id_vals[2]],
            stump_width=[1, 2],
            stump_float=[0, 1],
            pit_gap=[id_vals[4], id_vals[3]],
            stair_height=[],
            stair_steps=[],
            stair_width=[],
            ground_roughness=id_vals[0],
        )
        super().__init__(env_config=config, seed=seed)


class BipedalWalkerPOETRose1a(BipedalWalkerPOETRose):
    def __init__(self, seed: int | None = None):
        super().__init__(rose_id="1a", seed=seed)


class BipedalWalkerPOETRose1b(BipedalWalkerPOETRose):
    def __init__(self, seed: int | None = None):
        super().__init__(rose_id="1b", seed=seed)


class BipedalWalkerPOETRose2a(BipedalWalkerPOETRose):
    def __init__(self, seed: int | None = None):
        super().__init__(rose_id="2a", seed=seed)


class BipedalWalkerPOETRose2b(BipedalWalkerPOETRose):
    def __init__(self, seed: int | None = None):
        super().__init__(rose_id="2b", seed=seed)


class BipedalWalkerPOETRose3a(BipedalWalkerPOETRose):
    def __init__(self, seed: int | None = None):
        super().__init__(rose_id="3a", seed=seed)


class BipedalWalkerPOETRose3b(BipedalWalkerPOETRose):
    def __init__(self, seed: int | None = None):
        super().__init__(rose_id="3b", seed=seed)


# ---------------------------------------------------------------------
# Gym registration
# ---------------------------------------------------------------------
if hasattr(__loader__, "name"):
    module_path = __loader__.name
elif hasattr(__loader__, "fullname"):
    module_path = __loader__.fullname

gym_register(
    id="BipedalWalker-Default-v0",
    entry_point=module_path + ":BipedalWalkerDefault",
    max_episode_steps=2000,
)

gym_register(
    id="BipedalWalker-Med-Roughness-v0",
    entry_point=module_path + ":BipedalWalkerMedRoughness",
    max_episode_steps=2000,
)

gym_register(
    id="BipedalWalker-High-Roughness-v0",
    entry_point=module_path + ":BipedalWalkerHighRoughness",
    max_episode_steps=2000,
)

gym_register(
    id="BipedalWalker-Med-StumpHeight-v0",
    entry_point=module_path + ":BipedalWalkerMedStumps",
    max_episode_steps=2000,
)

gym_register(
    id="BipedalWalker-High-StumpHeight-v0",
    entry_point=module_path + ":BipedalWalkerHighStumps",
    max_episode_steps=2000,
)

gym_register(
    id="BipedalWalker-Med-Stairs-v0",
    entry_point=module_path + ":BipedalWalkerMedStairs",
    max_episode_steps=2000,
)

gym_register(
    id="BipedalWalker-High-Stairs-v0",
    entry_point=module_path + ":BipedalWalkerHighStairs",
    max_episode_steps=2000,
)

gym_register(
    id="BipedalWalker-Med-PitGap-v0",
    entry_point=module_path + ":BipedalWalkerMedPits",
    max_episode_steps=2000,
)

gym_register(
    id="BipedalWalker-Wide-PitGap-v0",
    entry_point=module_path + ":BipedalWalkerWidePits",
    max_episode_steps=2000,
)

gym_register(
    id="BipedalWalker-Insane-v0",
    entry_point=module_path + ":BipedalWalkerInsane",
    max_episode_steps=2000,
)

gym_register(
    id="BipedalWalker-XChal-v0",
    entry_point=module_path + ":BipedalWalkerXChal",
    max_episode_steps=2000,
)

gym_register(
    id="BipedalWalker-Full-v0",
    entry_point=module_path + ":BipedalWalkerFull",
    max_episode_steps=2000,
)
gym_register(
    id="BipedalWalker-ZeroShot-v0",
    entry_point=module_path + ":BipedalWalkerZeroShot",
    max_episode_steps=2000,
)

for id_ in ["1a", "1b", "2a", "2b", "3a", "3b"]:
    gym_register(
        id=f"BipedalWalker-POET-Rose-{id_}-v0",
        entry_point=module_path + f":BipedalWalkerPOETRose{id_}",
        max_episode_steps=2000,
    )
