import numpy as np
import random
from scipy.special import logsumexp


class MoldovanGridworldMDP:
    
    X = {(1, 0), (2, 0), (2, 2)}
    S = (4, 0)
    G = (0, 1)
    NOISE = 0.2
    MOVES = [(0, +1), (-1, 0), (0, -1), (+1, 0)]
    
    def __init__(self, fail_cost, fail_cost2, step_cost=1, goal_reward=20.):
        self.step_cost = step_cost
        self.fail_cost = fail_cost
        self.fail_cost2 = fail_cost2
        self.goal_reward = goal_reward
    
    def is_terminal_state(self, state):
        return state in MoldovanGridworldMDP.X or state == MoldovanGridworldMDP.G
    
    def next_state_distribution(self, state, action):
        
        # first assuming all moves are valid
        y, x = state
        probs = [MoldovanGridworldMDP.NOISE / 4] * 4
        probs[action] += 1 - MoldovanGridworldMDP.NOISE
        next_states = [((y + dy, x + dx), p) for (dy, dx), p in zip(MoldovanGridworldMDP.MOVES, probs)]
        
        # now correct moves that are invalid, and distribute probability
        next_states = [((y, x), p) if 0 <= y <= 4 and 0 <= x <= 4 else (state, p) 
                       for (y, x), p in next_states]
        replaced_states = [(s, p) for s, p in next_states if s == state]
        if len(replaced_states) != 0:
            next_states = [(s, p) for s, p in next_states if s != state]
            next_states.append((state, sum(p for s, p in replaced_states)))
        return next_states
    
    def reward(self, state, action, next_state):
        if next_state == MoldovanGridworldMDP.G:
            return self.goal_reward - self.step_cost
        elif next_state in MoldovanGridworldMDP.X:
            if next_state == (2, 2):
                return -self.fail_cost2 - self.step_cost
            else:
                return -self.fail_cost - self.step_cost
        else:
            return -self.step_cost
    
    def transition(self, state, action):
        if random.random() < MoldovanGridworldMDP.NOISE:
            action = random.randint(0, 3)
        row, col = state
        if action == 0: col += 1
        elif action == 1: row -= 1
        elif action == 2: col -= 1
        elif action == 3: row += 1
        else: raise Exception('bad action {}'.format(action))
        if col < 0 or col >= 5 or row < 0 or row >= 5: return state
        return (row, col)
    
    def policy_rollout(self, policy):
        state = MoldovanGridworldMDP.S
        total_reward = 0.
        for t in range(35):
            action = policy[state[0], state[1]]
            next_state = self.transition(state, action)
            reward = self.reward(state, action, next_state)
            done = self.is_terminal_state(next_state)
            total_reward += reward
            state = next_state
            if done: break
        return total_reward

            
class ValueIteration:
    
    def __init__(self, task, beta, gamma=1.):
        self.task = task
        self.gamma = gamma
        self.beta = beta
        self.Q = np.zeros((5, 5, 4))
    
    def _iterate(self, policy=None):
        new_Q = np.zeros_like(self.Q)
        for y in range(5):
            for x in range(5):
                for a in range(4):
                    s = (y, x)
                    values, probs = [], []
                    for s1, p in self.task.next_state_distribution(s, a):
                        r = self.task.reward(s, a, s1)
                        if self.task.is_terminal_state(s1):
                            backup = r
                        else:
                            if policy is None:
                                a1 = np.argmax(self.Q[s1[0], s1[1],:])
                            else:
                                a1 = policy[s1[0], s1[1]]
                            backup = r + self.gamma * self.Q[s1[0], s1[1], a1]
                        if self.beta == 0:
                            values.append(backup)
                        else:
                            values.append(backup * self.beta)
                        probs.append(p)
                    if self.beta == 0:
                        new_Q[y, x, a] = np.sum(np.multiply(values, probs))
                    else:
                        new_Q[y, x, a] = (1 / self.beta) * logsumexp(values, b=probs)
        error = np.max(np.abs(new_Q - self.Q))
        self.Q = new_Q
        return error
    
    def _extract_policy(self):
        policy = np.zeros((5, 5), dtype=int)
        for y in range(5):
            for x in range(5):
                policy[y, x] = np.argmax(self.Q[y, x,:])
        return policy
    
    def solve(self, policy=None, tol=1e-12):
        for t in range(9999):
            error = self._iterate(policy)
            if t % 10 == 0:
                print('iter {} error {}'.format(t, error))
            if error < tol:
                print('converged at iter {} with error {}'.format(t, error))
                pi = self._extract_policy()
                return self.Q, pi
        
