import functools
from typing import Tuple, Dict, List

from brax.envs.wrappers import VectorGymWrapper, VectorWrapper
import gym
from gym.envs.mujoco.half_cheetah_v3 import HalfCheetahEnv
from gym.envs.mujoco.hopper_v3 import HopperEnv
from gym.envs.mujoco.inverted_pendulum import InvertedPendulumEnv
from gym.envs.mujoco.walker2d_v3 import Walker2dEnv
import torch
import torch.optim as optim
import torch.nn as nn
from envs import (
    BraxDynamicsAnt,
    BraxDynamicsHalfcheetah,
    BraxDynamicsHopper,
    BraxDynamicsHumanoid,
    BraxDynamicsInvertedPendulum,
    BraxDynamicsWalker,
    DomainRandomizationMujocoWrapper,
    MujocoDynamicsWrapper,
    ObsToNumpy,
    TruncatedAnt,
    TruncatedHumanoidV3,
    VecBoxTorchWrapper,
    RecordEpisodeStatistics,
    AntBenchmark,
    HopperBenchmark,
    Walker2dBenchmark,
    HumanoidStandUpBenchmark,
    HalfCheetahBenchmark,
    InvertedPendulumBenchmark,
    DomainRandomizationBenchmarkWrapper,
    # RecordStepStatistics,
)

from algos import IAgent
from models import (
    Actor,
    Critic,
    ModelPredictiveCodingSeperateCritic,
    ModelHydraPredictiveCodingCritic,
)
from project_agent import (
    MujocoSacAgent,
    MujocoSacMPC,
    MujocoSacMPCHydra,
    RaviAgent,
    PessimistExpertAgent,
    PessimistExpertAgentHydra,
    PessimistExpertAgentHydraDefault,
    PessimistExpertAgentHydraDefaultAdaptativeThreshold,
)

from bounds import BENCHMARK_MUJOCO_BOUNDS, MUJUCO_BOUNDS

BRAX_ENVS = {
    "Ant": BraxDynamicsAnt,
    "Halfcheetah": BraxDynamicsHalfcheetah,
    "Hopper": BraxDynamicsHopper,
    "Humanoid": BraxDynamicsHumanoid,
    "InvertedPendulum": BraxDynamicsInvertedPendulum,
    "Walker": BraxDynamicsWalker,
}
MUJOCO_ENVS = {
    "Ant": TruncatedAnt,
    "Halfcheetah": HalfCheetahEnv,
    "Hopper": HopperEnv,
    "Humanoid": TruncatedHumanoidV3,
    "InvertedPendulum": InvertedPendulumEnv,
    "Walker": Walker2dEnv,
}

MUJOCO_BENCHMARK_ENVS = {
    "Ant": AntBenchmark,
    "Halfcheetah": HalfCheetahBenchmark,
    "Hopper": HopperBenchmark,
    "Humanoid": HumanoidStandUpBenchmark,
    "InvertedPendulum": InvertedPendulumBenchmark,
    "Walker": Walker2dBenchmark,
}


MAX_NB_STEP = 1000


def build_brax_envs(
    env_name: str, num_env: int = 1, seed: int = 42, **kwargs
) -> gym.Env:
    """Build `brax` env as gym vector env.

    Args:
        env_name (str): Name of the environment.
        num_env (int, optional): number of environment in parallels. Defaults to 1.
        seed (int, optional): [description]. Defaults to 42.

    Returns:
        gym.Env: Brax env as gym vector env.
    """
    jax_env = BRAX_ENVS[env_name](**kwargs)
    jax_vec_env = VectorWrapper(env=jax_env, batch_size=num_env)
    gym_env = VectorGymWrapper(env=jax_vec_env, seed=seed)
    gym_env = ObsToNumpy(env=gym_env)
    gym_env = gym.wrappers.TimeLimit(env=gym_env, max_episode_steps=MAX_NB_STEP)
    return gym_env


def build_mujoco_env(env_name: str, seed: int = 42, **kwargs):
    env = MUJOCO_ENVS[env_name]()
    env = MujocoDynamicsWrapper(env, **kwargs)
    env = gym.wrappers.RescaleAction(env=env, min_action=-1, max_action=1)
    env = gym.wrappers.TimeLimit(env=env, max_episode_steps=MAX_NB_STEP)
    env.seed(seed=seed)
    return env


