"""Implementations of algorithms for continuous control."""

from typing import Sequence, Tuple

import jax
import jax.numpy as jnp
from flax import linen as nn
from sklearn import tree

from jax_rl.networks.common import MLP

from emlp.reps import Rep, Scalar
from rpp.flax import EMLP
import numpy as np

def _vec(x):
    return x[..., 0] if x.ndim > 1 and x.shape[-1] == 1 else x

def PEDoubleCritic(state_rep, action_rep, G, ch:Sequence[int], hidden_dims, state_transform, action_transform):
    critic1 = EMLP(state_rep+action_rep, Scalar, G, ch), MLP((*hidden_dims, 1))
    critic2 = EMLP(state_rep+action_rep, Scalar, G, ch), MLP((*hidden_dims, 1))
    return _PEDoubleCritic(critic1, critic2, state_transform, action_transform)

class _PEDoubleCritic(nn.Module):
    critic1:Tuple[nn.Module, nn.Module] # Equivariant and Standard
    critic2:Tuple[nn.Module, nn.Module] # Equivariant and Standard
    state_transform:callable
    action_transform:callable
    def __call__(self, observations: jnp.ndarray,
                 actions: jnp.ndarray,
                 lam: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
        return self.q_hybrid(observations, actions, lam)
    
    def _flatten_BK(self, obs, act):
        if obs.ndim == 3: # Allow [B, K, -1]
            B, K, S = obs.shape
            A = act.shape[-1]
            return obs.reshape(B*K, S), act.reshape(B*K, A)
        return obs, act

    def q_hybrid(self, observations: jnp.ndarray,
                 actions: jnp.ndarray,
                 lam: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:

        observations, actions = self._flatten_BK(observations, actions)

        lam_sa = jax.lax.stop_gradient(lam.astype(jnp.float32)).reshape(-1)

        state = self.state_transform(observations)
        t_actions = self.action_transform(actions)
        inputs = jnp.concatenate([state, t_actions], axis=-1)

        c1_e = _vec(self.critic1[0](inputs))
        c1_s = _vec(self.critic1[1](inputs))
        c2_e = _vec(self.critic2[0](inputs))
        c2_s = _vec(self.critic2[1](inputs))
        

        c1 = (1.0 - lam_sa) * c1_e + lam_sa * c1_s
        c2 = (1.0 - lam_sa) * c2_e + lam_sa * c2_s
        return c1, c2

    def heads(self, observations: jnp.ndarray, actions: jnp.ndarray) \
            -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]]:
        observations, actions = self._flatten_BK(observations, actions)

        state = self.state_transform(observations)
        t_actions = self.action_transform(actions)
        inputs = jnp.concatenate([state, t_actions], axis=-1)

        c1_e = _vec(self.critic1[0](inputs))
        c1_s = _vec(self.critic1[1](inputs))
        c2_e = _vec(self.critic2[0](inputs))
        c2_s = _vec(self.critic2[1](inputs))

        return c1_e, c1_s, c2_e, c2_s
