# import copy
from typing import Callable, Dict, Optional, Sequence, Tuple, Union, no_type_check

import flax.linen as nn
import jax
import jax.numpy as jnp
import optax
from optax import GradientTransformation
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.preprocessing import is_image_space, maybe_transpose
from stable_baselines3.common.utils import is_vectorized_observation


class Flatten(nn.Module):
    """
    Equivalent to PyTorch nn.Flatten() layer.
    """

    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        return x.reshape((x.shape[0], -1))


class BaseJaxPolicy(BasePolicy):
    def __init__(self, *args, **kwargs):
        super().__init__(
            *args,
            **kwargs,
        )

    @staticmethod
    @jax.jit
    def sample_action(actor_state, obervations, key):
        dist = actor_state.apply_fn(actor_state.params, obervations)
        action = dist.sample(seed=key)
        return action

    @staticmethod
    @jax.jit
    def select_action(actor_state, obervations):
        return actor_state.apply_fn(actor_state.params, obervations).mode()

    @no_type_check
    def predict(
        self,
        observation: Union[np.ndarray, Dict[str, np.ndarray]],
        state: Optional[Tuple[np.ndarray, ...]] = None,
        episode_start: Optional[np.ndarray] = None,
        deterministic: bool = False,
    ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
        # self.set_training_mode(False)

        observation, vectorized_env = self.prepare_obs(observation)

        actions = self._predict(observation, deterministic=deterministic)

        # Convert to numpy, and reshape to the original action shape
        actions = np.array(actions).reshape((-1, *self.action_space.shape))

        if isinstance(self.action_space, spaces.Box):
            if self.squash_output:
                # Clip due to numerical instability
                actions = np.clip(actions, -1, 1)
                # Rescale to proper domain when using squashing
                actions = self.unscale_action(actions)
            else:
                # Actions could be on arbitrary scale, so clip the actions to avoid
                # out of bound error (e.g. if sampling from a Gaussian distribution)
                actions = np.clip(actions, self.action_space.low, self.action_space.high)

        # Remove batch dimension if needed
        if not vectorized_env:
            actions = actions.squeeze(axis=0)  # type: ignore[call-overload]

        return actions, state

    def prepare_obs(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> Tuple[np.ndarray, bool]:
        vectorized_env = False
        if isinstance(observation, dict):
            assert isinstance(self.observation_space, spaces.Dict)
            # Minimal dict support: flatten
            keys = list(self.observation_space.keys())
            vectorized_env = is_vectorized_observation(observation[keys[0]], self.observation_space[keys[0]])

            # Add batch dim and concatenate
            observation = np.concatenate(
                [observation[key].reshape(-1, *self.observation_space[key].shape) for key in keys],
                axis=1,
            )
            # need to copy the dict as the dict in VecFrameStack will become a torch tensor
            # observation = copy.deepcopy(observation)
            # for key, obs in observation.items():
            #     obs_space = self.observation_space.spaces[key]
            #     if is_image_space(obs_space):
            #         obs_ = maybe_transpose(obs, obs_space)
            #     else:
            #         obs_ = np.array(obs)
            #     vectorized_env = vectorized_env or is_vectorized_observation(obs_, obs_space)
            #     # Add batch dimension if needed
            #     observation[key] = obs_.reshape((-1, *self.observation_space[key].shape))

        elif is_image_space(self.observation_space):
            # Handle the different cases for images
            # as PyTorch use channel first format
            observation = maybe_transpose(observation, self.observation_space)

        else:
            observation = np.array(observation)

        if not isinstance(self.observation_space, spaces.Dict):
            assert isinstance(observation, np.ndarray)
            vectorized_env = is_vectorized_observation(observation, self.observation_space)
            # Add batch dimension if needed
            observation = observation.reshape((-1, *self.observation_space.shape))  # type: ignore[misc]

        assert isinstance(observation, np.ndarray)
        return observation, vectorized_env

    def set_training_mode(self, mode: bool) -> None:
        # self.actor.set_training_mode(mode)
        # self.critic.set_training_mode(mode)
        self.training = mode


class ContinuousCritic(nn.Module):
    net_arch: Sequence[int]
    use_layer_norm: bool = False
    dropout_rate: Optional[float] = None
    activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu

    @nn.compact
    def __call__(self, x: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray:
        x = Flatten()(x)
        x = jnp.concatenate([x, action], -1)

        for n_units in self.net_arch:
            x = nn.Dense(n_units)(x)
            if self.dropout_rate is not None and self.dropout_rate > 0:
                x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=False)
            if self.use_layer_norm:
                x = nn.LayerNorm()(x)
            x = self.activation_fn(x)
        x = nn.Dense(1)(x)
        return x


class VectorCritic(nn.Module):
    net_arch: Sequence[int]
    use_layer_norm: bool = False
    dropout_rate: Optional[float] = None
    n_critics: int = 2
    activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu

    @nn.compact
    def __call__(self, obs: jnp.ndarray, action: jnp.ndarray):
        # Idea taken from https://github.com/perrin-isir/xpag
        # Similar to https://github.com/tinkoff-ai/CORL for PyTorch
        vmap_critic = nn.vmap(
            ContinuousCritic,
            variable_axes={"params": 0},  # parameters not shared between the critics
            split_rngs={"params": True, "dropout": True},  # different initializations
            in_axes=None,
            out_axes=0,
            axis_size=self.n_critics,
        )
        q_values = vmap_critic(
            use_layer_norm=self.use_layer_norm,
            dropout_rate=self.dropout_rate,
            net_arch=self.net_arch,
            activation_fn=self.activation_fn,
        )(obs, action)
        return q_values


class AutoVectorCritic:
    def __init__(
        self,
        net_archs: Sequence[Sequence[int]],
        use_layer_norm: bool = False,
        dropout_rate: Optional[float] = None,
        activation_fns: Sequence[Callable[[jnp.ndarray], jnp.ndarray]] = [nn.relu],
        optimizer_classes: Sequence[Callable] = [optax.adam],
        learning_rates: Sequence[float] = [0.001],
    ):
        self.n_critics = len(net_archs) * len(activation_fns) * len(optimizer_classes) * len(learning_rates)

        self.init_fns = []
        self.apply_fns = []
        optimizers_init = []
        optimizers_update = []

        for net_arch in net_archs:
            for activation_fn in activation_fns:
                for optimizer_class in optimizer_classes:
                    for learning_rate in learning_rates:
                        model = ContinuousCritic(
                            net_arch,
                            use_layer_norm,
                            dropout_rate,
                            activation_fn,
                        )
                        self.init_fns.append(model.init)
                        self.apply_fns.append(model.apply)

                        optimizer = optimizer_class(learning_rate)
                        optimizers_init.append(optimizer.init)
                        optimizers_update.append(optimizer.update)

        self.tx = AutoVectorOptimizer(optimizers_init, optimizers_update)

    def init(self, key_init: jax.random.PRNGKey, obs: jnp.ndarray, action: jnp.ndarray, **kwargs):
        keys = jax.random.split(key_init, self.n_critics)

        def init_one_net(init_fn, key):
            return init_fn(key, obs, action)

        return jax.tree.map(init_one_net, self.init_fns, list(keys))

    def apply(self, params, indexes: jax.Array, obs: jnp.ndarray, action: jnp.ndarray, **kwargs):
        batch_size = jnp.atleast_2d(obs).shape[0]

        def apply_one_net(apply_fn, param, keep):
            return jax.lax.cond(keep, apply_fn, lambda a_, b_, c_: jnp.zeros((batch_size, 1)), param, obs, action)

        mapped_batch = jax.tree.map(
            apply_one_net, self.apply_fns, params, list(jnp.zeros(self.n_critics).at[indexes].set(True))
        )

        return jax.vmap(lambda idx: jnp.stack(mapped_batch)[idx])(indexes)


class AutoVectorOptimizer:

    def __init__(self, optimizers_init: Sequence[Callable], optimizers_apply: Sequence[Callable]) -> None:
        self.optimizers_init = optimizers_init
        self.optimizers_apply = optimizers_apply

    def init(self, params: Sequence[Dict]) -> Sequence[Dict]:
        return jax.tree.map(lambda init_fn, param: init_fn(param), self.optimizers_init, params)

    def update(self, grads: Sequence[Dict], opt_state: Sequence[Dict], params: Sequence[Dict]):
        return jax.tree.map(lambda update_fn, *args: update_fn(*args), self.optimizers_apply, grads, opt_state, params)