def build_mujoco_benchmark_env(env_name: str, seed: int = 42, **kwargs):
    env = MUJOCO_BENCHMARK_ENVS[env_name](**kwargs)
    env = gym.wrappers.RescaleAction(env=env, min_action=-1, max_action=1)
    env = gym.wrappers.TimeLimit(env=env, max_episode_steps=MAX_NB_STEP)
    env.seed(seed=seed)
    return env


def build_mujoco_envs(
    env_name: str, num_env: int = 1, seed: int = 42, **kwargs
) -> gym.Env:
    """Build mujoco envs as gym vector envs.

    Args:
        env_name (str): Name of the environment.
        num_env (int, optional): number of environment in parallels. Defaults to 1.
        seed (int, optional): [description]. Defaults to 42.

    Returns:
        gym.Env: Brax env as gym vector env.
    """

    def _make_env(**kwargs):
        env = MUJOCO_ENVS[env_name]()
        env = MujocoDynamicsWrapper(env, **kwargs)
        env = gym.wrappers.RescaleAction(env=env, min_action=-1, max_action=1)
        env = gym.wrappers.TimeLimit(env=env, max_episode_steps=MAX_NB_STEP)
        return env

    _env_maker = functools.partial(_make_env, **kwargs)  # HACK
    env_fns = [_env_maker for _ in range(num_env)]
    if num_env == 1:
        vec_env = gym.vector.SyncVectorEnv(env_fns)
    else:
        vec_env = gym.vector.AsyncVectorEnv(env_fns)
    # vec_env = AsyncRenderVectorEnv(env_fns)
    vec_env.seed(seed)
    return vec_env


def build_mujoco_benchmark_envs(
    env_name: str, num_env: int = 1, seed: int = 42, **kwargs
) -> gym.Env:
    """Build mujoco envs as gym vector envs.
    Args:
        env_name (str): Name of the environment.
        num_env (int, optional): number of environment in parallels. Defaults to 1.
        seed (int, optional): [description]. Defaults to 42.

    Returns:
        gym.Env: Brax env as gym vector env
    """

    def _make_env(**kwargs):
        env = MUJOCO_BENCHMARK_ENVS[env_name](**kwargs)
        env = gym.wrappers.RescaleAction(env=env, min_action=-1, max_action=1)
        env = gym.wrappers.TimeLimit(env=env, max_episode_steps=MAX_NB_STEP)
        return env

    _env_maker = functools.partial(_make_env, **kwargs)  # HACK
    env_fns = [_env_maker for _ in range(num_env)]
    if num_env == 1:
        vec_env = gym.vector.SyncVectorEnv(env_fns)
    else:
        vec_env = gym.vector.AsyncVectorEnv(env_fns)
    # vec_env = AsyncRenderVectorEnv(env_fns)
    vec_env.seed(seed)
    return vec_env


def build_domain_randomization_mujoco_envs(
    env_name: str,
    bound: Dict[str, List[float]],
    num_env: int = 1,
    seed: int = 42,
) -> gym.Env:
    """Build mujoco envs as gym vector envs.

    Args:
        env_name (str): Name of the environment.
        bound (Dict[str, List[float]]): Bound of env parameters.
        num_env (int, optional): number of environment in parallels. Defaults to 1.
        seed (int, optional): [description]. Defaults to 42.

    Returns:
        gym.Env: Brax env as gym vector env.
    """

    def _make_env(bound):
        env = MUJOCO_BENCHMARK_ENVS[env_name]()
        env = DomainRandomizationMujocoWrapper(env, bound)
        env = gym.wrappers.RescaleAction(env=env, min_action=-1, max_action=1)
        env = gym.wrappers.TimeLimit(env=env, max_episode_steps=MAX_NB_STEP)
        return env

    _env_maker = functools.partial(_make_env, bound)  # HACK
    env_fns = [_env_maker for _ in range(num_env)]
    if num_env == 1:
        vec_env = gym.vector.SyncVectorEnv(env_fns)
    else:
        vec_env = gym.vector.AsyncVectorEnv(env_fns)
    # vec_env = AsyncRenderVectorEnv(env_fns)
    vec_env.seed(seed)
    return vec_env


