import jax.numpy as jnp
import flax.linen as nn
import numpy as np
from flax.linen.initializers import orthogonal
import distrax
import torch


class PredictabilityHead(nn.Module):
    """
    Transformer-based model to predict values from driver states.
    
    Args:
        num_heads: number of attention heads per transformer block
        hidden_dim: width of hidden representation / token size
        num_layers: number of self-attention blocks
    """
    num_heads: int
    hidden_dim: int
    num_layers: int

    @nn.compact
    def __call__(
        self,
        driver_states: jnp.ndarray,   # [k,obs] or [B,k,obs]
        driver_returns: jnp.ndarray,  # [k] or [B,k]
        query_state: jnp.ndarray,     # [B,obs]
        training: bool = False,
    ) -> jnp.ndarray:                 # [B]
        """Predict the value of query_state."""
        # Normalize shapes so that driver_states → [B, k, obs_dim]
        if driver_states.ndim == 2:  # no batch dimension yet
            k, obs_dim = driver_states.shape
            batch_size = query_state.shape[0]
            driver_states = jnp.broadcast_to(driver_states, (batch_size, k, obs_dim))
            driver_returns = jnp.broadcast_to(driver_returns, (batch_size, k))
        else:
            batch_size, k, obs_dim = driver_states.shape

        h = self.hidden_dim

        # Token embeddings
        state_tokens = nn.Dense(h)(driver_states)                    # [B,k,h]
        return_tokens = nn.Dense(h)(driver_returns[..., None])       # [B,k,1]→[B,k,h]
        query_token = nn.Dense(h)(query_state)[:, None, :]           # [B,1,h]
        
        # Create positional embeddings and add them to tokens
        pos_embeddings = nn.Embed(num_embeddings=k + 1, features=h)(jnp.arange(k + 1))  # [k + 1,h]
        state_tokens = state_tokens + pos_embeddings[:-1, :]         # [B,k,h] + [k,h] → [B,k,h]
        return_tokens = return_tokens + pos_embeddings[-1, :]        # [B,k,h] + [h] → [B,k,h]
        query_token = query_token + pos_embeddings[-1, None, :]      # [B,1,h] + [h] → [B,1,h]
        
        # Concatenate tokens to form the full sequence
        full_sequence = jnp.concatenate([state_tokens, return_tokens, query_token], axis=1)  # [B,2k+1,h]

        # Transformer
        x = full_sequence
        for _ in range(self.num_layers):
            # Self-attention block
            y = nn.LayerNorm()(x)
            y = nn.SelfAttention(num_heads=self.num_heads, qkv_features=h)(y)
            x = x + y
            # Feed-forward block
            y = nn.LayerNorm()(x)
            y = nn.relu(nn.Dense(h)(y))
            y = nn.Dense(h)(y)
            x = x + y

        query_representation = x[:, -1, :]                           # last (query) token
        value_prediction = nn.Dense(1)(query_representation)         # [B,1]
        return jnp.squeeze(value_prediction, -1)                     # [B]


class ActorCriticDiscreteAction(nn.Module):
    """MLP policy + value network for discrete action spaces."""
    action_dim: int
    activation: str = "tanh"  # or "relu"

    @nn.compact
    def __call__(self, observation: jnp.ndarray):
        act_fn = nn.relu if self.activation == "relu" else nn.tanh

        # Policy
        policy_h = act_fn(nn.Dense(64, orthogonal(np.sqrt(2)))(observation))
        policy_h = act_fn(nn.Dense(64, orthogonal(np.sqrt(2)))(policy_h))
        logits = nn.Dense(self.action_dim, orthogonal(0.01))(policy_h)
        policy = distrax.Categorical(logits=logits)

        # Value
        value_h = act_fn(nn.Dense(64, orthogonal(np.sqrt(2)))(observation))
        value_h = act_fn(nn.Dense(64, orthogonal(np.sqrt(2)))(value_h))
        value = nn.Dense(1, orthogonal(1.0))(value_h)

        return policy, jnp.squeeze(value, -1)


