from __future__ import annotations

import typing
import numpy as np

if typing.TYPE_CHECKING:
    from core.algorithm import Environment
    from typing import Dict, Tuple, List

class TransitionFunction:
    """
    A class used to represent the estimated Transition Function
    ...

    Attributes
    ----------
    confidence : Dict[Tuple[int,float,int], float]
        confidence bound for triple (x,a,x_p)
    empiric_transition_function :  Dict[Tuple[int,float,int], float]
        trans. function for x_p|x,a
    _environment : Environment
        the mdp
    _par : float
        parameter fixed with the environment, used to estimate the confidence

    Methods
    -------
    _initialize_attributes()
        Initializes the confidence and empiric_transition_function attributes based on the environment

    update_epoch(sas_counter: Dict[Tuple[int, float, int], int], sa_counter: Dict[Tuple[int, float], int])
        Updates the transition function and confidence based on counters when entering new epoch,
        using equation (6) of "Learning Adversarial MDPs with Bandit Feedback and Unknown Transition",
        Chi Jin et al.
    """

    def __init__(
        self,
        environment: Environment,
        time_horizon: float,
        confidence_delta: float,
        confidence: Dict[Tuple[int, float, int], float] = dict(),
        empiric_transition_function: Dict[Tuple[int, float, int], float] = dict(),
    ):
        self._confidence = confidence
        self._empiric_transition_function = empiric_transition_function
        self._environment = environment
        self._par = np.log(
            (time_horizon * self._environment.n_states * self._environment.n_actions)
            / confidence_delta
        )

    @property
    def confidence(self):
        if not self._confidence:
            (
                self._confidence,
                self._empiric_transition_function,
            ) = self._initialize_attributes()
        return self._confidence

    @property
    def empiric_transition_function(self):
        if not self._empiric_transition_function:
            (
                self._confidence,
                self._empiric_transition_function,
            ) = self._initialize_attributes()
        return self._empiric_transition_function

    def _initialize_attributes(self):
        """Initializes the confidence and empiric_transition_function attributes based on the environment"""
        om = {}
        conf = {}
        sas_ind = self._environment.sas_ind
        for i, key in enumerate(self._environment.sas_ind.keys()):
            if sas_ind[key] != -1:
                layer = self._environment.get_layer(key[0]) + 1
                om[key] = 1 / (
                    self._environment.sas_count[layer]
                    - self._environment.sas_count[layer - 1]
                )

        un_induced_policy = {}  # Unnormalized induced policy
        induced_policy = {}  # Induced policy from occupancy measure
        state_action_pairs = {triple[0:2] for triple in om.keys()}

        for state_action in state_action_pairs:
            keys = [
                triple for triple in om.keys() if state_action[0:2] == triple[0:2]
            ]  # (x,a,x') for every valid triple starting with specific (x,a,*)
            prob = sum([om[triple] for triple in keys])  # sum_(x':(x,a,x')) q(x,a,x')
            un_induced_policy[state_action] = prob

        for triple in om:
            induced_policy[triple] = om[triple] / un_induced_policy[triple[0:2]]

            if len(self._environment.actions[triple[0]]) == 1:
                conf[triple] = 0
            else:
                conf[triple] = 0.5

        self._empiric_transition_function = induced_policy
        self._confidence = conf
        return self._confidence, self._empiric_transition_function

    def update_epoch(
        self,
        sas_counter: Dict[Tuple[int, float, int], int],
        sa_counter: Dict[Tuple[int, float], int],
    ) -> TransitionFunction:
        """Updates the transition function and confidence based on counters when entering new epoch,
        using equation (6) of "Learning Adversarial MDPs with Bandit Feedback and Unknown Transition",
        Chi Jin et al.

        Parameters
        ----------
        sas_counter: Dict[Tuple[int, float, int], int],
            Counters M(x,a,x_p)

        sa_counter: Dict[Tuple[int, float], int]
            Counters N(x,a)

        Returns
        -------
        TransitionFunction
            Updated transition function
        """
        for x, a, x_p in sas_counter.keys():
            self._empiric_transition_function[(x, a, x_p)] = sas_counter[(x, a, x_p)][
                0
            ] / max(1, sa_counter[(x, a)][0])
            self._confidence[(x, a, x_p)] = 2 * np.sqrt(
                (self._empiric_transition_function[x, a, x_p] * self._par)
                / (max(1, sa_counter[(x, a)][0] - 1))
            ) + (14 * self._par) / (3 * max(1, sa_counter[(x, a)][0] - 1))
        return self
