from typing import Sequence, Tuple

import chex
import jax
import numpy as np
from flax import linen as nn
from flax.linen.initializers import Initializer, orthogonal
import jax.numpy as jnp
from stoix.base_types import Observation
from stoix.networks.torso import CNNTorso, MLPTorso
from stoix.networks.utils import parse_activation_fn
from jumanji.environments.logic.rubiks_cube.constants import Face

class CNN_TO_MLP(nn.Module):
    channel_sizes: Sequence[int]
    kernel_sizes: Sequence[int]
    strides: Sequence[int]
    mlp_sizes : Sequence[int]
    activation: str = "relu"
    use_layer_norm: bool = False

    @nn.compact
    def __call__(self, observation: chex.Array) -> chex.Array:
        """Forward pass."""
        x = observation
        
        if x.ndim < 4:
            x = x[..., None]
        
        x = CNNTorso(self.channel_sizes, self.kernel_sizes, self.strides, self.activation)(x)
        
        x = MLPTorso(self.mlp_sizes, self.activation, self.use_layer_norm)(x)
        
        return x


def process_pac_man_image(observation: Observation) -> chex.Array:
    """Process the `Observation` to be usable by the critic model.

    Args:
        observation: the observation as returned by the environment.

    Returns:
        rgb: a 2D, RGB image of the current observation.
    """

    layer_1 = jnp.array(observation.grid) * 0.66
    layer_2 = jnp.array(observation.grid) * 0.0
    layer_3 = jnp.array(observation.grid) * 0.33
    player_loc = observation.player_locations
    ghost_pos = observation.ghost_locations
    pellets_loc = observation.power_up_locations
    is_scatter = observation.frightened_state_time[0]
    idx = observation.pellet_locations

    # Pellets are light orange
    for i in range(len(idx)):
        if jnp.array(idx[i]).sum != 0:
            loc = idx[i]
            layer_3 = layer_3.at[loc[1], loc[0]].set(1)
            layer_2 = layer_2.at[loc[1], loc[0]].set(0.8)
            layer_1 = layer_1.at[loc[1], loc[0]].set(0.6)

    # Power pellet is purple
    for i in range(len(pellets_loc)):
        p = pellets_loc[i]
        layer_1 = layer_1.at[p[1], p[0]].set(0.5)
        layer_2 = layer_2.at[p[1], p[0]].set(0)
        layer_3 = layer_3.at[p[1], p[0]].set(0.5)

    # Set player is yellow
    layer_1 = layer_1.at[player_loc.x, player_loc.y].set(1)
    layer_2 = layer_2.at[player_loc.x, player_loc.y].set(1)
    layer_3 = layer_3.at[player_loc.x, player_loc.y].set(0)

    cr = jnp.array([1, 1, 0, 1])
    cg = jnp.array([0, 0.7, 1, 0.7])
    cb = jnp.array([0, 1, 1, 0.35])

    layers = (layer_1, layer_2, layer_3)
    scatter = 1 * (is_scatter / 60)

    def set_ghost_colours(
        layers: chex.Array,
    ) -> Tuple[chex.Array, chex.Array, chex.Array]:
        layer_1, layer_2, layer_3 = layers
        for i in range(4):
            y = ghost_pos[i][0]
            x = ghost_pos[i][1]
            layer_1 = layer_1.at[x, y].set(cr[0])
            layer_2 = layer_2.at[x, y].set(cg[0] + scatter)
            layer_3 = layer_3.at[x, y].set(cb[0] + scatter)
        return layer_1, layer_2, layer_3

    layers = set_ghost_colours(layers)
    layer_1, layer_2, layer_3 = layers
    layer_1 = layer_1.at[0, 0].set(0)
    layer_2 = layer_2.at[0, 0].set(0)
    layer_3 = layer_3.at[0, 0].set(0)
    obs = [layer_1, layer_2, layer_3]
    rgb = jnp.stack(obs, axis=-1)

    return rgb


