from __future__ import annotations

import typing
import numpy as np
from scipy.optimize import minimize

if typing.TYPE_CHECKING:
    from core.algorithm import Environment
    from core.transition_function import TransitionFunction
    from typing import Dict, Tuple, List


class OccupancyMeasure:
    """ Class that represents and defines the occupancy measure over the triples (x,a,x') of the MDP """
    def __init__(
        self,
        environment: Environment,
        transition_function: TransitionFunction,
        occupancy_measure: Dict[Tuple[int, int, int], float] = None,
    ):
        if occupancy_measure is None:
            occupancy_measure = dict()
        self._transition_function = transition_function
        self._occupancy_measure = occupancy_measure
        self._environment = environment


    @property
    def occupancy_measure(self):
        """Returns the occupancy measure defined on the triples (x,a,x')
        if not already defined, it is calculated as the uniform occupancy measure

        Returns
        -------
        Dict[Tuple[int, int, int], float]
            Occupancy measure
        """
        if not self._occupancy_measure:
            # Defining uniform occupancy measure over each triple
            for _, key in enumerate(self._environment.sas_ind.keys()):
                if self._environment.sas_ind[key] != -1:
                    layer = self._environment.get_layer(key[0]) + 1
                    self._occupancy_measure[key] = 1 / (
                        self._environment.sas_count[layer]
                        - self._environment.sas_count[layer - 1]
                    )

        return self._occupancy_measure

    @property
    def pairs_occupancy_measure(self):
        """Calculates the occupancy measure defined on the pairs (x,a)

        Returns
        -------
        Dict[Tuple[int, int], float]
            Pairs occupancy measure
        """
        pairs_occupancy_measure = dict()
        for key in self._occupancy_measure.keys():
            pairs_occupancy_measure[(key[0], key[1])] = sum(
                [
                    self._occupancy_measure[triple]
                    for triple in self._occupancy_measure.keys()
                    if triple[0] == key[0] and triple[1] == key[1]
                ]
            )
        return pairs_occupancy_measure

    def _optimize_occupancy_measure(
        self, loss: Dict[Tuple[int, int], float], lr: float
    ) -> Dict[Tuple[int, int, int], float]:
        """Updates the Occupancy Measure using algorithm described in Appendix A.1 of
        "Learning Adversarial Markov Decision Processes with Bandit Feedback and
        Unknown Transition", Chi Jin et al.

        Parameters
        ----------
        loss : Dict[Tuple[int,int],float]
            loss over state action pairs
        lr: float
            learning rate for optimization step

        Returns
        -------
        Dict[Tuple[int, int, int], float]
            Updated occupancy measure
        """

        occupancy_measure = self.occupancy_measure
        empiric_transition_function = (
            self._transition_function.empiric_transition_function
        )
        confidence = self._transition_function.confidence
        environment = self._environment

        # Defining optimization problem
        # MinimizationProblem Class is defined below
        prob = MinimizationProblem(
            confidence,
            empiric_transition_function,
            environment,
            occupancy_measure,
            lr,
            loss,
        )

        # Updates occupancy measure
        self._occupancy_measure = prob.minimize()

        return self._occupancy_measure

    def _greedy(
        self,
        f: List[float],
        x: int,
        a: int,
        layer: List[int],
        tf: Dict[Tuple[int, int, int], float],
        conf: Dict[Tuple[int, int, int], float],
    ):
        """Solves greedily the optimization problem needed to calculate
        Upper Occupancy Bound described in appendix A.2 and Algorithm 4
        of "Learning Adversarial Markov Decision Processes with Bandit
        Feedback and Unknown Transition", Chi Jin et al.

        Parameters
        ----------
        f: np.array[float]
            f(x) is the probability of passing through state x tilde starting from
            state x under transition function P and policy Pi
            Here only the slice of f defined over layer successive to layer(x) is passed
        x : int
            state
        a: int
            action
        tf: Dict[Tuple[int,int,int],float]
            transition function
        conf: Dict[Tuple[int,int,int],float
            confidence

        Returns
        -------
        float
            f dot p, used to calculate UOB
        """
        # Layer is already successive layer
        j_m = 1
        j_p = len(layer)
        sigma = np.argsort(f)

        while j_m < j_p:
            x_m, x_p = layer[sigma[j_m - 1]], layer[sigma[j_p - 1]]

            if (x, a, x_m) not in tf or (x, a, x_p) not in tf:
                tf[(x, a, x_m)] = tf.get((x, a, x_m), 0)
                tf[(x, a, x_p)] = tf.get((x, a, x_p), 0)
                conf[(x, a, x_m)] = conf.get((x, a, x_m), 0)
                conf[(x, a, x_p)] = conf.get((x, a, x_p), 0)

            delta_m = min(tf[(x, a, x_m)], conf[(x, a, x_m)])
            delta_p = min(1 - tf[(x, a, x_p)], conf[(x, a, x_p)])
            tf[(x, a, x_m)] -= min(delta_m, delta_p)
            tf[(x, a, x_p)] += min(delta_m, delta_p)

            if delta_m <= delta_p:
                conf[(x, a, x_p)] -= delta_m
                j_m += 1
            else:
                conf[(x, a, x_m)] -= delta_p
                j_p -= 1
        sigma_l = [(x, a, layer[y]) for y in sigma]

        return np.sum(np.multiply([tf[x] for x in sigma_l], f[sigma]))

    def _optimize_upper_bound(
        self, x: int, a: int, policy: Dict[Tuple[int, int], float]
    ) -> float:
        """Calculates the Upper Occupancy Bound for state-action pair (x,a)
        using Algorithm 3 Comp-UOB of "Learning Adversarial Markov Decision
        Processes with Bandit Feedback and Unknown Transition", Chi Jin et al.

        Parameters
        ----------
        x : int
            state
        a: int
            action
        policy: Dict[Tuple[int,int],float]
            Policy induced by occupancy_measure

        Returns
        -------
        float
            UOB for (x,a)
        """
        env = self._environment
        transition_function = self._transition_function.empiric_transition_function
        confidence = self._transition_function.confidence
        k = env.get_layer(x)
        flat_layers = [state for sublist in env.layers[: k + 1] for state in sublist]

        f = np.array([node == x for node in flat_layers], dtype=float)

        for i, layer in reversed(list(enumerate(env.layers[:k]))):
            for node in layer:
                gr = np.empty(len(env.actions[node]))
                for j, action in enumerate(env.actions[node]):
                    gr[j] = self._greedy(
                        f[env.layers[i + 1][0] : env.layers[i + 1][-1] + 1],
                        node,
                        action,
                        env.layers[i + 1],
                        dict(transition_function),
                        dict(confidence),
                    )

                f[node] = np.sum(
                    np.multiply(
                        [policy[(node, action)] for action in env.actions[node]], gr
                    )
                )

        return f[0] * policy[(x, a)]

    def compute_upper_bound(self, path) -> Dict[Tuple[int, int], float]:
        """Computes Upper Occupancy Bound for obtained trajectory used to estimate loss

        Parameters
        ----------
        path : List[Tuple[int,int]]
            Trajectory of state action pairs

        Returns
        -------
        Dict[Tuple[int, int], float]
            Upper occupancy bound for each state action pair of the trajectory
        """
        u = {}
        policy = self.from_occupancy_measure_to_policy()

        for x, a, x_p in path:
            u[(x, a)] = self._optimize_upper_bound(x, a, policy)

        return u

    def from_occupancy_measure_to_policy(self):
        """
        Calculate the policy induced by the occupancy measure
        At first, an unnormalized induced policy is calculated at (x,a) by summing
        over occupancy measure at all possible triples x': (x,a,*)
        The normalized policy at (x,a) is thus calculated by dividing over the sum of the
        unnormalized induced policy of all possible a, x': (x,*,*)

        Returns
        -------
            Dict[Tuple[int,int],float]
                Policy induced by occupancy_measure
        """
        occupancy_measure = self.occupancy_measure
        un_induced_policy = {}  # Unnormalized induced policy
        induced_policy = {}  # Induced policy from occupancy measure
        state_action_pairs = [triple[0:2] for triple in occupancy_measure.keys()]
        state_action_pairs = list(dict.fromkeys(state_action_pairs))  # Unique values

        for state_action in state_action_pairs:
            keys = [
                triple
                for triple in occupancy_measure.keys()
                if state_action[0:2] == triple[0:2]
            ]  # (x,a,x') for every valid triple starting with specific (x,a,*)
            prob = sum(
                [occupancy_measure[triple] for triple in keys]
            )  # sum_(x':(x,a,x')) q(x,a,x')
            un_induced_policy[state_action] = prob

        for state_action in state_action_pairs:
            induced_policy[state_action] = un_induced_policy[state_action] / sum(
                [
                    un_induced_policy[x]
                    for x in state_action_pairs
                    if state_action[0] == x[0]
                ]
            )

        return induced_policy

    def update_space(self, transition_function: TransitionFunction):
        """Updates the transition function"""
        self._transition_function = transition_function

    def update_occupancy_measure(self, loss, lr) -> OccupancyMeasure:
        """Updates the occupancy measure using the optimization problem

        Parameters
        ----------
        loss : Dict[Tuple[int,int],float]
            loss over state action pairs
        lr: float
            learning rate for optimization step

        Returns
        -------
        OccupancyMeasure
            Updated occupancy measure class
        """
        self._occupancy_measure = self._optimize_occupancy_measure(loss, lr)
        return self


