"""
AppleGridMDP16 Environment
──────────────────────────
This environment defines a 16×16 grid with 4 agents and 3 apple trees.
- Agents move on the grid and can collect apples for rewards.
- Each tree consists of a fixed cluster of apple positions.
- Apples regenerate probabilistically as long as the tree is still alive.
- Three disruption protocols are available (e.g., apple removal).
This environment is designed for evaluating PPO agents
under larger-scale and disruptive conditions.
"""

import numpy as np
import random


class AppleGridMDP16:
    def __init__(self, grid_size=(16, 16), regen_threshold=16, rate_regen=0.05):
        self.grid_size = grid_size
        self.regen_threshold = regen_threshold
        self.rate_regen = rate_regen

        # Define three apple clusters (trees)
        self.apple_groups = [
            [  # Tree 1
                (2, 7), (2, 8),
                (3, 5), (3, 6), (3, 7), (3, 8), (3, 9), (3, 10),
                (4, 5), (4, 6), (4, 7), (4, 8), (4, 9), (4, 10),
                (5, 7), (5, 8),
            ],
            [  # Tree 2
                (7, 2), (7, 3),
                (8, 0), (8, 1), (8, 2), (8, 3), (8, 4), (8, 5),
                (9, 0), (9, 1), (9, 2), (9, 3), (9, 4), (9, 5),
                (10, 2), (10, 3),
            ],
            [  # Tree 3
                (7, 12), (7, 13),
                (8, 10), (8, 11), (8, 12), (8, 13), (8, 14), (8, 15),
                (9, 10), (9, 11), (9, 12), (9, 13), (9, 14), (9, 15),
                (10, 12), (10, 13),
            ]
        ]

        self.tree_alives = [1, 1, 1]

        # Flatten all apple positions
        self.apple_positions = [pos for group in self.apple_groups for pos in group]

        # Initial grid
        self.initial_grid = np.zeros(grid_size, dtype=int)
        for x, y in self.apple_positions:
            self.initial_grid[x, y] = 1

        # Generate random initial positions for 4 agents
        self.initial_agent_positions = self.generate_random_agent_positions(num_agents=4)
        self.reset()

    def generate_random_agent_positions(self, num_agents=4):
        positions = set()
        while len(positions) < num_agents:
            pos = (
                random.randint(0, self.grid_size[0] - 1),
                random.randint(0, self.grid_size[1] - 1)
            )
            if pos not in self.apple_positions and pos not in positions:
                positions.add(pos)
        return list(positions)

    def get_state(self):
        agent_state = np.array(self.agent_positions).flatten()
        grid_state = np.array(self.grid).flatten()
        return np.concatenate([agent_state, grid_state])

    def move_agent(self, agent_idx, action):
        x, y = self.agent_positions[agent_idx]
        new_x, new_y = x, y

        if action == 0 and x > 0:
            new_x -= 1
        elif action == 1 and x < self.grid_size[0] - 1:
            new_x += 1
        elif action == 2 and y > 0:
            new_y -= 1
        elif action == 3 and y < self.grid_size[1] - 1:
            new_y += 1

        if (new_x, new_y) not in self.agent_positions:
            self.agent_positions[agent_idx] = (new_x, new_y)

    def step(self, actions):
        rewards = [0] * len(actions)

        for i, action in enumerate(actions):
            self.move_agent(i, action)
            x, y = self.agent_positions[i]
            if self.grid[x, y] == 1:
                remaining = sum(self.grid[x, y] for x, y in self.apple_positions)
                rewards[i] = -1000 if remaining == 1 else 1
                self.grid[x, y] = 0

        # Regeneration per tree
        agent_positions_set = set(self.agent_positions)
        for tree_index, tree in enumerate(self.apple_groups):
            apples_alive = sum(self.grid[x, y] for x, y in tree)

            # If no apples remain, the tree is dead and does not regenerate
            if apples_alive == 0:
                self.tree_alives[tree_index] = 0
                continue

            # Probability of regenerating an apple in this tree
            regen_prob = self.rate_regen * (apples_alive / len(tree))

            # Empty spots inside the tree, not occupied by agents
            empty_spots = [
                (x, y) for (x, y) in tree
                if self.grid[x, y] == 0 and (x, y) not in agent_positions_set
            ]

            # Attempt to regenerate one apple
            if empty_spots and random.random() < regen_prob:
                new_apple = random.choice(empty_spots)
                self.grid[new_apple] = 1

        return self.get_state(), rewards

    def render(self):
        grid_display = self.grid.astype(str)
        for i, (x, y) in enumerate(self.agent_positions):
            grid_display[x, y] = f'#{i}'
        print("\n".join(" ".join(row) for row in grid_display))
        print()

    def reset(self):
        self.grid = self.initial_grid.copy()
        self.agent_positions = self.initial_agent_positions.copy()

    def trigger_disruption(self, magnitude=0.4):
        current = [(x, y) for x, y in self.apple_positions if self.grid[x, y] == 1]
        if len(current) > 1:
            to_remove = random.sample(current, int(np.ceil(len(current) * magnitude)))
            for x, y in to_remove:
                self.grid[x, y] = 0

    def set_state(self, state):
        self.agent_positions = [(state[i], state[i+1]) for i in range(0, 7, 2)]
        grid_state = np.array(state[2*4:]).reshape(self.grid_size)
        if grid_state.shape != self.grid.shape:
            raise ValueError("Grid size mismatch.")
        self.grid = grid_state.copy()
