from __future__ import annotations

import typing
import numpy as np
import json

if typing.TYPE_CHECKING:
    from typing import Dict, Tuple, List
    from core.transition_function import TransitionFunction


class Environment:
    """
    A class used to represent the environment of the Markov Decision Process
    ...

    Attributes
    ----------
    actions : List[List[int]]
        actions[i] is a list of the possible actions that can be chosen
        at state i
    adj_matrix : List[List[int]]
        adj_matrix[i,j] is equal to 1 if state i and state j are connected
        with probability > 0
    layers : List[List[int]]
        layers[i] is a list of states contained in layer number i
    n_constraints: int
        is the number of constraints in the MDP
    sas_ind : Dict[Tuple[int,int,int],int]
        is a progressive indexing of all valid triples state, action, state_p (sas)
    sas_count : np.array[int]
        is a vector. sas_count[j] - sas_count[j-1] contains the number of
        possible valid triple (x,a,x_p), x in layer j
    sa_ind: Dict[Tuple[int,float],int]
        is a progressive indexing of all valid pairs state, action (sa)
    sa_count: np.array[int]
        is a vector. sa_count[j] - sa_count[j-1] contains the number of
        possible valid pairs state, action in layer j
    n_states: int
        is the number of states in the MDP
    n_actions: int
        is the number of actions in the MDP             

    Methods
    -------
    _initialize_indexes()
        Calculates _sas_ind, _sas_count and _sa_ind

    _initialize_n_states()
        Calculates _n_states

    _initialize_n_actions()
        Calculates _n_actions

    get_layer(x: int)
        Returns the index of the layer that contains state x

    play_policy(policy: Dict[Tuple[int,int],float])
        Plays an episode using policy and returns path, rewards and constraints
    """

    def __init__(self, data, path="../config/env_config.json"):
        self._read_config(path, data)
        self._sas_ind: Dict[
            Tuple[int, int, int], int
        ] = None  # state-action-state indexing
        self._sas_count: List[int] = None  # state-action-state triple count per layer
        self._sa_ind: Dict[Tuple[int, int], int] = None  # state-action indexing
        self._sa_count: List[int] = None  # state-action pair count per layer
        self._n_states: int = None
        self._n_actions: int = None

    @property
    def sas_ind(self):
        """Returns the indexing of all valid triples state, action, state_p (sas)"""
        if self._sas_ind is None:
            (
                self._sas_ind,
                self._sas_count,
                self._sa_ind,
                self._sa_count,
            ) = self._initialize_indexes()
        return self._sas_ind

    @property
    def sas_count(self):
        """Returns the number of possible valid triple (x,a,x_p), x in each layer"""
        if self._sas_count is None:
            (
                self._sas_ind,
                self._sas_count,
                self._sa_ind,
                self._sa_count,
            ) = self._initialize_indexes()
        return self._sas_count

    @property
    def sa_ind(self):
        """Returns the indexing of all valid pairs state, action (sa)"""
        if self._sas_count is None:
            (
                self._sas_ind,
                self._sas_count,
                self._sa_ind,
                self._sa_count,
            ) = self._initialize_indexes()
        return self._sa_ind

    @property
    def sa_count(self):
        """Returns the number of possible valid pairs state, action (sa) in each layer"""
        if self._sa_count is None:
            (
                self._sas_ind,
                self._sas_count,
                self._sa_ind,
                self._sa_count,
            ) = self._initialize_indexes()
        return self._sa_count

    @property
    def n_states(self):
        """Returns the number of states in the MDP"""
        if self._n_states is None:
            self._n_states = self._initialize_n_states()
        return self._n_states

    @property
    def n_actions(self):
        """Returns the number of actions in the MDP"""
        if self._n_actions is None:
            self._n_actions = self._initialize_n_actions()
        return self._n_actions

    def _read_config(self, path, data):
        """Reads the configuration file

        The json file must contain the following variables:
             - actions: list of lists of actions (one list for each node)
             - adj_matrix: adjacency matrix of the graph
             - layers: list of lists of states (one list for each layer)
         If the standard configuration is used, the json file must contain:
             - actions_flight: number of actions of the flight
             - actions_pol: number of actions of the complementary asset
             - adj_matrix: adjacency matrix of the graph
             - layers: list of lists of states

        """
        with open(path, "r") as f:
            variables = json.load(f)
            if data:
                self.actions = [
                    list(range(variables["actions_flight"])), # state 0
                    list(
                        range(
                            variables["actions_flight"],
                            variables["actions_pol"] + variables["actions_flight"],
                        )
                    ), # state 1
                    list(
                        range(
                            variables["actions_flight"],
                            variables["actions_pol"] + variables["actions_flight"],
                        )
                    ), # state 2
                    [variables["actions_pol"] + variables["actions_flight"]], # state 3
                    [variables["actions_pol"] + variables["actions_flight"] + 1], # state 4
                    [variables["actions_pol"] + variables["actions_flight"] + 2], # state 5
                ]

            else:
                self.actions = [
                    list(range(variables["actions_flight"])),
                    list(
                        range(
                            variables["actions_flight"],
                            variables["actions_pol"] + variables["actions_flight"],
                        )
                        ),
                    [variables["actions_pol"] + variables["actions_flight"]],
                    [variables["actions_pol"] + variables["actions_flight"] + 1],
                    [variables["actions_pol"] + variables["actions_flight"] + 2],
                ]
            self.adj_matrix = variables["adj_matrix"]
            self.layers = variables["layers"]
            self.n_constraints = variables["n_constraints"]

    def _initialize_indexes(self):
        """Calculates _sas_ind, _sas_count and _sa_ind

        Returns
        -------
        Tuple[Tuple[int, float, int], int ,List[int] , Tuple[int, float], int]
            indexing of sas (state-action-state) triple, count of sas triple in
            each layer, indexing of sa (state-action) pair
        """
        actions = self.actions
        adj_matrix = self.adj_matrix
        layers = self.layers

        sas_ind = {}  # state_action_state_p indexes
        sa_ind = {}  # state_action indexes
        ind_pair = 0
        ind = 0
        sa_count = np.zeros(
            len(layers), dtype=int
        )  # count of pairs state_action for each layer
        sas_count = np.zeros(
            len(layers), dtype=int
        )  # count of triple state_action_state_p for each layer

        for i, layer in enumerate(layers[:-1]):
            for node in layer:
                for action in actions[node]:
                    if (node, action) not in sa_ind:
                        sa_ind[(node, action)] = ind_pair
                        ind_pair = ind_pair + 1

                    for node_p in layers[i + 1]:
                        if adj_matrix[node][node_p]:  # check if there is a connection
                            sas_ind[(node, action, node_p)] = ind
                            ind = ind + 1
            sas_count[i + 1] = ind
            sa_count[i + 1] = ind_pair

        return sas_ind, sas_count, sa_ind, sa_count

    def _initialize_n_states(self):
        """Calculates _n_states

        Returns
        -------
        int
            number of states in the MDP
        """
        return sum(len(x) for x in self.layers)

    def _initialize_n_actions(self):
        """Calculates _n_actions

        Returns
        -------
        int
            number of actions in MDP
        """
        return sum(len(x) for x in self.actions if len(x) > 1)

    def get_layer(self, x: int) -> int:
        """Returns the index of the layer that contains state x

        Parameters
        ----------
        x : int
            state

        Returns
        -------
        int
            index of the layer in self.layers that contains x
        """
        return [i for i, layer in enumerate(self.layers) if x in layer][0]

    def play_policy(
        self, policy: Dict[Tuple[int, int], float], time_step: int
    ) -> Tuple[List, Dict, List]:
        """Plays policy and returns result for given episode
            WARNING: IMPLEMENTED in SimulationEnvironment
            This is a placeholder for the actual implementation in child classes

        Parameters
        ----------
        Policy : Dict[Tuple[int,float],int]
            Policy to be played

        Returns
        -------
        Tuple[List,List,List]
            Path, losses, constraints of the episode
        """

        path = []
        losses = []
        constraints = []

        state = self.layers[0][0]

        while state != self.layers[-1][-1]:

            # Sample policy
            sa_pairs = {
                pair: policy[pair] for pair in policy.keys() if pair[0] == state
            }
            keys = list(sa_pairs)
            index = np.random.choice(len(keys), 1, p=list(sa_pairs.values()))[0]
            played_pair = keys[index]

            # Get transition
            sas_triples = {
                # triple: Environment.transition_function[triple]
                # for triple in Environment.transition_function.keys()
                # if triple[0:2] == played_pair
            }
            keys = list(sas_triples)
            index = np.random.choice(len(keys), 1, p=list(sas_triples.values()))[0]
            selected_transition = keys[index]

            # Observe reward
            # Depending on the environment, the reward can be calculated differently

            # Observe constraints
            # Depending on the environment, the constraints can be calculated differently

            # Store values
            path.append(selected_transition)
            state = selected_transition[2]

        return path, losses, constraints