class ActorCriticContinuousAction(nn.Module):
    """Shared MLP encoder with separate actor (Gaussian) and critic heads."""
    action_dim: int
    activation: str = "tanh"  # or "relu"

    @nn.compact
    def __call__(self, observation: jnp.ndarray):
        act_fn = nn.relu if self.activation == "relu" else nn.tanh

        # Policy head
        π_h = act_fn(nn.Dense(64, orthogonal(np.sqrt(2)))(observation))
        π_h = act_fn(nn.Dense(64, orthogonal(np.sqrt(2)))(π_h))
        mean = nn.Dense(self.action_dim, orthogonal(0.01))(π_h)
        log_std = self.param("log_std", nn.initializers.zeros, (self.action_dim,))
        policy = distrax.MultivariateNormalDiag(mean, jnp.exp(log_std))

        # Value head
        v_h = act_fn(nn.Dense(64, orthogonal(np.sqrt(2)))(observation))
        v_h = act_fn(nn.Dense(64, orthogonal(np.sqrt(2)))(v_h))
        value = nn.Dense(1, orthogonal(1.0))(v_h)

        return policy, jnp.squeeze(value, -1)


class FeatExtractorDiscreteAction(nn.Module):
    """Feature extractor for discrete action spaces."""
    activation: str = "tanh"  # or "relu"

    @nn.compact
    def __call__(self, observation: jnp.ndarray):
        act_fn = nn.relu if self.activation == "relu" else nn.tanh
        x = act_fn(nn.Dense(64, orthogonal(np.sqrt(2)))(observation))
        x = act_fn(nn.Dense(64, orthogonal(np.sqrt(2)))(x))
        return x


class FeatExtractorContinuousAction(nn.Module):
    """Feature extractor for continuous action spaces."""
    activation: str = "tanh"  # or "relu"

    @nn.compact
    def __call__(self, observation: jnp.ndarray):
        act_fn = nn.relu if self.activation == "relu" else nn.tanh
        x = act_fn(nn.Dense(64, orthogonal(np.sqrt(2)))(observation))
        x = act_fn(nn.Dense(64, orthogonal(np.sqrt(2)))(x))
        return x


def load_feat_extractor_params(actor_critic_params):
    """Load the feature extractor parameters from actor-critic params."""
    return {
        "params": {
            "Dense_0": actor_critic_params["Dense_0"],
            "Dense_1": actor_critic_params["Dense_1"],
        }
    }


class PyTorchContinuousActor(torch.nn.Module):
    """PyTorch implementation of continuous action actor."""
    def __init__(self, action_dim, activation, jax_param_dict, device="cpu"):
        super(PyTorchContinuousActor, self).__init__()
        self.activation = activation
        self.action_dim = action_dim
        self.device = device

        for name, param in jax_param_dict['params'].items():
            if 'Dense' in name:
                weight = torch.from_numpy(np.array(param['kernel'])).to(device)
                bias = torch.from_numpy(np.array(param['bias'])).to(device)
                layer = torch.nn.Linear(weight.shape[1], weight.shape[0])
                layer.weight.data = weight.T
                layer.bias.data = bias
                setattr(self, name, layer)
            
            if 'log_std' in name:
                self.log_std = torch.from_numpy(np.array(param)).to(device)
                
    def forward(self, observation):
        act_fn = torch.nn.ReLU() if self.activation == "relu" else torch.nn.Tanh()

        π_h = act_fn(self.Dense_0(observation))
        π_h = act_fn(self.Dense_1(π_h))
        mean = self.Dense_2(π_h)
        mean = mean.squeeze(-1)
        log_std = self.log_std.expand_as(mean)
        std = torch.exp(log_std)
        policy = torch.distributions.MultivariateNormal(loc=mean, covariance_matrix=torch.diag_embed(std))
        return policy


class PytorchDiscreteActor(torch.nn.Module):
    """PyTorch implementation of discrete action actor."""
    def __init__(self, action_dim, activation, jax_param_dict):
        super(PytorchDiscreteActor, self).__init__()
        self.action_dim = action_dim
        self.activation = activation
        
        for name, param in jax_param_dict['params'].items():
            if 'Dense' in name:
                weight = torch.tensor(np.array(param['kernel']))
                bias = torch.tensor(np.array(param['bias']))
                setattr(self, name, torch.nn.Linear(weight.shape[1], weight.shape[0]))
                getattr(self, name).weight.data = weight.T
                getattr(self, name).bias.data = bias
                
    def forward(self, observation):
        act_fn = torch.nn.ReLU() if self.activation == "relu" else torch.nn.Tanh()

        π_h = act_fn(self.Dense_0(observation))
        π_h = act_fn(self.Dense_1(π_h))
        actor_mean = self.Dense_2(π_h)
        return torch.distributions.Categorical(logits=actor_mean)