from __future__ import annotations
import sys
sys.path.append('../')

import numpy as np
from collections import defaultdict 
# import all we need

import random
import omnisafe
from typing import Any, ClassVar
import torch
import torch.nn.functional as F

import gymnasium as gym
from gymnasium import spaces
from collections import defaultdict 
from omnisafe.envs.core import CMDP, env_register, env_unregister
import importlib.util

def fill_matrix(grid_size, n_states, n_actions, coloured_states, start_states, prob=0.0, safe_states=[], wall_states=[], bomb_states=[]):
    """construct the transition dynamics as a dictionary of successor states and corresponding probabilities"""
    assert n_states == grid_size**2 

    grid = np.arange(n_states).reshape(grid_size, grid_size)

    # make action map
    action_map = {0: (0, -1), # left
                  1: (0, 1), # right
                  2: (1, 0), # up
                  3: (-1, 0), # down
                  4: (0, 0), # stay
                  5: (-1, -1), # left up
                  6: (-1, 1), # left down
                  7: (-1, 1), # right up
                  8: (1, 1), # right down
                  }

    assert n_actions < len(action_map.keys())

    matrix = np.zeros((n_states, n_states, n_actions))

    for y in range(grid_size):
        for x in range(grid_size):
            for a in range(n_actions):
                state = grid[y][x]
                if state in coloured_states:
                    prob_vec = np.zeros_like(matrix[:, state, a])
                    prob_vec[start_states] = 1.0
                    prob_vec = prob_vec / np.sum(prob_vec)
                    matrix[:, state, a] = prob_vec
                    continue

                next_y = int(np.clip(y + action_map[a][0], 0, grid_size-1))
                next_x = int(np.clip(x + action_map[a][1], 0, grid_size-1))
                next_state = grid[next_y, next_x]
                # check if the next state is a wall
                next_state = next_state if next_state not in wall_states else state
                # only the safe/coloured states are deterministic
                
                p = 1.0 if state in safe_states else 1.0-prob
                matrix[next_state, state, a] += p

                if p == 1.0:
                    continue

                rand_prob = prob * 1 / (n_actions - 1)
                for rand_a in range(n_actions):
                    if rand_a == a:
                        continue
                    next_y = int(np.clip(y + action_map[rand_a][0], 0, grid_size-1))
                    next_x = int(np.clip(x + action_map[rand_a][1], 0, grid_size-1))
                    next_state = grid[next_y, next_x]
                    # check if the next state is a wall
                    next_state = next_state if next_state not in wall_states else state
                    matrix[next_state, state, a] += rand_prob
                    
    return matrix