def build_domain_randomization_mujoco_benchmark_envs(
    env_name: str,
    bound: Dict[str, List[float]],
    num_env: int = 1,
    seed: int = 42,
):
    def _make_env(bound):
        env = MUJOCO_BENCHMARK_ENVS[env_name]()
        env = DomainRandomizationBenchmarkWrapper(env, params_bound=bound)
        env = gym.wrappers.RescaleAction(env=env, min_action=-1, max_action=1)
        env = gym.wrappers.TimeLimit(env=env, max_episode_steps=MAX_NB_STEP)
        return env

    _env_maker = functools.partial(_make_env, bound)  # HACK
    env_fns = [_env_maker for _ in range(num_env)]
    if num_env == 1:
        vec_env = gym.vector.SyncVectorEnv(env_fns)
    else:
        vec_env = gym.vector.AsyncVectorEnv(env_fns)
    # vec_env = AsyncRenderVectorEnv(env_fns)
    vec_env.seed(seed)
    return vec_env


def build_envs(env_type: str, device: torch.device, **kwargs) -> gym.Env:
    """Builder for brax and mujoco envs

    Args:
        env_type (str): mujoco or brax
        device (torch.device): explicit

    Returns:
        gym.Env: TorchEnv
    """
    map_build_fn = {
        "mujoco": build_mujoco_envs,
        "brax": build_brax_envs,
        "mujoco_single": build_mujoco_env,
        "domain_randomization_mujoco": build_domain_randomization_mujoco_envs,
        "mujoco_benchmark": build_mujoco_benchmark_envs,
        "domain_randomization_mujoco_benchmark": build_domain_randomization_mujoco_benchmark_envs,
    }
    build_fn = map_build_fn[env_type]
    env = build_fn(**kwargs)  # type: ignore
    # env = RecordStepStatistics(env=env, deque_size=100)
    env = RecordEpisodeStatistics(env=env, deque_size=100)
    env = VecBoxTorchWrapper(env=env, device=device)
    return env


def build_bounds(env_type: str, env_name: str, bound_name: str):
    """
    Build bounds for envs

    Args:
        env_type (str): mujoco or benchmark
        env_name (str): Name of the environment.
        bound_name (str): name of the bound

    Returns:
        _type_: _description_
    """
    BOUNDS = {
        "mujoco": MUJUCO_BOUNDS,
        "mujoco_benchmark": BENCHMARK_MUJOCO_BOUNDS,
    }
    return BOUNDS[env_type][env_name][bound_name]


def build_networks(
    observation_dim: int, action_dim: int, mpc: bool = False
) -> Tuple[nn.Module, nn.Module, nn.Module]:
    """Build model for sac

    Args:
        observation_dim (int): observation space dim
        action_dim (int): action space dim
        mpc (bool): Add next_state prediction head

    Returns:
        Tuple[nn.Module, nn.Module, nn.Module]: Actor, Q1, Q2
    """
    actor = Actor(input_dim=observation_dim, action_dim=action_dim)
    # dispatch
    # TODO: ADD MPC SHARED NETWORK
    critic_model = Critic if not mpc else ModelPredictiveCodingSeperateCritic
    q_value_network_1 = critic_model(obs_size=observation_dim, action_size=action_dim)
    q_value_network_2 = critic_model(obs_size=observation_dim, action_size=action_dim)
    return actor, q_value_network_1, q_value_network_2


def build_networks_hydra(
    observation_dim: int,
    action_dim: int,
    mpc: bool = True,
) -> Tuple[nn.Module, nn.Module, nn.Module]:
    """Build model for sac

    Args:
        observation_dim (int): observation space dim
        action_dim (int): action space dim
        mpc (bool): Add next_state prediction head

    Returns:
        Tuple[nn.Module, nn.Module, nn.Module]: Actor, Q1, Q2
    """
    actor = Actor(input_dim=observation_dim, action_dim=action_dim)
    # dispatch
    critic_model = (
        ModelHydraPredictiveCodingCritic if mpc else ModelPredictiveCodingSeperateCritic
    )
    q_value_network_1 = critic_model(obs_size=observation_dim, action_size=action_dim)
    q_value_network_2 = critic_model(obs_size=observation_dim, action_size=action_dim)
    return actor, q_value_network_1, q_value_network_2


