#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Simulate a predator prey environment.
Each agent can just observe itself (it's own identity) i.e. s_j = j and vision sqaure around it.

Design Decisions:
    - Memory cheaper than time (compute)
    - Using Vocab for class of box:
         -1 out of bound,
         indexing for predator agent (from 2?)
         ??? for prey agent (1 for fixed case, for now)
    - Action Space & Observation Space are according to an agent
    - Rewards -0.05 at each time step till the time
    - Episode never ends
    - Obs. State: Vocab of 1-hot < predator, preys & units >
"""

# core modules
import curses

# 3rd party modules
import gym
import numpy as np
from gym import spaces
from ipdb import set_trace


class PredatorPreyEnv(gym.Env):

    def __init__(self, args):
        self.__version__ = "0.0.1"

        # TODO: better config handling
        self.OUTSIDE_CLASS = 1
        self.PREY_CLASS = 2
        self.PREDATOR_CLASS = 3
        self.TIMESTEP_PENALTY = -0.05
        self.PREY_REWARD = 0
        self.POS_PREY_REWARD = 0.05
        self.episode_over = False

        # Initialize random number generator
        self.np_random = None

        # track reward collected over the complete episode
        self.episode_reward = 0

        self._multi_agent_init(args)

    def init_curses(self):
        self.stdscr = curses.initscr()
        curses.start_color()
        curses.use_default_colors()
        curses.init_pair(1, curses.COLOR_RED, -1)
        curses.init_pair(2, curses.COLOR_YELLOW, -1)
        curses.init_pair(3, curses.COLOR_CYAN, -1)
        curses.init_pair(4, curses.COLOR_GREEN, -1)

    def _multi_agent_init(self, args):

        # General variables defining the environment : CONFIG
        params = ['dim', 'vision', 'moving_prey', 'mode', 'enemy_comm']
        for key in params:
            setattr(self, key, getattr(args, key))

        self.nprey = args.nenemies
        self.npredator = args.num_agents
        self.dims = dims = (self.dim, self.dim)
        self.stay = not args.no_stay

        if args.moving_prey:
            raise NotImplementedError

        # (0: UP, 1: RIGHT, 2: DOWN, 3: LEFT, 4: STAY)
        # Define what an agent can do -
        if self.stay:
            self.naction = 5
        else:
            self.naction = 4

        self.action_space = spaces.MultiDiscrete([self.naction])

        self.BASE = (dims[0] * dims[1])
        self.OUTSIDE_CLASS += self.BASE
        self.PREY_CLASS += self.BASE
        self.PREDATOR_CLASS += self.BASE

        # Setting max vocab size for 1-hot encoding
        self.vocab_size = 1 + 1 + self.BASE + 1 + 1
        #          predator + prey + grid + outside

        # Observation for each agent will be flattened vision * vision * vocab_size array
        obs_dim = self.vocab_size * ((2 * self.vision) + 1) * ((2 * self.vision) + 1)

        # Number of agents that will have observations (predators + prey if enemy_comm enabled)
        self.n_agents = self.npredator if not self.enemy_comm else self.npredator + self.nprey

        self.observation_space = [spaces.Box(low=0, high=1,
                                             shape=(obs_dim,),
                                             dtype=np.float32) for _ in range(self.n_agents)]
        # Actual observation will be of the shape 1 * n_agents * (2v+1) * (2v+1) * vocab_size
        tmp_obs = self.reset()
        share_obs_dim = int(np.prod(tmp_obs[0].shape)) * self.n_agents
        self.share_observation_space = [spaces.Box(low=0, high=1, shape=(share_obs_dim,), dtype=np.float32)
                                        for _ in range(self.n_agents)]

        return

    def step(self, action):
        """
        The agents take a step in the environment.

        Parameters
        ----------
        action : list/ndarray of length m, containing the indexes of what lever each 'm' chosen agents pulled.

        Returns
        -------
        obs, reward, episode_over, info : tuple
            obs (object) :

            reward (float) : Ratio of Number of discrete levers pulled to total number of levers.
            episode_over (bool) : Will be true as episode length is 1
            info (dict) : diagnostic information useful for debugging.
        """
        if self.episode_over:
            raise RuntimeError("Episode is done")
        action = np.array(action).squeeze()
        action = np.atleast_1d(action)

        for i, a in enumerate(action):
            self._take_action(i, a)

        assert np.all(action <= self.naction), "Actions should be in the range [0,naction)."

        self.obs = self._get_obs()
        reward = self._get_reward()

        # Reshape rewards to have shape (n_agents, 1) for consistency with framework
        reward = reward[:self.npredator].reshape(self.npredator, 1)

        self.episode_reward += np.mean(reward)

        # Create dones array for all agents - shape should be (n_agents,)
        dones = np.array([self.episode_over] * self.npredator, dtype=bool)

        debug = {
            'predator_locs': self.predator_loc, 
            'prey_locs': self.prey_loc, 
            'episode_reward': self.episode_reward,
            'success': self.stat.get('success', 0),  # Add success flag
            'predators_on_prey': int(np.sum(self.reached_prey))  # Track partial progress
        }
        return self.obs, reward, dones, debug

    def reset(self):
        """
        Reset the state of the environment and returns an initial observation.

        Returns
        -------
        observation (object): the initial observation of the space.
        """
        self.episode_over = False
        self.reached_prey = np.zeros(self.npredator)
        self.episode_reward = 0  # Reset episode reward

        # Locations
        locs = self._get_cordinates()
        self.predator_loc, self.prey_loc = locs[:self.npredator], locs[self.npredator:]

        self._set_grid()

        # stat - like success ratio
        self.stat = dict()

        # Observation will be npredator * vision * vision ndarray
        self.obs = self._get_obs()

        return self.obs

    def seed(self, seed=None):
        """
        Set the seed for this env's random number generator(s).

        Parameters
        ----------
        seed : int or None
            The seed value. If None, a random seed will be chosen.

        Returns
        -------
        list
            Returns the list of seeds used in this env's random number generators.
            The first value in the list should be the "main" seed, or the value which
            a reproducer should pass to 'seed'.
        """
        self.np_random = np.random.RandomState(seed)
        # Also set the global numpy seed for any legacy code that might use np.random directly
        # Note: In production, you might want to avoid this and only use self.np_random
        if seed is not None:
            np.random.seed(seed)
        return [seed]

    def _get_cordinates(self):
        """Generate random coordinates for agents using the seeded random generator."""
        if self.np_random is not None:
            idx = self.np_random.choice(np.prod(self.dims), (self.npredator + self.nprey), replace=False)
        else:
            idx = np.random.choice(np.prod(self.dims), (self.npredator + self.nprey), replace=False)
        return np.vstack(np.unravel_index(idx, self.dims)).T

    def _set_grid(self):
        self.grid = np.arange(self.BASE).reshape(self.dims)
        # Mark agents in grid
        # self.grid[self.predator_loc[:,0], self.predator_loc[:,1]] = self.predator_ids
        # self.grid[self.prey_loc[:,0], self.prey_loc[:,1]] = self.prey_ids

        # Padding for vision
        self.grid = np.pad(self.grid, self.vision, 'constant', constant_values=self.OUTSIDE_CLASS)

        self.empty_bool_base_grid = self._onehot_initialization(self.grid)

    def _get_obs(self):
        self.bool_base_grid = self.empty_bool_base_grid.copy()

        for i, p in enumerate(self.predator_loc):
            self.bool_base_grid[p[0] + self.vision, p[1] + self.vision, self.PREDATOR_CLASS] += 1

        for i, p in enumerate(self.prey_loc):
            self.bool_base_grid[p[0] + self.vision, p[1] + self.vision, self.PREY_CLASS] += 1

        obs = []
        for p in self.predator_loc:
            slice_y = slice(p[0], p[0] + (2 * self.vision) + 1)
            slice_x = slice(p[1], p[1] + (2 * self.vision) + 1)
            obs.append(self.bool_base_grid[slice_y, slice_x])

        if self.enemy_comm:
            for p in self.prey_loc:
                slice_y = slice(p[0], p[0] + (2 * self.vision) + 1)
                slice_x = slice(p[1], p[1] + (2 * self.vision) + 1)
                obs.append(self.bool_base_grid[slice_y, slice_x])

        obs = np.stack(obs)
        # Flatten each agent's observation to 1D
        obs = obs.reshape(obs.shape[0], -1)
        return obs

    def _take_action(self, idx, act):
        # prey action
        if idx >= self.npredator:
            # fixed prey
            if not self.moving_prey:
                return
            else:
                raise NotImplementedError

        if self.reached_prey[idx] == 1:
            return

        # STAY action
        if act == 4:
            return

        # UP
        if act == 0 and self.grid[max(0,
                                      self.predator_loc[idx][0] + self.vision - 1),
        self.predator_loc[idx][1] + self.vision] != self.OUTSIDE_CLASS:
            self.predator_loc[idx][0] = max(0, self.predator_loc[idx][0] - 1)

        # RIGHT
        elif act == 1 and self.grid[self.predator_loc[idx][0] + self.vision,
        min(self.dims[1] - 1,
            self.predator_loc[idx][1] + self.vision + 1)] != self.OUTSIDE_CLASS:
            self.predator_loc[idx][1] = min(self.dims[1] - 1,
                                            self.predator_loc[idx][1] + 1)

        # DOWN
        elif act == 2 and self.grid[min(self.dims[0] - 1,
                                        self.predator_loc[idx][0] + self.vision + 1),
        self.predator_loc[idx][1] + self.vision] != self.OUTSIDE_CLASS:
            self.predator_loc[idx][0] = min(self.dims[0] - 1,
                                            self.predator_loc[idx][0] + 1)

        # LEFT
        elif act == 3 and self.grid[self.predator_loc[idx][0] + self.vision,
        max(0,
            self.predator_loc[idx][1] + self.vision - 1)] != self.OUTSIDE_CLASS:
            self.predator_loc[idx][1] = max(0, self.predator_loc[idx][1] - 1)

    def _get_reward(self):
        n = self.npredator if not self.enemy_comm else self.npredator + self.nprey
        reward = np.full(n, self.TIMESTEP_PENALTY)

        # Using the original vectorized approach for finding predators on prey
        # This handles single prey case efficiently
        on_prey = np.where(np.all(self.predator_loc == self.prey_loc, axis=1))[0]
        nb_predator_on_prey = on_prey.size

        if self.mode == 'cooperative':
            reward[on_prey] = self.POS_PREY_REWARD * nb_predator_on_prey
        elif self.mode == 'competitive':
            if nb_predator_on_prey:
                reward[on_prey] = self.POS_PREY_REWARD / nb_predator_on_prey
        elif self.mode == 'mixed':
            reward[on_prey] = self.PREY_REWARD
        else:
            raise RuntimeError("Incorrect mode, Available modes: [cooperative|competitive|mixed]")

        self.reached_prey[on_prey] = 1

        if np.all(self.reached_prey == 1) and self.mode == 'mixed':
            self.episode_over = True

        # Prey reward
        if nb_predator_on_prey == 0:
            reward[self.npredator:] = -1 * self.TIMESTEP_PENALTY
        else:
            # TODO: discuss & finalise
            reward[self.npredator:] = 0

        # Success ratio
        if self.mode != 'competitive':
            if nb_predator_on_prey == self.npredator:
                self.stat['success'] = 1
            else:
                self.stat['success'] = 0

        return reward

    def reward_terminal(self):
        return np.zeros_like(self._get_reward())

    def _onehot_initialization(self, a):
        ncols = self.vocab_size
        out = np.zeros(a.shape + (ncols,), dtype=int)
        out[self._all_idx(a, axis=2)] = 1
        return out

    def _all_idx(self, idx, axis):
        grid = list(np.ogrid[tuple(map(slice, idx.shape))])
        grid.insert(axis, idx)
        return tuple(grid)

    def render(self, mode='human', close=False):
        grid = np.zeros(self.BASE, dtype=object).reshape(self.dims)
        self.stdscr.clear()

        for p in self.predator_loc:
            if grid[p[0]][p[1]] != 0:
                grid[p[0]][p[1]] = str(grid[p[0]][p[1]]) + 'X'
            else:
                grid[p[0]][p[1]] = 'X'

        for p in self.prey_loc:
            if grid[p[0]][p[1]] != 0:
                grid[p[0]][p[1]] = str(grid[p[0]][p[1]]) + 'P'
            else:
                grid[p[0]][p[1]] = 'P'

        for row_num, row in enumerate(grid):
            for idx, item in enumerate(row):
                if item != 0:
                    if 'X' in item and 'P' in item:
                        self.stdscr.addstr(row_num, idx * 4, item.center(3), curses.color_pair(3))
                    elif 'X' in item:
                        self.stdscr.addstr(row_num, idx * 4, item.center(3), curses.color_pair(1))
                    else:
                        self.stdscr.addstr(row_num, idx * 4, item.center(3), curses.color_pair(2))
                else:
                    self.stdscr.addstr(row_num, idx * 4, '0'.center(3), curses.color_pair(4))

        self.stdscr.addstr(len(grid), 0, '\n')
        self.stdscr.refresh()

    def exit_render(self):
        curses.endwin()
