import jax
import jax.numpy as jnp
from jax import random

from typing import Any, Dict, List, Tuple, Union
from enum import IntEnum
from gymnasium import spaces
from environments.jax_entities import Agent, Goal
# Define Agent and Goal types (you'll have to implement these yourself)
# from environments import Agent, Goal

# Global elements
_LAYER_AGENTS = 0

class Actions(IntEnum):
    Up = 0
    Down = 1
    Left = 2
    Right = 3
    Noop = 4

class MOResourceGathering:
    def __init__(
        self,
        env_size: str, 
        sampling_strategy: str,
        n_agents: int,
        reward_dim: int,
        agents_finite_size_bags: bool, 
        agents_bags_size: List[int],
        partial_observability: bool = False,
        centralized_controller: bool = False,
        sensor_range: int = 4,
        local_reward: bool = False,
        agent_specific_objectives: bool = False,
        agents_objectives: List[List[int]] = [],
        normalize_reward: bool = True,
    ):
        self.current_step = 0
        self.env_size = env_size
        if env_size == "small":
            self.grid_size = (10, 10)
            self.min_reward_amplitude = 3
            self.max_reward_amplitude = 7
            self.__max_steps__ = 1e3
        elif env_size == "medium":
            self.grid_size = (15, 15)
            self.min_reward_amplitude = 5
            self.max_reward_amplitude = 12
            self.__max_steps__ = 5e3
        elif env_size == "large":
            self.grid_size = (20, 20)
            self.min_reward_amplitude = 10
            self.max_reward_amplitude = 20
            self.__max_steps__ = 2e4
        else: 
            raise ValueError("Unknown environment size")
        
        if sampling_strategy in ["uniform", "debug", "objective_clusters", "geographical_clusters"]:
            self.sampling_strategy = sampling_strategy
        else:
            raise ValueError("Unknown sampling strategy")    

        self.n_agents = n_agents
        self.reward_dim = reward_dim
        self.partial_observability = partial_observability
        self.centralized_controller = centralized_controller
        self.local_reward = local_reward
        self.agent_specific_objectives = agent_specific_objectives
        # sensor range is only used when partial observability is enabled
        self.sensor_range = sensor_range if partial_observability else 0
        self.normalize_reward = normalize_reward

        # JAX compatible reward amplitude, replacing np with jax.numpy
        self.reward_amplitude = jnp.zeros(self.reward_dim)
        self.agents_finite_size_bags = agents_finite_size_bags
        self.agents_bags_size = agents_bags_size

        # Replace np arrays with jax.numpy arrays
        self.reward_space = spaces.Box(
            low=jnp.array([0 for _ in range(reward_dim)]),
            high=jnp.array([self.max_reward_amplitude for _ in range(reward_dim)]),
            dtype=jnp.float32,
        )
        
        # Define the goals as a range based on reward dimensions
        self.goal_layers = list(range(1, 1 + self.reward_dim))

        # Agent objectives setup
        self.agents_objectives = (
            agents_objectives
            if agent_specific_objectives
            else [list(range(self.reward_dim)) for _ in range(n_agents)]
        )

        # Define the grid environment with JAX
        self.grid = jnp.zeros((1 + self.reward_dim, *self.grid_size))

        # Setup action space
        if self.centralized_controller:
            self.action_space = spaces.Discrete(len(Actions) ** n_agents)
        else:
            self.action_space = spaces.Tuple(
                tuple([spaces.Discrete(len(Actions)) for _ in range(n_agents)])
            )

        # Observation space setup with JAX arrays
        if self.partial_observability:
            if agent_specific_objectives:
                self.state_space_dim = [
                    (
                        (self.sensor_range + 1)
                        * (self.sensor_range + 1)
                        * (1 + len(self.agents_objectives[i]))
                        + len(self.agents_objectives[i])
                        + 3
                    )
                    for i in range(self.n_agents)
                ]
                self.observation_space = spaces.Tuple(
                    tuple(
                        spaces.Box(
                            jnp.array([0] * self.state_space_dim[i]),
                            jnp.array([1] * self.state_space_dim[i]),
                        )
                        for i in range(self.n_agents)
                    )
                )
            else:
                self.state_space_dim = (
                    (self.sensor_range + 1)
                    * (self.sensor_range + 1)
                    * (1 + self.reward_dim)
                    + self.reward_dim
                    + 3
                )
                self.observation_space = spaces.Tuple(
                    tuple(
                        spaces.Box(
                            jnp.array([0] * self.state_space_dim),
                            jnp.array([1] * self.state_space_dim),
                        )
                        for _ in range(n_agents)
                    )
                )
        else:
            self.state_space_dim = (
                jnp.prod(jnp.array(self.grid_size)) * (1 + self.reward_dim)
                + (self.reward_dim + 1) * self.n_agents
            )
            self.observation_space = spaces.Box(
                jnp.array([0] * self.state_space_dim),
                jnp.array([1] * self.state_space_dim),
            )

        # Placeholder for agents and goals (implement with JAX structures)
        self.agents = []  # You would need to handle this with JAX-compatible objects
        self.goals = [[] for _ in range(self.reward_dim)]
        
        self.agent_order = list(range(n_agents))

        self.collected_chests = jnp.zeros(self.reward_dim)

    # Add more methods and functionality as needed, ensuring the use of JAX and functional programming
    def reset(self, key: jax.random.PRNGKey) -> Tuple[Union[Tuple[Any], jnp.ndarray], Dict[str, Any]]:
    
        def sample_distinct_coordinates(grid_size: Tuple[int, int], reward_dim: int, key: jax.random.PRNGKey):
            """
            Sample distinct 2D coordinates from a grid using JAX.

            Parameters:
            - grid_size: Size of the grid (height, width).
            - reward_dim: Number of distinct coordinates to sample.

            Returns:
            - Array of reward_dim distinct 2D coordinates.
            """
            h, w = grid_size
            # Create the grid coordinates
            grid = jnp.array([(i, j) for i in range(1, h - 1) for j in range(1, w - 1)])

            if reward_dim > len(grid):
                raise ValueError("Cannot sample more distinct coordinates than available in the grid.")

            # Shuffle grid with a JAX key and take first `reward_dim` points
            shuffled_grid = jax.random.permutation(key, grid)
            return shuffled_grid[:reward_dim]

        def uniform_sampling(grid_size: Tuple[int, int], reward_dim: int, reward_amplitude: jnp.ndarray, key: jax.random.PRNGKey):
            """
            Sample chests and agent positions uniformly on the grid.
            """
            chests_positions = []
            all_samples = set()
            for i in range(reward_dim):
                samples = []
                while len(samples) < reward_amplitude[i]:
                    x_coord = jax.random.randint(key, shape=(), minval=0, maxval=grid_size[0])
                    y_coord = jax.random.randint(key, shape=(), minval=0, maxval=grid_size[1])
                    sample = (x_coord, y_coord)
                    if sample not in all_samples:
                        samples.append(sample)
                        all_samples.add(sample)
                chests_positions.append(jnp.array(samples))
            return jnp.array(chests_positions), all_samples

        def sample_positions(grid_size: Tuple[int, int], reward_dim: int, reward_amplitude: jnp.ndarray, key: jax.random.PRNGKey):
            if self.sampling_strategy == "debug":
                self.reward_dim = 3
                self.reward_amplitude = jnp.array([3] * self.reward_dim)
                chests_positions = jnp.array([[[9, 9], [9, 8], [8, 9]], [[0, 9], [0, 8], [1, 9]], [[9, 0], [9, 1], [8, 0]]])
                agent_positions = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]])
                return chests_positions, agent_positions[:self.n_agents]

            if self.sampling_strategy == "uniform":
                chests_positions, all_samples = uniform_sampling(grid_size, reward_dim, reward_amplitude, key)
                all_positions = jnp.array([(i, j) for i in range(grid_size[0]) for j in range(grid_size[1])])
                mask = jnp.array([not ((i, j) in all_samples) for i, j in all_positions])
                available_positions = all_positions[mask]
                shuffled_available = jax.random.permutation(key, available_positions)
                return chests_positions, shuffled_available[:self.n_agents]

        def generate_chests_and_agents(grid_size: Tuple[int, int], reward_dim: int, reward_amplitude: jnp.ndarray, key: jax.random.PRNGKey):
            centroids = sample_distinct_coordinates(grid_size, reward_dim, key)
            chests_positions, agent_positions = sample_positions(grid_size, reward_dim, reward_amplitude, key)

            goals = [[Goal("goal", x, y) for x, y in chests] for chests in chests_positions]

            grid = jnp.zeros((1 + reward_dim, *grid_size))
            for i in range(reward_dim):
                for x, y in chests_positions[i]:
                    grid = grid.at[self.goal_layers[i], y, x].set(1)

            agents = []
            for i in range(self.n_agents):
                x, y = agent_positions[i]
                agents.append(Agent(i, x, y, len(self.agents_objectives[i]), self.agents_bags_size[i]))
                grid = grid.at[_LAYER_AGENTS, y, x].set(1)

            return grid, agents, goals

        # JAX-compliant reset logic
        grid_size = self.grid_size
        reward_dim = self.reward_dim

        # Reset the grid and random elements
        key, subkey = random.split(key)
        reward_amplitude = jax.random.randint(subkey, shape=(reward_dim,), minval=self.min_reward_amplitude, maxval=self.max_reward_amplitude + 1)

        key, subkey = random.split(key)
        grid, agents, goals = generate_chests_and_agents(grid_size, reward_dim, reward_amplitude, subkey)

        self.grid = grid
        self.agents = agents
        self.goals = goals
        self.reward_amplitude = reward_amplitude
        self.current_step = 0
        self.collected_chests = jnp.zeros(reward_dim)

        return self._get_obs(), {
            "total_demand": self.reward_amplitude,
            "agents_bags_size": self.agents_bags_size,
        }
    

e = MOResourceGathering("medium", "uniform", 5, 4, True, 4, True)
print(e.reset())