class MinimizationProblem:
    """
    A class used to represent the dual optimization problem for updating the
    Occupancy Measure described in Appendix A.1 of "Learning Adversarial Markov
    Decision Processes with Bandit Feedback and Unknown Transition", Chi Jin et al.
    ...

    Attributes
    ----------
    big_b : Dict[Tuple[int, int, int], float]
        a dictionary that keeps track of the B variables, defined for each triple (x,a,x'),
        used for the final update in closed form
    big_z : str
        a dictionary that keeps track of the Z variables, defined for each layer k = 0...L-1,
        used for the final update in closed form
    _des_length : int
        the length of the dual variables vector to be found, equal to the
        number of states (beta) + 2*number of valid (x,a,x') triple (mu+, mu-)
    opt : bool
        a parameter to check if scipy.optimize.minimize is performing optimization
        (B,Z variables are not saved) or if B,Z are being calculated for the closed form update

    Methods
    -------
    _initialize_indexes()
        Calculates _des_length and _upd_occupancy_measure, which performs update using
        the unconstrained problem

    _generate_b(triple: Tuple[int,float,int], params: np.Array[Double])
        Calculates B variable for given triple

    _generate_zeta(i: int, params: np.Array[Double])
        Calculates Z variable for layer i

    _obj_function(params: np.Array[Double])
        Calculates obj_function of dual optimization problem

    _update(params: np.Array[Double])
        Performs projection step and closed form update on the unconstrained result,
        using parameters calculated via minimization of the dual problem

    minimize()
        Calculates optimal parameters of the dual optimization problem and calls
        update function to perform final projection step
    """

    def __init__(
        self,
        confidence,
        empirical_transition_function,
        env,
        occupancy_measure,
        lr,
        loss,
    ):
        self.big_b: Dict[Tuple[int, int, int], float] = dict()
        self.big_z: Dict[int, np.longdouble] = dict()
        self._env: Environment = env
        self._confidence: Dict[Tuple[int, int, int], float] = confidence
        self._des_length = None  # Length of the dual variables vector (n_states + 2*(x,a,x') triple -> beta, mu+, mu-)
        self._empirical_transition_function = empirical_transition_function
        self._occupancy_measure: Dict[Tuple[int, int, int], float] = occupancy_measure
        self._lr: float = lr
        self._loss: Dict[Tuple[int, int], float] = loss
        self._upd_occupancy_measure: Dict[Tuple[int, int, int], float] = {}
        self.opt: bool = False

    @property
    def des_length(self):
        """Returns the length of the dual variables vector to be found, equal to the
        number of states (beta) + 2*number of valid (x,a,x') triple"""
        if self._des_length is None:
            self._des_length, self._upd_occupancy_measure = self._initialize_indexes()
        return self._des_length

    @property
    def upd_occupancy_measure(self):
        """Returns the updated occupancy measure"""
        if not self._upd_occupancy_measure:
            self._des_length, self._upd_occupancy_measure = self._initialize_indexes()
        return self._upd_occupancy_measure

    def _initialize_indexes(self) -> Tuple[int, Dict[Tuple[int, int, int], float]]:
        """Calculates _des_length and _upd_occupancy_measure, which performs update using
        the unconstrained problem

        Returns
        -------
        Tuple[int, Dict[Tuple[int, int, int], float]]
            length of vector of dual variables and updated occupancy measure
        """
        self._des_length = self._env.n_states + 2 * self._env.sas_count[-1]

        # Optimization step without constraint
        for triple in self._occupancy_measure.keys():
            self._upd_occupancy_measure[triple] = self._occupancy_measure[
                triple
            ] * np.exp(-self._lr * self._loss[triple[0:2]])
        return self._des_length, self._upd_occupancy_measure

    def _generate_b(self, triple: Tuple[int, int, int], params) -> np.longdouble:
        """Calculates B variable for given triple.

        Parameters
        ----------
        triple : Tuple[int,int,int]
            State, action, state in next layer valid triple
        params: np.array[Double]
            Array of dual optimization variables [Beta(x),..., Mu+(x,a,x'),..., Mu-(x,a,x')]

        Returns
        -------
        double
            B variable
        """
        env = self._env
        sas_ind = env.sas_ind
        sas_count = env.sas_count
        n_states = env.n_states
        ind = sas_count[
            -1
        ]  # Total number of (x,a,x') triples, used for indexing paramaters to be optimized

        x = triple[0]
        a = triple[1]
        x_p = triple[2]

        i = env.get_layer(x)
        triples = list(sas_ind.keys())[sas_count[i] : sas_count[i + 1]]

        keys = [
            key for key in triples if triple[0:2] == key[0:2]
        ]  # All triples (x,a,*), * in layer i+1
        next_zeta = np.empty(
            len(keys), dtype=np.longdouble
        )  # Initialize vector for sum over all possible states in layer i+1

        for k, key in enumerate(keys):
            key_index = sas_ind[key]
            next_zeta[k] = (
                params[key_index + n_states] - params[key_index + n_states + ind]
            ) * self._empirical_transition_function[key] + (
                params[key_index + n_states] + params[key_index + n_states + ind]
            ) * self._confidence[
                key
            ]

        triple_index = sas_ind[triple]
        b = (
            params[x_p]
            - params[x]
            + params[triple_index + n_states + ind]
            - params[triple_index + n_states]
            - self._lr * self._loss[(x, a)]
            + np.sum(next_zeta)
        )

        if not self.opt:
            self.big_b[triple] = b  # store B value for final update
        return b

    def _generate_zeta(self, i: int, params) -> np.longdouble:
        """Calculates Z variable for given layer i
        Parameters
        ----------
        i : int
            Index of layer
        params: np.array[Double]
            Array of dual optimization variables [Beta(x),..., Mu+(x,a,x'),..., Mu-(x,a,x')]

        Returns
        -------
        double
            Z variable
        """
        env = self._env
        sas_ind = env.sas_ind
        sas_count = env.sas_count

        small_zeta = np.empty(
            sas_count[i + 1] - sas_count[i], dtype=np.longdouble
        )  # Initialize vector of function of triple (x,a,x') to be summed to obtain zeta for layer i
        triples = list(sas_ind.keys())[
            sas_count[i] : sas_count[i + 1]
        ]  # Retrieve valid triples (x,a,x') for layer i

        for j, triple in enumerate(triples):
            small_zeta[j] = self.upd_occupancy_measure[triple] * np.exp(
                self._generate_b(triple, params)
            )
        zeta = np.sum(small_zeta)

        if not self.opt:
            self.big_z[i] = zeta  # store zeta values for update
        return zeta

    def _obj_function(self, params) -> np.longdouble:
        """Calculates objective function of dual problem
        Parameters
        ----------
        params: np.array[Double]
            Array of dual optimization variables [Beta(x),..., Mu+(x,a,x'),..., Mu-(x,a,x')]

        Returns
        -------
        double
            The value of the objective function
        """

        layers = self._env.layers
        obj_dual = np.empty(
            1, dtype=np.longdouble
        )  # Initialize final objective function to be minimized
        big_zeta = np.empty(
            len(layers[:-1]), dtype=np.longdouble
        )  # Initialize a Z(mu,beta) for each layer

        for i, layer in enumerate(layers[:-1]):
            big_zeta[i] = self._generate_zeta(i, params)
        obj_dual = np.sum(np.log(big_zeta))
        return obj_dual

    def _update(self, params) -> Dict[Tuple[int, int, int], float]:
        """Performs projection step and closed form update on the unconstrained result,
        using parameters calculated via minimization of the dual problem

        Parameters
        ----------
        params: np.array[Double]
            Array of dual optimization variables [Beta(x),..., Mu+(x,a,x'),..., Mu-(x,a,x')]

        Returns
        -------
        Dict[Tuple[int, int, int], float]
            The updated occupancy measure
        """
        self._obj_function(params)
        final_occ = {}
        for triple in self._occupancy_measure.keys():
            if self.big_z[self._env.get_layer(triple[0])] != 0:
                final_occ[triple] = (
                    self.upd_occupancy_measure[triple] * np.exp(self.big_b[triple])
                ) / self.big_z[self._env.get_layer(triple[0])]
            else:
                raise (Exception("Division by zero"))

        return final_occ

    def minimize(self) -> Dict[Tuple[int, int, int], float]:
        """Calculates optimal parameters of the dual optimization problem and calls
        update function to perform final projection step

        Returns
        -------
        Dict[Tuple[int, int, int], float]
            The updated occupancy measure
        """
        self.opt = True  # Turns on opt mode so B and Z values are not stored

        x0 = np.ones(
            self.des_length
        )  # Initial guess: all 1s except beta for node x0 and xL
        x0[0] = 0
        x0[self._env.n_states - 1] = 0

        bnds = [(0, None)] * self.des_length  # Non-negativity constraints
        bnds[0] = (0, 0)  # First and last layer must have beta(x0) = beta(xL) = 0
        bnds[self._env.n_states - 1] = (
            0,
            0,
        )  # First and last layer must have beta(x0) = beta(xL) = 0

        results = minimize(
            self._obj_function,
            x0,
            bounds=bnds,
            method="SLSQP",  # L-BFGS-B has a bug and optimizes out of bounds after some thousands iterations
            options={"disp": False},
        )

        self.opt = False  # Turns off opt mode

        self._occupancy_measure = self._update(results.x)
        return self._occupancy_measure