def build_optimizer(
    actor_parameters: nn.Parameter,  # type: ignore
    critic_parameters: nn.Parameter,  # type: ignore
    learning_rate_actor: float,
    learning_rate_critic: float,
) -> Tuple[optim.Optimizer, optim.Optimizer]:
    actor_optimizer = optim.Adam(
        params=actor_parameters, lr=learning_rate_actor, eps=1e-5
    )
    critic_optimizer = optim.Adam(
        params=critic_parameters, lr=learning_rate_critic, eps=1e-5
    )
    return actor_optimizer, critic_optimizer


def build_optimizer_mpc(
    actor_parameters: nn.Parameter,  # type: ignore
    critic_parameters: nn.Parameter,  # type: ignore
    learning_rate_actor: float,
    learning_rate_critic: float,
) -> Tuple[optim.Optimizer, optim.Optimizer, optim.Optimizer]:
    actor_optimizer = optim.Adam(
        params=actor_parameters, lr=learning_rate_actor, eps=1e-5
    )
    critic_optimizer = optim.Adam(
        params=critic_parameters, lr=learning_rate_critic, eps=1e-5
    )

    mpc_optimizer = optim.Adam(
        params=critic_parameters, lr=learning_rate_critic, eps=1e-5
    )
    return actor_optimizer, critic_optimizer, mpc_optimizer


def build_optimizer_mpc_no_reg(
    actor_parameters: nn.Parameter,  # type: ignore
    critic_parameters: nn.Parameter,  # type: ignore
    learning_rate_actor: float,
    learning_rate_critic: float,
) -> Tuple[optim.Optimizer, optim.Optimizer, optim.Optimizer, optim.Optimizer]:
    actor_optimizer = optim.Adam(
        params=actor_parameters, lr=learning_rate_actor, eps=1e-5
    )
    critic_optimizer = optim.Adam(
        params=critic_parameters, lr=learning_rate_critic, eps=1e-5
    )

    mpc_optimizer = optim.Adam(
        params=critic_parameters, lr=learning_rate_critic, eps=1e-5
    )
    critic_no_reg_optimizer = optim.Adam(
        params=critic_parameters, lr=learning_rate_critic, eps=1e-5
    )
    return actor_optimizer, critic_optimizer, mpc_optimizer, critic_no_reg_optimizer


