from functools import partial

import numpy as np
import jax
from jax import lax
import jax.numpy as jnp
import flax
from flax import linen as nn
import distrax

from .jax_utils import extend_and_repeat, next_rng, JaxRNG


def update_target_network(main_params, target_params, tau):
    return jax.tree_util.tree_map(
        lambda x, y: tau * x + (1.0 - tau) * y,
        main_params, target_params
    )


def multiple_action_q_function(forward):
    # Forward the q function with multiple actions on each state, to be used as a decorator
    def wrapped(self, observations, actions, **kwargs):
        multiple_actions = False
        batch_size = observations.shape[0]
        if actions.ndim == 3 and observations.ndim == 2:
            multiple_actions = True
            observations = extend_and_repeat(observations, 1, actions.shape[1]).reshape(-1, observations.shape[-1])
            actions = actions.reshape(-1, actions.shape[-1])
        q_values, features = forward(self, observations, actions, **kwargs)
        if multiple_actions:
            q_values = q_values.reshape(batch_size, -1)
            features = features.reshape(batch_size, -1, features.shape[-1])
        return q_values, features
    return wrapped


def final_layer(output_dim, orthogonal_init=False):
    if orthogonal_init:
        return nn.Dense(
            output_dim,
            kernel_init=jax.nn.initializers.orthogonal(1e-2),
            bias_init=jax.nn.initializers.zeros
        )
    else:
        return nn.Dense(
            output_dim,
            kernel_init=jax.nn.initializers.variance_scaling(
                1e-2, 'fan_in', 'uniform'
            ),
            bias_init=jax.nn.initializers.zeros
        )


class Scalar(nn.Module):
    init_value: float

    def setup(self):
        self.value = self.param('value', lambda x:self.init_value)

    def __call__(self):
        return self.value


class FullyConnectedNetwork(nn.Module):
    output_dim: int
    arch: str = '256-256'
    orthogonal_init: bool = False
    feature_layer_norm: bool = False
    backprop_through_features: bool = True
    numerical_eps = 1e-8

    @nn.compact
    def __call__(self, input_tensor):
        x = input_tensor
        hidden_sizes = [int(h) for h in self.arch.split('-')]
        for h in hidden_sizes:
            if self.orthogonal_init:
                x = nn.Dense(
                    h,
                    kernel_init=jax.nn.initializers.orthogonal(jnp.sqrt(2.0)),
                    bias_init=jax.nn.initializers.zeros
                )(x)
            else:
                x = nn.Dense(h)(x)
            x = nn.relu(x)
        
        if not self.backprop_through_features:
            x = jax.lax.stop_gradient(x)
        
        if self.feature_layer_norm:
            x -= jnp.mean(x, axis=-1, keepdims=True)
            x /= (jnp.linalg.norm(x, ord=2, axis=-1, keepdims=True) + self.numerical_eps)
            features = x
            x *= jnp.sqrt(x.shape[-1])
        else:
            features = x
        
        output = final_layer(output_dim=self.output_dim, orthogonal_init=self.orthogonal_init)(x)
        
        return output, features


class FullyConnectedQFunction(nn.Module):
    observation_dim: int
    action_dim: int
    arch: str = '256-256'
    orthogonal_init: bool = False
    q_min: float = -np.inf
    q_max: float = np.inf
    feature_layer_norm: bool = False
    scale: float = 1.0
    backprop_through_features: bool = True

    @nn.compact
    @multiple_action_q_function
    def __call__(self, observations, actions):
        x = jnp.concatenate([observations, actions], axis=-1)
        q, features = FullyConnectedNetwork(output_dim=1,
                                            arch=self.arch,
                                            orthogonal_init=self.orthogonal_init,
                                            feature_layer_norm=self.feature_layer_norm,
                                            backprop_through_features=self.backprop_through_features
                                            )(x)
        q *= self.scale
       
        if self.q_min != -np.inf and self.q_max != np.inf:
            q = 0.5 * ((self.q_max - self.q_min) * jnp.tanh(q) + (self.q_max + self.q_min))
        elif self.q_min != -np.inf or self.q_max != np.inf:
            raise ValueError('q_min and q_max must both be set or both not be set')
        
        return jnp.squeeze(q, -1), features

    @nn.nowrap
    def rng_keys(self):
        return ('params', )
    
    def feature_dim(self):
        return int(self.arch.split('-')[-1])


