from __future__ import annotations

import typing
import json
import logging
from collections import defaultdict
from scipy.optimize import minimize
import numpy as np

from core.occupancy_measure import OccupancyMeasure
from core.environment import Environment
from core.transition_function import TransitionFunction

if typing.TYPE_CHECKING:
    from typing import Any, List, Tuple, Dict


class AlgorithmMDP:
    """
    A class used to represent the Primal algorithm for the MDP.
    """

    occupancy_measure: OccupancyMeasure
    transition_function: TransitionFunction
    temporal_horizon: int
    exploration_gamma: float
    confidence_delta: float
    environment: Environment

    def __init__(
        self,
        temporal_horizon: int,
        confidence_delta: float,
        environment: Environment
    ):
        self.temporal_horizon = temporal_horizon
        self.confidence_delta = confidence_delta
        self.environment = environment
        self._sa_counter, self._sas_counter = dict(), dict()
        self._sa_counter, self._sas_counter = self._initialize_counters()
        self._epoch = 1

        self.transition_function = TransitionFunction(
            environment, temporal_horizon, confidence_delta
        )
        self.occupancy_measure = OccupancyMeasure(environment, self.transition_function)

    @property
    def learning_rate(self):
        """Returns the learning rate of the algorithm."""
        return 1 / np.sqrt(self.temporal_horizon)

    def _initialize_counters(self):
        """Initializes the counters of visits for every triple (x,a,x') (sas_counter)
           or every pair (x,a) (sa_counter) in the MDP to 0.

        Returns
        ----------
        Dict[Tuple[int,int],int]], Tuple[Dict[Tuple[int,int,int],int]
            Tuple of dict counters (x,a), (x,a,x')
        """

        sa = self.environment.sa_ind.keys()
        sas = self.environment.sas_ind.keys()

        for pair in sa:
            self._sa_counter[pair] = [0, 0]
        for triple in sas:
            self._sas_counter[triple] = [0, 0]

        return self._sa_counter, self._sas_counter

    def _update_counters(self, path: List[Tuple[int, int, int]]):
        """Updates the counters of visits for every triple (x,a,x') (sas_counter)
           or every pair (x,a) (sa_counter) in the path.
           If a counter doubles the amount of the previous epoch a new epoch starts
           and the empirical transition function and confidence sets are updates.


        Parameters
        ----------
        Path : List[Tuple[int, float, int]]
            Trajectory consisting of list of (x,a,x') visited

        Returns
        ----------
        Tuple[Dict[Tuple[int,int,int],int], Dict[Tuple[int,int],int]]
            Tuple of dict counters (x,a,x'), (x,a)

        """
        for x, a, x_p in path:
            self._sa_counter[(x, a)][1] += path[(x, a, x_p)]
            self._sas_counter[(x, a, x_p)][1] += path[(x, a, x_p)]

        if max([i[1] // max(i[0], 1) for i in self._sa_counter.values()]) >= 2:
            self._epoch += 1
            for x, a, x_p in self._sas_counter:
                self._sa_counter[(x, a)][0] = self._sa_counter[(x, a)][1]
                self._sas_counter[(x, a, x_p)][0] = self._sas_counter[(x, a, x_p)][1]

            self.transition_function = self.transition_function.update_epoch(
                self._sas_counter, self._sa_counter
            )
            self.occupancy_measure.update_space(self.transition_function)

        return self._sa_counter, self._sas_counter

    def update(
        self, loss_mdp: Dict[Tuple[int, int], float], path: List[Tuple[int, int, int]]
    ):
        """Updates the occupancy measure using algorithm described in Appendix A.1 of
        "Learning Adversarial Markov Decision Processes with Bandit Feedback and
        Unknown Transition", Chi Jin et al.

        Parameters
        ----------
        loss_mdp : Dict[Tuple[int,int],float]
            loss_mdp for every (x,a) state action pair in the environment

        path : List[Tuple[int, int, int]]
            Trajectory consisting of list of (x,a,x') visited
        """

        # Calculate upper occupancy bound for every state-action pair visited
        ubs = self.occupancy_measure.compute_upper_bound(path)

        # Update loss_mdp estimate with upper bound
        for x, a, x_p in path:
            loss_mdp[(x, a)] = loss_mdp[(x, a)] / (ubs[(x, a)] + self.learning_rate)

        # Update counters
        self._sa_counter, self._sas_counter = self._update_counters(path)

        # Update occupancy measure
        self.occupancy_measure = self.occupancy_measure.update_occupancy_measure(
            loss_mdp, self.learning_rate
        )


class AlgorithmConstraints:
    """
    Class used to represent the Dual Algorithm for the constraints of the MDP.
    The algorithm is described in the paper
    "A Best-of-Both-Worlds Algorithm for Constrained MDPs with Long Term Constraints"
    """

    def __init__(
        self,
        environment: Environment,
        temporal_horizon: int,
    ):
        self.temporal_horizon: int = temporal_horizon
        self._n_constraints = environment.n_constraints
        self._lagrangian_vector = np.zeros(self._n_constraints, dtype=float)
        self.environment = environment

        if self.environment.n_constraints > 0:
            print("AlgorithmConstraints learning rate: ", self.learning_rate)

    @property
    def lagrangian_vector(self):
        """Returns the lagrangian vector of the dual algorithm."""
        return self._lagrangian_vector

    @property
    def learning_rate(self):
        """Returns the learning rate of the dual algorithm."""
        return 1 / np.sqrt(self.temporal_horizon)

    def update_lagrangian_vector(self, gradient: List[float]):
        """
        Updates the lagrangian vector using gradient descent and projection on the feasible set.

        Parameters
        ----------
        gradient : List[float]
            Gradient of the lagrangian vector

        Returns
        ----------
        List[float]
            Updated lagrangian vector
        """

        n_constraints = self._n_constraints
        if n_constraints == 0:
            return None

        # Unconstrained gradient descent
        upd_lagrangian_vector = self.lagrangian_vector - self.learning_rate * gradient

        # Projection on the feasible set
        bnds = [(0, None)] * n_constraints
        x0 = np.ones(n_constraints, dtype=float)
        opt = {"disp": False}

        res_cons = minimize(
            lambda x, upd_lagrangian_vector: np.linalg.norm(x - upd_lagrangian_vector),
            x0,
            args=upd_lagrangian_vector,
            bounds=bnds,
            method="SLSQP",
            options=opt,
        )
        if not res_cons.success:
            raise ValueError(
                "Constraints not satisfied, optimization of lagrangian failed"
            )
        self._lagrangian_vector = res_cons.x


class ConstrainedMDP:
    """ Class used to represent the PD-DP algorithm. """
    algorithm_mdp: AlgorithmMDP
    algorithm_constraints: AlgorithmConstraints
    environment: Environment

    def __init__(self):
        self.stored_path = []
        self.stored_rewards = []
        self.stored_pure_rewards = []
        self.stored_constraints = []

    @staticmethod
    def instantiate_algorithms(
        environment: Environment,
        path_sim="../config/sim_config.json",
    ):
        """
        Instantiates the algorithms using the parameters in the simulation config file.
        """
        with open(path_sim, "r", encoding="utf-8") as f:
            variables = json.load(f)
            confidence_delta = variables["confidence_delta"]
            temporal_horizon = variables["temporal_horizon"]
            n_batch = variables["n_batch"]
            mean_update = variables["mean_update"]

            # Updates the temporal horizon if mean_update
            # is True in order to take into account the number of updates
            if mean_update:
                temporal_horizon = temporal_horizon // n_batch

        ConstrainedMDP.algorithm_mdp = AlgorithmMDP(
            temporal_horizon, confidence_delta, environment
        )
        ConstrainedMDP.algorithm_constraints = AlgorithmConstraints(
            environment, temporal_horizon
        )
        ConstrainedMDP.environment = environment

    @property
    def policy(self):
        """Returns the policy of the algorithm."""
        return self.algorithm_mdp.occupancy_measure.from_occupancy_measure_to_policy()

    def round_play(self, timestep: int):
        """Plays a round of the algorithm

        Parameters
        ----------
        timestep : int
            timestep for adaptive lr

        Returns
        ----------
        Tuple[List,List,List]
            Path, losses, constraints of the episode
        """
        # Observe rewards, path, constraints from environment
        path, rewards, constraints = self.environment.play_policy(self.policy, timestep)

        # Initialize and fill constraints matrix and rewards array
        constraints_matrix = np.zeros(
            (len(self.environment.sa_ind), self.environment.n_constraints), dtype=float
        )
        rewards_array = np.zeros(len(self.environment.sa_ind), dtype=float)
        pure_rewards = np.zeros(len(self.environment.sa_ind), dtype=float)

        for x, a, x_p in path:
            constraints_matrix[self.environment.sa_ind[(x, a)], :] += constraints[
                self.environment.sa_ind[(x, a)], :
            ]
            pure_rewards[self.environment.sa_ind[(x, a)]] = rewards[
                self.environment.sa_ind[(x, a)]
            ]
            rewards_array[self.environment.sa_ind[(x, a)]] = (
                rewards[self.environment.sa_ind[(x, a)]]
            )

        # Append to stored path, rewards, constraints
        self.stored_path.append(path)
        self.stored_rewards.append(rewards_array)
        self.stored_pure_rewards.append(pure_rewards)
        self.stored_constraints.append(constraints_matrix)

        return self.policy, self.algorithm_constraints.lagrangian_vector

    def round_update(self, mean_update: bool):
        """Performs the update of the algorithm using
            the stored path, rewards, constraints

        Parameters
        ----------
        mean_update : bool
            whether to update the algorithm using the mean of the stored path, rewards, constraints

        Returns
        ----------
        Tuple[Dict[Tuple[int,int],float], List[float]]
            Tuple of policy dict and losses list (ordered as of environment.sa_ind) of round t
        """
        if mean_update:
            self.stored_constraints = [np.mean(self.stored_constraints, axis=0)]
            self.stored_rewards = [np.mean(self.stored_rewards, axis=0)]
            self.stored_pure_rewards = [np.mean(self.stored_pure_rewards, axis=0)]
            self.stored_path = [
                [item for sublist in self.stored_path for item in sublist]
            ]

        for _ in range(len(self.stored_rewards)):
            # Retrieve path, rewards, constraints in FIFO order
            rewards_array = self.stored_rewards.pop(0)
            pure_rewards = self.stored_pure_rewards.pop(0)
            constraints_matrix = self.stored_constraints.pop(0)
            path = self.stored_path.pop(0)

            # Build a path dict that counts the number of visits for every triple (x,a,x') in the path
            path_dict = defaultdict(int)
            for triple in path:
                path_dict[triple] += 1

            # Calculate loss
            loss = (
                np.dot(constraints_matrix, self.algorithm_constraints.lagrangian_vector)
                - rewards_array
            )
            
            loss_mdp = dict(zip(self.environment.sa_ind.keys(), list(loss)))
            
            # Loss analysis for debugging and dashboard visualization
            loss_analysis = [
                -rewards_array + pure_rewards,
                -pure_rewards,
                np.dot(
                    constraints_matrix, self.algorithm_constraints.lagrangian_vector
                ),
            ]

            # Update lagrangian vector
            gradient = -np.dot(
                list(
                    self.algorithm_mdp.occupancy_measure.pairs_occupancy_measure.values()
                ),
                constraints_matrix,
            )
            self.algorithm_constraints.update_lagrangian_vector(gradient)

            # Update occupancy measure
            self.algorithm_mdp.update(loss_mdp, path_dict)

            yield (
                self.policy, # Automatically updated when called
                self.algorithm_constraints.lagrangian_vector,
                loss_analysis,
                np.array(
                    [
                        self.algorithm_mdp._sa_counter[x][1]
                        for x in self.environment.sa_ind
                    ]
                ),
            )