class PacManNetwork(nn.Module):
    channel_sizes: Sequence[int]
    kernel_sizes: Sequence[int]
    strides: Sequence[int]
    mlp_sizes : Sequence[int]
    activation: str = "relu"
    use_layer_norm: bool = False
    

    @nn.compact
    def __call__(self, observation: Observation) -> chex.Array:
        rgb_obs = process_pac_man_image(observation)
        
        # Get player position, scatter_time and ghost locations
        player_pos = jnp.array(
            [observation.player_locations.x, observation.player_locations.y]
        )
        player_pos = jnp.stack(player_pos, axis=-1)
        scatter_time = observation.frightened_state_time / 60
        scatter_time = jnp.expand_dims(scatter_time, axis=-1)
        ghost_locations_x = observation.ghost_locations[:, :, 0]
        ghost_locations_y = observation.ghost_locations[:, :, 1]
        
        embedding = CNNTorso(self.channel_sizes, self.kernel_sizes, self.strides, self.activation)(rgb_obs)
        
        output = jnp.concatenate(
            [embedding, player_pos, ghost_locations_x, ghost_locations_y, scatter_time],
            axis=-1,
        )  # (B, H+...)
        
        output = MLPTorso(self.mlp_sizes, self.activation, self.use_layer_norm)(output)
        
        return output
    
    
class TetrisNetwork(nn.Module):
    channel_sizes: Sequence[int]
    kernel_sizes: Sequence[int]
    strides: Sequence[int]
    mlp_sizes : Sequence[int]
    activation: str = "relu"
    use_layer_norm: bool = False
    
    @nn.compact
    def __call__(self, observation: Observation) -> chex.Array:
        grid_net = CNNTorso(self.channel_sizes, self.kernel_sizes, self.strides, self.activation, flatten=False)
        
        grid_embeddings = grid_net(
            observation.grid.astype(float)[..., None]
        )  # [B, 2, 10, 64]
        grid_embeddings = jnp.transpose(grid_embeddings, [0, 2, 1, 3])  # [B, 10, 2, 64]
        grid_embeddings = jnp.reshape(
            grid_embeddings, [*grid_embeddings.shape[:2], -1]
        )  # [B, 10, 128]
        
        tetromino_net = MLPTorso(self.mlp_sizes, self.activation, self.use_layer_norm)
        
        tetromino_embeddings = tetromino_net(observation.tetromino.astype(float)).reshape(grid_embeddings.shape[0], -1)
       
        norm_step_count = observation.step_count / 400
        norm_step_count = norm_step_count.reshape(grid_embeddings.shape[0], -1)
        grid_embeddings = jnp.reshape(grid_embeddings, [grid_embeddings.shape[0], -1])
        embedding = jnp.concatenate(
            [grid_embeddings, tetromino_embeddings, norm_step_count], axis=-1
        ) 
        
        # flatten the embedding
        embedding = jnp.reshape(embedding, [embedding.shape[0], -1])
        
        return embedding
    
    
   
class EmbeddingRubiks(nn.Module):
    cube_embedding_dim: int
    step_count_embed_dim: int
    mlp_sizes: Sequence[int]
    
    
    @nn.compact
    def __call__(
        self,
        observation: Observation,
    ) -> chex.Array:
        # Cube embedding
        cube_embedding = nn.Embed(num_embeddings=len(Face), features=self.cube_embedding_dim)(observation.cube).reshape(
            *observation.cube.shape[:-3], -1
        )
        
        # Step count embedding
        step_count_embedding = nn.Dense(self.step_count_embed_dim)(
            observation.step_count[:, None] / 200
        )

        embedding = jnp.concatenate([cube_embedding, step_count_embedding], axis=-1)
        
        embedding = MLPTorso(self.mlp_sizes)(embedding)
        
        return embedding


def process_observation(observation: Observation) -> chex.Array:
    """Add the agent and the target to the walls array."""
    agent = 2
    target = 3
    obs = observation.walls.astype(float)
    obs = obs.at[tuple(observation.agent_position)].set(agent)
    obs = obs.at[tuple(observation.target_position)].set(target)
    return jnp.expand_dims(obs, axis=-1)  # Adding a channels axis.

class MazeNetwork(nn.Module):
    conv_n_channels: Sequence[int]
    mlp_units: Sequence[int]
    time_limit: int

    @nn.compact
    def __call__(self, observation):
        # Convolutional layers
        x = jax.vmap(process_observation)(observation)  # (B, G, G, 1)
        for channels in self.conv_n_channels:
            x = nn.Conv(features=channels, kernel_size=(3, 3))(x)
            x = nn.relu(x)
        x = x.reshape((x.shape[0], -1))  # Flatten

        # Embedding and step count normalization
        embedding = x  # (B, H)
        normalised_step_count = (
            jnp.expand_dims(observation.step_count, axis=-1) / self.time_limit
        )  # (B, 1)
        output = jnp.concatenate([embedding, normalised_step_count], axis=-1)  # (B, H+1)

        output = MLPTorso(self.mlp_units)(output)
        
        return output