class LinearDualFunction(nn.Module):
    rank: int = 1
    dual_min: float = np.nan
    dual_max: float = np.nan
    numerical_eps = 1e-8
    
    @nn.compact
    def __call__(self, features):
        if self.dual_min is np.nan and self.dual_max is np.nan:
            dual_limit = False
        elif self.dual_min is np.nan or self.dual_max is np.nan:
            raise ValueError('dual_min and dual_max must both be set or both not be set')
        else:
            dual_limit = True

        param_dtype = jnp.float32
        kernel_init = jax.nn.initializers.normal(1)
        a = self.param('a', kernel_init, (jnp.shape(features)[-1], self.rank), param_dtype)
        a_mag = self.param('a_mag', jax.nn.initializers.constant(1), (1,), param_dtype)
        b = self.param('b', kernel_init, (jnp.shape(features)[-1], self.rank), param_dtype)
        b_mag = self.param('b_mag', jax.nn.initializers.constant(1), (1,), param_dtype)
        
        a_2_norm = jnp.linalg.svd(a, compute_uv=False).max()
        b_2_norm = jnp.linalg.svd(b, compute_uv=False).max()
        a = a_mag * (a / (a_2_norm + self.numerical_eps))
        b = b_mag * (b / (b_2_norm + self.numerical_eps))

        m_a = features @ a
        m_b = features @ b

        if dual_limit:
            m_a = 0.5 * ((self.dual_max - self.dual_min) * jnp.tanh(m_a) + (self.dual_max + self.dual_min))
            m_b = 0.5 * ((self.dual_max - self.dual_min) * jnp.tanh(m_b) + (self.dual_max + self.dual_min))
        
        return m_a, m_b, a, b, a_mag, b_mag
    
    @nn.nowrap
    def rng_keys(self):
        return ('params',)


class TanhGaussianPolicy(nn.Module):
    observation_dim: int
    action_dim: int
    arch: str = '256-256'
    orthogonal_init: bool = False
    log_std_multiplier: float = 1.0
    log_std_offset: float = -1.0

    def setup(self):
        self.base_network = FullyConnectedNetwork(
            output_dim=2 * self.action_dim, arch=self.arch, orthogonal_init=self.orthogonal_init
        )
        self.log_std_multiplier_module = Scalar(self.log_std_multiplier)
        self.log_std_offset_module = Scalar(self.log_std_offset)
        
    def dist(self, observations):
        base_network_output, features = self.base_network(observations)
        mean, log_std = jnp.split(base_network_output, 2, axis=-1)
        log_std = self.log_std_multiplier_module() * log_std + self.log_std_offset_module()
        log_std = jnp.clip(log_std, -20.0, 2.0)
        action_distribution = distrax.Transformed(
            distrax.MultivariateNormalDiag(mean, jnp.exp(log_std)),
            distrax.Block(distrax.Tanh(), ndims=1)
        )
        return action_distribution

    def log_prob(self, observations, actions):
        if actions.ndim == 3:
            observations = extend_and_repeat(observations, 1, actions.shape[1])
        return self.dist(observations).log_prob(actions)

    def __call__(self, observations, deterministic=False, repeat=None):
        if repeat is not None:
            observations = extend_and_repeat(observations, 1, repeat)
        action_distribution = self.dist(observations)
        if deterministic:
            samples = jnp.tanh(action_distribution.distribution.mean())
            log_prob = action_distribution.log_prob(samples)
        else:
            samples, log_prob = action_distribution.sample_and_log_prob(seed=self.make_rng('noise'))

        return samples, log_prob

    @nn.nowrap
    def rng_keys(self):
        return ('params', 'noise')


class SamplerPolicy(object):

    def __init__(self, policy, params):
        self.policy = policy
        self.params = params

    def update_params(self, params):
        self.params = params
        return self

    @partial(jax.jit, static_argnames=('self', 'deterministic'))
    def act(self, params, rng, observations, deterministic):
        return self.policy.apply(
            params, observations, deterministic, repeat=None,
            rngs=JaxRNG(rng)(self.policy.rng_keys())
        )

    def __call__(self, observations, deterministic=False):
        actions, _ = self.act(self.params, next_rng(), observations, deterministic=deterministic)
        assert jnp.all(jnp.isfinite(actions))
        return jax.device_get(actions)
