from __future__ import annotations
import sys
sys.path.append('../')

import gymnasium as gym
from gymnasium import spaces
import numpy as np
from collections import defaultdict 

import random
import omnisafe
from typing import Any, ClassVar
import torch
import torch.nn.functional as F

from omnisafe.envs.core import CMDP, env_register, env_unregister

import importlib.util

def fill_matrix(n_states, n_actions, rates, out_rate):

    assert n_actions == len(rates)


    matrix = np.zeros((n_states, n_states, n_actions))


    for s in range(n_states):
        for a in range(n_actions):
            in_rate = rates[a]
            if s == 0:
                matrix[s+1, s, a] = in_rate * (1 - out_rate)
                matrix[s, s, a] = 1 - in_rate * (1 - out_rate)
            elif s == (n_states-1):
                matrix[s-1, s, a] = (1 - in_rate) * out_rate
                matrix[s, s, a] = 1 - (1 - in_rate) * out_rate       
            else:
                matrix[s+1, s, a] = in_rate * (1 - out_rate)
                matrix[s, s, a] = 1 - (in_rate * (1 - out_rate) + (1 - in_rate) * out_rate)
                matrix[s-1, s, a] = (1 - in_rate) * out_rate

    return matrix

@env_register
@env_unregister
class MediaStreaming(CMDP):

    """
        Media streaming environment:
            where the state space is augmented with costs (not every state is reachable from every other state
            Reward structure: Negative reward when the buffer is empty
            Unsafe behaviour: when the number of fast rate actions invoked is > episode_length/2
            Goal: keep the buffer non-empty (reward objective) while making sure the total number of fast rate actions (0) is less than episode_length/2 (safety objective)
            Optimal policy: the optimal reward policy will always choose action 0 (fast rate) almost guaranteeing that the buffer is never empty
                thus maximising expected reward but failing the ensring the total number of fast rate actions (0) is less than episode_length/2
        """

    _support_envs: ClassVar[list[str]] = ['MediaStreaming-v0']  # Supported task names

    need_auto_reset_wrapper = True  # Whether `AutoReset` Wrapper is needed
    need_time_limit_wrapper = True  # Whether `TimeLimit` Wrapper is needed
    metadata = {"render_modes": ["text"]}

    def __init__(self, env_id: str, seed=0, episode_length=40, render_mode=None, **kwargs) -> None:
        
        np.random.seed(seed)

        self.env_id = env_id

        if self.env_id == 'MediaStreaming-v0':
            spec=importlib.util.spec_from_file_location("property", "./properties/media_streaming/property_1.py")
        else:
            raise RuntimeError(f"cost function not specified for env id {env_id}")
        
        properties = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(properties)

        self.cost_function = properties.cost_function

        self._num_envs = 1

        self._fast_rate = 0.9
        self._slow_rate = 0.1
        self._out_rate = 0.7
        self._buffer_size = 20

        self.n_states = self._buffer_size+1
        self.n_actions = 2
        self.n_automaton_states = len(self.cost_function.dfa.states)

        self.episode_length = episode_length

        self._observation_space = spaces.Box(low=np.array([0, 0]), high=np.array([self.n_states, self.n_automaton_states]), shape=(2,), dtype=np.float32)
        # Define box action space (4-dim for discrete actions)
        self._action_space = spaces.Box(low=-5, high=2, shape=(self.n_actions,), dtype=np.float32)

        self.transition_matrix = fill_matrix(self.n_states, self.n_actions, [self._fast_rate, self._slow_rate], self._out_rate)

        self._start_state = self._buffer_size//2
        self._unsafe_state = 0

        self.atomic_predicates = {"start", "empty"}

        def empty_set():
            return {}
        self.labelling_fn = defaultdict(empty_set) 
        self.labelling_fn[self._start_state] = ({"start"})
        self.labelling_fn[self._unsafe_state] = ({"empty"})

        # setup the reward function - negative reward when the buffer is empty
        self.reward_fn = defaultdict(float)
        for s in range(self.n_states):
            self.reward_fn[(s, 0)] = -1.0

        assert render_mode is None or render_mode in self.metadata["render_modes"]
        self.render_mode = render_mode

        # the buffer starts half full
        # and the cost starts at zero
        self._step_counter = 0

    def set_seed(self, seed: int) -> None:
        random.seed(seed)
        np.random.seed(seed)

    def _transition(self, action):
        """sample a next state randomly from the transition matrix"""
        return np.random.choice(self.n_states, p=self.transition_matrix[:, self._agent_location, action])

    def _get_labels(self):
        """return the labels for the current state"""
        return self.labelling_fn[self._agent_location]

    def _get_obs(self):
        """return the observation for the current state"""
        return np.array([self._agent_location, self._automaton_state],dtype=np.float32)

    def _get_info(self):
        """return the info for the current state"""
        return {}

    def _get_reward(self, action):
        """return the reward for the current state"""
        return self.reward_fn[(self._agent_location, action)]

    def _get_cost(self):
        labels = self._get_labels()
        cost, next_automaton_state = self.cost_function.step(labels)
        self._automaton_state = next_automaton_state
        return cost

    def _get_terminated(self):
        """check if the termination condition is satisfied"""
        return True if self._step_counter >= self.episode_length else False

    def _get_truncated(self):
        return True if self._step_counter >= self.episode_length else False

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[torch.Tensor, dict]:
        if seed is not None:
            self.set_seed(seed)
        """reset the environment and return the start obs"""
        # the buffer starts half full
        # and the cost starts at zero
        self._agent_location = self._start_state

        labels = self._get_labels()
        self.cost_function.reset()
        _, automaton_state = self.cost_function.step(labels)
        self._automaton_state = automaton_state

        observation = torch.as_tensor(self._get_obs())
        info = self._get_info()
        self._step_counter = 0

        if self.render_mode == "text":
            self._render_frame()

        return observation, info

    def step(
        self,
        action: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict]:
        """play a given action in the environment"""
        action_probs = F.softmax(action, dim=-1).detach().cpu().numpy()
            # Sample discrete action from the probabilities
        discrete_action = np.random.choice(self.n_actions, p=action_probs)

        next_state = self._transition(discrete_action)
        self._agent_location = next_state

        # increment step counter
        self._step_counter += 1

        terminated = torch.as_tensor(self._get_terminated())
        truncated = torch.as_tensor(self._get_truncated())
        reward = torch.as_tensor(self._get_reward(discrete_action))
        cost = torch.as_tensor(self._get_cost())
        obs = torch.as_tensor(self._get_obs())
        info = self._get_info()
        info.update({'final_observation': obs})

        if self.render_mode == "text":
            self._render_frame()
            
        return obs, reward, cost, terminated, truncated, info

    @property
    def max_episode_steps(self) -> None:
        """The max steps per episode."""
        return self.episode_length
    
    def render(self) -> Any:
        pass

    def close(self) -> None:
        pass