#!/usr/bin/env python

# Python imports.
import numpy as np

# Other imports.
from simple_rl.tasks import FourRoomMDP


def create_distance_matrix(policy, actions, beta):
    states = policy.get_states()
    
    states_dict = {}
    for i in range(len(states)):
        states_dict[states[i]] = i
    
    matrix = np.zeros((len(states), len(states)))
    for s in states:
        q = np.zeros((4,1))
        count = 0
        for a in actions:
            q[count] = policy.get_q_value(s, a)
            count += 1
        
        p = np.exp(-beta * q)
        distance = -np.log(p / np.sum(p))
        count = 0
        for a in actions:
            next_state = policy.transition_func(s,a)
            if next_state == s: # stay in same state (more than one way this can happen)
                matrix[states_dict[s], states_dict[s]] += distance[count] 
            else:
                matrix[states_dict[s], states_dict[next_state]] = distance[count]
            count += 1
            
    return matrix
    