def build_agent(
    observation_dim: int,
    action_dim: int,
    device: torch.device,
) -> IAgent[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Build continious control agent

    Args:
        observation_dim (int): Observation dimensions
        action_dim (int): Action dimension
        device (torch.device): Explicit

    Returns:
        IAgent[torch.Tensor, torch.Tensor, torch.Tensor]: Agent for Soft-Actor Critic
    """
    actor, q_value_network_1, q_value_network_2 = build_networks(
        observation_dim=observation_dim, action_dim=action_dim
    )
    agent = MujocoSacAgent(
        actor=actor,
        q_value_network_1=q_value_network_1,
        q_value_network_2=q_value_network_2,
        device=device,
        action_size=action_dim,
    )
    return agent


def build_ravi_agent(
    observation_dim: int, action_dim: int, device: torch.device
) -> IAgent[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Build continious control agent

    Args:
        observation_dim (int): Observation dimensions
        action_dim (int): Action dimension
        device (torch.device): Explicit

    Returns:
        IAgent[torch.Tensor, torch.Tensor, torch.Tensor]: Agent for Soft-Actor Critic
    """
    actor, q_value_network_1, q_value_network_2 = build_networks(
        observation_dim=observation_dim, action_dim=action_dim
    )

    _, q_robust_1, q_robust_2 = build_networks(
        observation_dim=observation_dim, action_dim=action_dim
    )
    agent = RaviAgent(
        actor=actor,
        q_value_network_1=q_value_network_1,
        q_value_network_2=q_value_network_2,
        q_robust_1=q_robust_1,
        q_robust_2=q_robust_2,
        device=device,
        action_size=action_dim,
    )
    return agent


def build_mpc_sac_agent(
    observation_dim: int, action_dim: int, device: torch.device
) -> IAgent[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Build continious control agent

    Args:
        observation_dim (int): Observation dimensions
        action_dim (int): Action dimension
        device (torch.device): Explicit

    Returns:
        IAgent[torch.Tensor, torch.Tensor, torch.Tensor]: Agent for Soft-Actor Critic
    """

    actor, q_value_network_1, q_value_network_2 = build_networks(
        observation_dim=observation_dim, action_dim=action_dim, mpc=True
    )
    agent = MujocoSacMPC(
        actor=actor,
        q_value_network_1=q_value_network_1,
        q_value_network_2=q_value_network_2,
        device=device,
        action_size=action_dim,
    )
    return agent


def build_mpc_sac_hydra_agent(
    observation_dim: int,
    action_dim: int,
    device: torch.device,
) -> IAgent[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Build continious control agent

    Args:
        observation_dim (int): Observation dimensions
        action_dim (int): Action dimension
        device (torch.device): Explicit

    Returns:
        IAgent[torch.Tensor, torch.Tensor, torch.Tensor]: Agent for Soft-Actor Critic
    """

    actor, q_value_network_1, q_value_network_2 = build_networks_hydra(
        observation_dim=observation_dim,
        action_dim=action_dim,
        mpc=True,
    )
    agent = MujocoSacMPCHydra(
        actor=actor,
        q_value_network_1=q_value_network_1,
        q_value_network_2=q_value_network_2,
        device=device,
        action_size=action_dim,
    )
    return agent


def build_pessimist_agent(
    observation_dim: int, action_dim: int, device: torch.device, threshold: float = 1.5
):
    """
    Build continious control agent

    Args:
        observation_dim (int): Observation dimensions
        action_dim (int): Action dimension
        device (torch.device): Explicit

    """

    actor, q_value_network_1, q_value_network_2 = build_networks(
        observation_dim=observation_dim, action_dim=action_dim, mpc=True
    )

    _, q_robust_1, q_robust_2 = build_networks(
        observation_dim=observation_dim, action_dim=action_dim, mpc=True
    )
    agent = PessimistExpertAgent(
        actor=actor,
        critic=q_value_network_1,
        device=device,
        action_size=action_dim,
        threshold_pc=threshold,
    )
    return agent


def build_pessimist_agent_hydra(
    observation_dim: int, action_dim: int, device: torch.device, threshold: float = 1.5
):
    """
    Build continious control agent

    Args:
        observation_dim (int): Observation dimensions
        action_dim (int): Action dimension
        device (torch.device): Explicit

    """
    actor, q_value_network_1, q_value_network_2 = build_networks_hydra(
        observation_dim=observation_dim, action_dim=action_dim, mpc=True
    )

    _, q_robust_1, q_robust_2 = build_networks_hydra(
        observation_dim=observation_dim, action_dim=action_dim, mpc=True
    )
    agent = PessimistExpertAgentHydra(
        actor=actor,
        threshold_pc=threshold,
        critic=q_value_network_1,
        device=device,
        action_size=action_dim,
    )
    return agent


def build_pessimist_agent_hydra_default(
    observation_dim: int, action_dim: int, device: torch.device, threshold: float = 1.5
):
    """
    Build continious control agent

    Args:
        observation_dim (int): Observation dimensions
        action_dim (int): Action dimension
        device (torch.device): Explicit

    """
    actor, q_value_network_1, q_value_network_2 = build_networks_hydra(
        observation_dim=observation_dim, action_dim=action_dim, mpc=True
    )

    _, q_robust_1, q_robust_2 = build_networks_hydra(
        observation_dim=observation_dim, action_dim=action_dim, mpc=True
    )
    agent = PessimistExpertAgentHydraDefault(
        actor=actor,
        threshold_pc=threshold,
        critic=q_value_network_1,
        device=device,
        action_size=action_dim,
    )
    return agent


def build_pessimist_agent_hydra_default_adaptative_threshold(
    observation_dim: int, action_dim: int, device: torch.device
):
    """
    Build continious control agent

    Args:
        observation_dim (int): Observation dimensions
        action_dim (int): Action dimension
        device (torch.device): Explicit

    """
    actor, q_value_network_1, q_value_network_2 = build_networks_hydra(
        observation_dim=observation_dim, action_dim=action_dim, mpc=True
    )

    _, q_robust_1, q_robust_2 = build_networks_hydra(
        observation_dim=observation_dim, action_dim=action_dim, mpc=True
    )
    agent = PessimistExpertAgentHydraDefaultAdaptativeThreshold(
        actor=actor,
        critic=q_value_network_1,
        device=device,
        action_size=action_dim,
    )
    return agent
