"""
AppleGridMDPD Environment
─────────────────────────
This environment is used to evaluate QMIX agents. 

Key features:
- Agents can move in 4 directions (up, down, left, right) and harvest apples.
- Apples regenerate probabilistically, controlled by a regeneration rate 
  and threshold.
- Three disruption protocols are available and ready to be triggered 
  during evaluation runs:
    1. Resource removal      -> trigger_disruption()
    2. Regeneration rate cut -> trigger_disruption_rate(start/stop)
    3. Agent randomization   -> trigger_disruption_agent(start/stop)

These methods are designed to be called at specific timesteps in the 
evaluation loop to simulate shocks to the system. The environment is 
ready to be integrated into QMIX testing pipelines.
"""

import numpy as np
import random 
from typing import List, Tuple

Action = int  # 0-up, 1-down, 2-left, 3-right
MOVE_DELTAS = {
    0: (-1, 0),  # up
    1: (1, 0),   # down
    2: (0, -1),  # left
    3: (0, 1),   # right
}

class AppleGridMDP:
    def __init__(self, grid_size=(8, 8), regen_threshold=16, rate_regen=0.05, 
                 n_agents=2, episode_limit=1000, seed: int | None = None):
        self.grid_size = grid_size
        self.regen_threshold = regen_threshold
        self.rate_regen = rate_regen
        self.agent_random_start = False
        self.n_agents = n_agents
        self.episode_limit = episode_limit
        self._rng = random.Random(seed)
        np.random.seed(seed)

        # Apple spawn sites
        self.apple_sites: List[Tuple[int, int]] = [
            (2, 3), (2, 4),
            (3, 1), (3, 2), (3, 3), (3, 4), (3, 5), (3, 6),
            (4, 1), (4, 2), (4, 3), (4, 4), (4, 5), (4, 6),
            (5, 4), (5, 3),
        ]

        # Default agent positions (expanded if n_agents > 2)
        self._initial_agent_positions: List[Tuple[int, int]] = [
            (1, 1),
            (self.grid_size[0] - 2, self.grid_size[1] - 2),
        ]
        while len(self._initial_agent_positions) < self.n_agents:
            x, y = np.random.randint(0, 8, size=2).tolist()
            if (x, y) not in self._initial_agent_positions and (x, y) not in self.apple_sites:
                self._initial_agent_positions.append((x, y))

        self.reset()

    @property
    def n_actions(self) -> int:
        return 4
    
    def get_state(self):
        agent_state = np.array(self.agent_positions).flatten()
        grid_state = np.array(self.grid).flatten()
        return np.concatenate([agent_state, grid_state])

    def move_agent(self, agent_idx, action):
        x, y = self.agent_positions[agent_idx]
        new_x, new_y = x, y
        if action == 0 and x > 0: new_x -= 1
        elif action == 1 and x < self.grid_size[0] - 1: new_x += 1
        elif action == 2 and y > 0: new_y -= 1
        elif action == 3 and y < self.grid_size[1] - 1: new_y += 1
        if (new_x, new_y) not in self.agent_positions:
            self.agent_positions[agent_idx] = (new_x, new_y)

    def step(self, actions: List[Action]):
        assert len(actions) == self.n_agents

        per_agent_reward = np.zeros(self.n_agents, dtype=np.float32)

        # Move agents
        new_positions = self.agent_positions.copy()
        for i, action in enumerate(actions):
            if self.agent_random_start:
                dx, dy = MOVE_DELTAS.get(random.randint(0, 3), (0, 0))
            else:
                dx, dy = MOVE_DELTAS.get(action, (0, 0))
            x, y = self.agent_positions[i]
            nx, ny = x + dx, y + dy
            nx = min(max(nx, 0), self.grid_size[0] - 1)
            ny = min(max(ny, 0), self.grid_size[1] - 1)
            if (nx, ny) not in new_positions:
                new_positions[i] = (nx, ny)
        self.agent_positions = new_positions
        
        # Harvest apples
        for i, (x, y) in enumerate(self.agent_positions):
            if self.grid[int(x), int(y)] == 1:
                per_agent_reward[i] = 1.0
                self.grid[int(x), int(y)] = 0

        # Regeneration
        current_apples = int(self.grid.sum())
        missing = self.regen_threshold - current_apples
        if missing > 0 and self.regen_threshold > 0:
            regen_prob = self.rate_regen * (current_apples / self.regen_threshold)
            available_sites = [s for s in self.apple_sites if self.grid[s] == 0 and s not in self.agent_positions]
            self._rng.shuffle(available_sites)
            grown = 0
            for site in available_sites:
                if grown >= missing:
                    break
                if self._rng.random() < regen_prob:
                    self.grid[site] = 1
                    grown += 1

        self.t += 1
        terminated = bool(self.t >= self.episode_limit or self.grid.sum() == 0)
        return per_agent_reward, float(per_agent_reward.sum()), terminated, {}

    def render(self):
        grid_display = self.grid.astype(str)
        for i, (x, y) in enumerate(self.agent_positions):
            grid_display[x, y] = '#' if i == 0 else '*'
        print("\n".join(" ".join(row) for row in grid_display))
        print()

    def reset(self):
        self.grid = np.zeros((self.grid_size[0], self.grid_size[1]), dtype=np.int8)
        for x, y in self.apple_sites:
            self.grid[x, y] = 1
        self.agent_positions = self._initial_agent_positions.copy()
        self.t = 0
        return self.get_obs()

    def trigger_disruption(self, magnitude=0.4):
        current = [(x, y) for x, y in self.apple_sites if self.grid[x, y] == 1]
        if len(current) > 1:
            to_remove = random.sample(current, int(np.ceil(len(current) * magnitude)))
            for x, y in to_remove:
                self.grid[x, y] = 0

    def trigger_disruption_rate(self, start=True):
        if start:
            self.rate_regen -= 0.01
        else:
            self.rate_regen += 0.01

    def trigger_disruption_agent(self, start=True):
        self.agent_random_start = start

    def set_state(self, state):
        self.agent_positions = [(state[0], state[1]), (state[2], state[3])]
        grid_state = np.array(state[4:]).reshape(self.grid_size)
        if grid_state.shape != self.grid.shape:
            raise ValueError("Grid size mismatch.")
        self.grid = grid_state.copy()
    
    def get_obs(self) -> List[np.ndarray]:
        return [self.get_obs_agent(a) for a in range(self.n_agents)]

    def get_obs_agent(self, agent_id: int) -> np.ndarray:
        pos = np.array(self.agent_positions[agent_id], dtype=np.float32)
        return np.concatenate((self.grid.flatten(), pos))

    def get_obs_size(self) -> int:
        return int(self.grid_size[0] ** 2 + 2)

    def get_state(self) -> np.ndarray:
        flat_pos = np.array(self.agent_positions).flatten().astype(np.float32)
        return np.concatenate((flat_pos, self.grid.flatten())).astype(np.float32)

    def get_state_size(self) -> int:
        return int(self.n_agents * 2 + self.grid_size[0] ** 2)

    def get_avail_actions(self) -> List[np.ndarray]:
        return [self.get_avail_actions_agent(i) for i in range(self.n_agents)]

    def get_avail_actions_agent(self, agent_id: int) -> np.ndarray:
        x, y = self.agent_positions[agent_id]
        avail = np.ones(self.n_actions, dtype=np.int8)
        if int(x) == 0: avail[0] = 0
        if int(x) == self.grid_size[0] - 1: avail[1] = 0
        if int(y) == 0: avail[2] = 0
        if int(y) == self.grid_size[1] - 1: avail[3] = 0
        return avail

    def get_total_actions(self) -> int:
        return self.n_actions
