import math
import numpy as np


class DiscountALAgent(object):
    """
    The super class for discounted apprenticeship learning agent.
    """
    def __init__(self, num_state: int, num_action: int, gamma: float, max_num_iterations: int):
        self.n_state = num_state
        self.n_action = num_action
        self.gamma = gamma
        tmp = np.random.random(size=(num_state, num_action))
        tmp = tmp / np.sum(tmp, axis=1, keepdims=True)
        self._policy = tmp
        self._reward_function = np.random.uniform(low=-1.0, high=1.0, size=(num_state, num_action))
        self.max_num_iterations = max_num_iterations

    @property
    def get_policy(self):
        return self._policy.copy()

    @property
    def get_reward_function(self):
        return self._reward_function.copy()

    def _policy_iteration(self, transition_probability: np.ndarray):
        raise NotImplementedError

    def _value_iteration(self, transition_probability: np.ndarray):
        M, N, gamma = self.n_state, self.n_action, self.gamma
        q_function = np.zeros(shape=(self.n_state, self.n_action))

        reward_matrix = self.get_reward_function.copy()
        H = int(1.0 / (1.0 - gamma))

        for t in range(H):
            cur_v = np.max(q_function, axis=1)
            
            next_q = reward_matrix + gamma * np.sum(transition_probability * np.reshape(cur_v, [1, 1, M]), axis=2)
            q_function = next_q.copy()

        opt_policy = self._generate_stationary_greedy_policy(q_function)
        return opt_policy


    def softmax(self, x:np.ndarray):
        exp_x = np.exp(x - np.max(x, axis=1, keepdims=True))
        softmax_x = exp_x / np.sum(exp_x, axis=1, keepdims=True)
        return softmax_x
    
    def _soft_actor_value_iteration(self, transition_probability: np.ndarray):
        q = np.zeros(shape=(self.n_state,self.n_action))
        
        M, N, gamma = self.n_state, self.n_action, self.gamma
        H = int(1.0/(1.0-gamma))
        reward_matrix = self.get_reward_function.copy()
        iternums = 100
        policy = self.softmax(q)
        for i in range(H):
            # for iter_ in range(iternums):
            cur_v = np.sum(self.policy * (q - np.log(np.maximum(policy, 1e-8))), axis=1)
            next_q = reward_matrix + gamma*np.sum(transition_probability * np.reshape(cur_v, [1, 1, M]), axis=2)
            # if np.max(np.abs(q-next_q))<1e-8:
            #     q = next_q.copy()
            #     # print("Coverged at iteration:",iter_)
            #     break
            q = next_q.copy()
            policy = self.softmax(q)
        opt_policy = self._generate_stationary_greedy_policy(q)
        return opt_policy
            
    
    def train_soft_policy_step(self,transition_probability: np.ndarray):
        self._policy = self._soft_actor_value_iteration(transition_probability)
    
    def train_policy_step(self, transition_probability: np.ndarray):
        self._policy = self._value_iteration(transition_probability)

    def _generate_stationary_greedy_policy(self, q_function: np.ndarray):
        """
        Args:
            q_function: Q function, a numpy array with shape [num_state, num_action].
        Returns:
            stat_greedy_policy: the policy acts greedily w.r.t Q_function, a numpy array with shape
            [num_state, num_action]
        """
        M, N = self.n_state, self.n_action
        stat_greedy_policy = np.zeros(shape=(M, N), dtype=np.float32)
        greedy_actions = np.argmax(q_function, axis=1)

        for state in range(M):
            action_dis = np.zeros(shape=N, dtype=np.float32)
            greedy_act = greedy_actions[state]
            action_dis[greedy_act] = 1.0
            stat_greedy_policy[state, :] = action_dis

        return stat_greedy_policy

    # def train_reward_step(self):
    #     raise NotImplementedError
