import gym
from gym import spaces
import numpy as np
from collections import defaultdict 
from gym.utils import seeding

def fill_matrix(n_states, n_actions, rates, out_rate):

    assert n_actions == len(rates)


    matrix = np.zeros((n_states, n_states, n_actions))


    for s in range(n_states):
        for a in range(n_actions):
            in_rate = rates[a]
            if s == 0:
                matrix[s+1, s, a] = in_rate * (1 - out_rate)
                matrix[s, s, a] = 1 - in_rate * (1 - out_rate)
            elif s == (n_states-1):
                matrix[s-1, s, a] = (1 - in_rate) * out_rate
                matrix[s, s, a] = 1 - (1 - in_rate) * out_rate       
            else:
                matrix[s+1, s, a] = in_rate * (1 - out_rate)
                matrix[s, s, a] = 1 - (in_rate * (1 - out_rate) + (1 - in_rate) * out_rate)
                matrix[s-1, s, a] = (1 - in_rate) * out_rate

    return matrix

class MediaStreaming(gym.Env):
    """
        Media streaming environment:
            where every state is reachable from evey other state
            Reward structure: Negative reward for choosing action 0 which corresponds to the fast rate
            Unsafe behaviour: when the buffer (s=0) is empty the streamer can't play their video :(
            Goal: keep the buffer non-empty (safety objective) while minimising the number of times the fast rate action (0) has to be picked (reward objective)
            Optimal policy: the optimal reward policy will always choose action 1 (slow rate) 
                thus maximising reward but failing the safety objective of letting the buffer be empty
    """

    metadata = {"render_modes": []}

    def __init__(self, seed, episode_length=40, render_mode=None):

        self.np_random, _ = seeding.np_random(seed)

        self.episode_length = episode_length

        self._fast_rate = 0.9
        self._slow_rate = 0.1
        self._out_rate = 0.7
        self._buffer_size = 20

        self.n_states = self._buffer_size+1
        self.n_actions = 2

        self.observation_space = spaces.Discrete(self.n_states)
        self.action_space = spaces.Discrete(self.n_actions)

        self.transition_matrix = fill_matrix(self.n_states, self.n_actions, [self._fast_rate, self._slow_rate], self._out_rate)

        # I am not sure what states are safe end components we might need to include the temporal information in the state space?
        self._start_state = self._buffer_size//2
        self._unsafe_state = 0

        self.atomic_predicates = {"start", "empty"}

        def empty_set():
            return {}
        self.labelling_fn = defaultdict(empty_set) 
        self.labelling_fn[self._start_state] = ({"start"})
        self.labelling_fn[self._unsafe_state] = ({"empty"})

        # setup the reward function only negative reward for action 0 otherwise reward = 0
        self.reward_fn = defaultdict(float)
        for s in range(self.n_states):
            self.reward_fn[(s, 0)] = -1.0

        assert render_mode is None or render_mode in self.metadata["render_modes"]
        self.render_mode = render_mode

        self._step_counter = 0

    def _transition(self, action):
        """sample a next state randomly from the transition matrix"""
        try:
            return self.np_random.choice(self.n_states, p=self.transition_matrix[:, self._agent_location, action])
        except ValueError:
            print(self._agent_location, action, self.transition_matrix[:, self._agent_location, action])
            raise RuntimeError
    
    def _get_labels(self):
        """return the labels for the current state"""
        return self.labelling_fn[self._agent_location]

    def _get_obs(self):
        """return the observation for the current state"""
        return self._agent_location

    def _get_info(self):
        """return the info for the current state"""
        return {"labels": self._get_labels()}

    def _get_reward(self, action):
        """return the reward for the current state"""
        return self.reward_fn[(self._agent_location, action)]

    def _get_terminated(self):
        """check if the termination condition is satisfied"""
        return True if self._step_counter >= self.episode_length else False

    def _get_truncated(self):
        """check if the termination condition is satisfied"""
        return True if self._step_counter >= self.episode_length else False

    def reset(self, seed=None, options=None):
        """reset the environment and return the start obs"""
        self._agent_location = self._start_state
        observation = self._get_obs()
        info = self._get_info()
        self._step_counter = 0
        
        if self.render_mode == "human":
            self._render_frame()

        return observation, info

    def step(self, action):
        """play a given action in the environment"""
        
        next_state = self._transition(action)
        self._agent_location = next_state

        # increment step counter
        self._step_counter += 1

        terminated = self._get_terminated()
        truncated=self._get_truncated()
        done = terminated or truncated
        reward = self._get_reward(action)
        observation = self._get_obs()
        info = self._get_info()
        info["is_truncated"] = truncated
        info["is_terminated"] = terminated

        if self.render_mode == "human":
            self._render_frame()
            
        return observation, reward, done, info

    def _render_frame():
        raise NotImplementedError