import numpy as np
from tqdm import tqdm

def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()

class Value_Iteration:

    def __init__(self, n_states, n_actions, discount=0.9, temperature=0.01):

        self.n_states = n_states
        self.n_actions = n_actions
        self.discount = discount
        self.temperature = temperature

        self.reset()

    def reset(self):
        self.V = np.zeros(self.n_states, dtype=np.float32)
        self.Q = np.zeros([self.n_states, self.n_actions], dtype=np.float32)
        self.pi = np.ones([self.n_states, self.n_actions], dtype=np.float32) / self.n_actions


    def train(self, transition_matrix, reward_map, steps=100, stopping_condition=1e-8):
        
        self.V = np.zeros(self.n_states, dtype=np.float32)

        def q_from_v(s):
            q = np.zeros(self.n_actions)
            for a in range(self.n_actions):
                q[a] = np.sum(transition_matrix[:, s, a] * (reward_map + self.discount * self.V))
            return q

        for i in tqdm(range(steps)):
            diff = 0.0
            for s in range(self.n_states):
                v_s = self.V[s]
                q_s = q_from_v(s)
                self.V[s] = np.max(q_s)
                diff = max(diff, np.abs(self.V[s] - v_s))

            if diff < stopping_condition:
                print(f"Value iteration terminated after {i+1} steps with stopping criteria loss  {diff}<{stopping_condition} ")
                break

            if i == (steps - 1):
                print(f"Value iteration terminated after {i+1} steps with loss {diff}")


        for s in range(self.n_states):
            q = q_from_v(s)
            scaled_q = q-max(q)
            exp=np.exp(scaled_q/self.temperature)
            self.pi[s] = exp / np.sum(exp)

    def update(self, *args, **kwargs):
        raise NotImplementedError

    def policy(self, state, **kwargs):
        return self.pi[state]

    def get_policy(self, **kwargs):
        return self.pi

    def step(self, state, **kwargs):
        """sample an action from the policy"""
        probs = self.policy(state, **kwargs)
        return np.random.choice(self.n_actions, p=probs)