@env_register
@env_unregister
class ColourBombGridWorldV2(CMDP):
    """
    Colour Bomb Gridworld environment

    Input attributes:
        random_action_probability: probability of a random actions being selected
        episode_length: length of the episode until termination
        render_mode: how to render the environment [currently not implemented]

    Other attributes:
        grid_size: size of the grid world
        ncol: number of columns
        nrow: number of rows
        n_states: number of states (grid_size^2)
        n_actions: number of actions
        observation_space: gym spaces object
        action_space: gym spaces object
        reward_fn: the reward function of the environment
        _step_counter: total number of steps in the environment

    """
    _support_envs: ClassVar[list[str]] = ['ColourBomb15x15-v0', 'ColourBomb15x15-v1', 'ColourBomb15x15-v2', 'ColourBomb15x15-v3']  # Supported task names

    need_auto_reset_wrapper = True  # Whether `AutoReset` Wrapper is needed
    need_time_limit_wrapper = True  # Whether `TimeLimit` Wrapper is needed
    metadata = {"render_modes": ["ascii"]}

    def __init__(self, env_id: str, seed=0, random_action_probability=0.1, episode_length=250, render_mode=None, **kwargs) -> None:

        np.random.seed(seed)

        self.env_id = env_id

        if self.env_id == 'ColourBomb15x15-v0':
            spec=importlib.util.spec_from_file_location("property", "./properties/colour_bomb_grid_world_v2/property_1.py")
        elif self.env_id == 'ColourBomb15x15-v1':
            spec=importlib.util.spec_from_file_location("property", "./properties/colour_bomb_grid_world_v2/property_2.py")
        elif self.env_id == 'ColourBomb15x15-v2':
            spec=importlib.util.spec_from_file_location("property", "./properties/colour_bomb_grid_world_v2/property_3.py")
        elif self.env_id == 'ColourBomb15x15-v3':
            spec=importlib.util.spec_from_file_location("property", "./properties/colour_bomb_grid_world_v2/property_4.py")
        else:
            raise RuntimeError(f"cost function not specified for env id {env_id}")
        
        properties = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(properties)

        self.cost_function = properties.cost_function

        self._num_envs = 1
        self.grid_size = 15
        self.ncol = self.grid_size
        self.nrow = self.grid_size

        # we need an abstract start state
        self.n_states = self.grid_size**2

        # I think we need 5 actions to guarantee the existence of safe end components as we need a stay still action
        self.n_actions = 5
        self.n_automaton_states = len(self.cost_function.dfa.states)

        self.random_action_probability = random_action_probability
        self.episode_length = episode_length

        self._observation_space = spaces.Box(low=np.array([0, 0]), high=np.array([self.n_states, self.n_automaton_states]), shape=(2,), dtype=np.float32)
        # Define box action space (4-dim for discrete actions)
        self._action_space = spaces.Box(low=-5, high=2, shape=(self.n_actions,), dtype=np.float32)

        self._wall_states = [45,60,75,210,195,180,165,150, 142] + [211,212,213,214,215,216] + [220,221,222,223,224,209] + [183,184,169,185,186,187] + [192,177,162,161,160] + [143,144,129] + [138,139,140,141,125]+ [3,18,47,62,63,64,50,35,20,80,81,95] + [83,84,99,100,116,131,133,134] + [87,72,57,70,55,39,9,13,14,29,44,59]
        self._start_states = [16,199,178,112,26]
        self._bomb_states = [76, 181, 123,82,207,8,58]
        self._medic_states = [154, 93, 38, 205, 74]
        self._green_states = [170]
        self._blue_states = [121,122,136,137]
        self._yellow_states = [176,191]
        self._red_states = [88,89,103,104]
        self._pink_states = [53,54,68,69]

        # all the coloured states are deterministic and therefore each set of coloured states forms a safe end component
        self._safe_states = self._red_states + self._yellow_states + self._blue_states + self._pink_states + self._green_states + self._medic_states
        self._coloured_states = self._red_states + self._yellow_states + self._blue_states + self._pink_states + self._green_states

        self.transition_matrix = fill_matrix(self.grid_size, self.n_states, self.n_actions, self._coloured_states, self._start_states, prob=random_action_probability, safe_states=self._safe_states, wall_states=self._wall_states, bomb_states=self._bomb_states)

        self.atomic_predicates = {"start", "green", "yellow", "blue", "red", "pink", "bomb", "medic"}
        def empty_set():
            return {}

        self.labelling_fn = defaultdict(empty_set) 

        for state in self._start_states:
            self.labelling_fn[state] = ({"start"})
        for state in self._green_states:
            self.labelling_fn[state] = ({"green", "colour"})
        for state in self._yellow_states:
            self.labelling_fn[state] = ({"yellow", "colour"})
        for state in self._blue_states:
            self.labelling_fn[state] = ({"blue", "colour"})
        for state in self._red_states:
            self.labelling_fn[state] = ({"red", "colour"})
        for state in self._pink_states:
            self.labelling_fn[state] = ({"pink", "colour"})
        for state in self._bomb_states:
            self.labelling_fn[state] = ({"bomb"})
        for state in self._medic_states:
            self.labelling_fn[state] = ({"medic"})

        # get a reward for entering a coloured state - maybe we need a different reward function since wouldn't an optimal policy go to the nearest coloured state and stay there?
        self.reward_fn = defaultdict(float)
        for s in self._coloured_states:
            self.reward_fn[s] = 1.0

        assert render_mode is None or render_mode in self.metadata["render_modes"]
        self.render_mode = render_mode

        self._step_counter = 0

    def set_seed(self, seed: int) -> None:
        random.seed(seed)
        np.random.seed(seed)

    def _transition(self, action):
        """sample a next state randomly from the transition matrix"""
        return np.random.choice(self.n_states, p=self.transition_matrix[:, self._agent_location, action])

    def _get_labels(self):
        """return the labels for the current state"""
        return self.labelling_fn[self._agent_location]

    def _get_obs(self):
        """return the observation for the current state"""
        return np.array([self._agent_location, self._automaton_state],dtype=np.float32)

    def _get_info(self):
        """return the info for the current state"""
        return {}

    def _get_reward(self):
        """return the reward for the current state"""
        return self.reward_fn[self._agent_location]

    def _get_cost(self):
        labels = self._get_labels()
        cost, next_automaton_state = self.cost_function.step(labels)
        self._automaton_state = next_automaton_state
        return cost

    def _get_terminated(self):
        return False

    def _get_truncated(self):
        return True if self._step_counter >= self.episode_length else False

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[torch.Tensor, dict]:
        if seed is not None:
            self.set_seed(seed)
        """reset the environment and return the start obs"""
        self._agent_location = np.random.choice(self._start_states)

        labels = self._get_labels()
        self.cost_function.reset()
        _, automaton_state = self.cost_function.step(labels)
        self._automaton_state = automaton_state
        
        observation = torch.as_tensor(self._get_obs())
        info = self._get_info()
        self._step_counter = 0

        if self.render_mode == "ascii":
            self._render_frame()

        return observation, info

    def step(
        self,
        action: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict]:
        """play a given action in the environment"""
        action_probs = F.softmax(action, dim=-1).detach().cpu().numpy()
            # Sample discrete action from the probabilities
        discrete_action = np.random.choice(self.n_actions, p=action_probs)

        """play a given action in the environment"""
        next_state = self._transition(discrete_action)
        self._agent_location = next_state

        # increment step counter
        self._step_counter += 1

        terminated = torch.as_tensor(self._get_terminated())
        truncated = torch.as_tensor(self._get_truncated())
        reward = torch.as_tensor(self._get_reward())
        cost = torch.as_tensor(self._get_cost())
        obs = torch.as_tensor(self._get_obs())
        info = self._get_info()
        info.update({'final_observation': obs})

        if self.render_mode == "ascii":
            self._render_frame()
            
        return obs, reward, cost, terminated, truncated, info

    @property
    def max_episode_steps(self) -> None:
        """The max steps per episode."""
        return self.episode_length
    
    def render(self) -> Any:
        pass

    def close(self) -> None:
        pass

