import numpy as np

def softmax(x, temp=1.0):
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp((x - np.max(x))/temp)
    return e_x / e_x.sum()

class Soft_Value_Iteration:

    def __init__(self, n_states, n_actions, discount=0.9, temperature=0.01):

        self.n_states = n_states
        self.n_actions = n_actions
        self.discount = discount
        self.temperature = temperature
        self.softmax = lambda x: softmax(x, temp=self.temperature)

        self.reset()

    def reset(self):
        self.Q = np.zeros([self.n_states, self.n_actions])
        self.pi = np.ones([self.n_states, self.n_actions]) / self.n_actions


    def train(self, transition_matrix, reward_fn, termination_fn, steps=100, stopping_condition=1e-8):
        
        def update_Q():
            Q_new = np.zeros([self.n_states, self.n_actions])
            for s in range(self.n_states):
                q = np.zeros(self.n_actions)
                for a in range(self.n_actions):
                    for ns in range(self.n_states):
                        if transition_matrix[ns, s, a] == 0.0:
                            pass

                        reward = reward_fn(ns, s, a)
                        done = termination_fn(ns, s, a)

                        q[a] += transition_matrix[ns, s, a] * (reward + (1 - done) * self.discount * np.dot(self.softmax(self.Q[ns]), self.Q[ns]))

                Q_new[s] = q.copy()

            return Q_new

        def update_V():
            V_new = np.zeros(self.n_states)

            for s in range(self.n_states):
                V_new[s] = np.dot(self.softmax(self.Q[s]), self.Q[s] - np.log(self.softmax(self.Q[s])))

            return V_new

        for i in range(steps):
            Q_new = update_Q()

            diff = np.max(np.abs(np.sum(Q_new - self.Q, axis=1)))

            self.Q = Q_new

            if diff < stopping_condition:
                print(f"Value iteration terminated after {i+1} steps with stopping criteria loss  {diff}<{stopping_condition} ")
                break
            if i == (steps - 1):
                print(f"Value iteration terminated after {i+1} steps with loss {diff}")

        for s in range(self.n_states):
            self.pi[s] = softmax(self.Q[s])

    def update(self, *args, **kwargs):
        raise NotImplementedError

    def policy(self, state, **kwargs):
        return self.pi[state]

    def get_policy(self, **kwargs):
        return self.pi

    def step(self, state, **kwargs):
        """sample an action from the policy"""
        probs = self.policy(state, **kwargs)
        return np.random.choice(self.n_actions, p=probs)



