#!/usr/bin/env python
# coding: utf-8

# In[1]:


import numpy as np
from copy import deepcopy
from pomdp import constrained_pomdp


# In[2]:


def probsum(iterable):
    sum = 0
    for i,j in enumerate(iterable):
        sum = sum + j

    return sum


# In[3]:


def gridworld(m,n,horizon,p,rew):
    shape = (m,n)
    states = {t:([(i,j) for i in range(shape[0]) for j in range(shape[1])]) for t in range(1,horizon+1)}
    actions = {t:['U','D','L','R'] for t in range(1,horizon+1)}
    
    P = {}
    for s in states[1]:
        for a in actions[1]:
            P.update({(s,a):[[],[]]})
    
    for s in states[1]:
        for a in actions[1]:
            if a!='U' and s[0]+1 < m:
                P[(s,a)][0].append((s[0]+1,s[1]))
                P[(s,a)][1].append(p if a =='D' else (1-p)/2.0)

            if a!='D' and s[0]-1 >= 0:
                P[(s,a)][0].append((s[0]-1,s[1]))
                P[(s,a)][1].append(p if a =='U' else (1-p)/2.0)

            if a!='R' and s[1]-1 >= 0:
                P[(s,a)][0].append((s[0],s[1]-1))
                P[(s,a)][1].append(p if a =='L' else (1-p)/2.0)

            if a!='L' and s[1]+1 < n:
                P[(s,a)][0].append((s[0],s[1]+1))
                P[(s,a)][1].append(p if a =='R' else (1-p)/2.0)

            prob_sum = probsum(P[(s,a)][1])

            if prob_sum<1:
                P[(s,a)][0].append(s)
                P[(s,a)][1].append(1-prob_sum)
                
    trans_step = dict()
    for a in actions[1]:
        trans_step.update({a:np.zeros((m*n,m*n))})
        for s in states[1]:
            x,y = s[0],s[1]
            for i,ns in enumerate(P[(s,a)][0]):
                nx,ny = ns[0],ns[1]
                trans_step[a][n*x+y][n*nx+ny] = P[(s,a)][1][i]
    
    transitions = {t: deepcopy(trans_step) for t in range(1,horizon)}
    
    initial_dist = np.zeros(m*n)
    initial_dist[0] = 1.0
    
    observations = deepcopy(states)
    
    constraints = {}
    constraint_val = {}
    constraint_indices = []
    
    rewards = {t:{a:np.zeros(m*n,dtype=float) for a in actions[t]} for t in states}
    
    for t in states:
        for a in actions[t]:
            for i in range(m*n):
                rewards[t][a][i] = rew[i]
    
    obs_prob = np.zeros((m*n,m*n))
    for s in states[1]:
        x,y = s[0],s[1]
        count = 0
        sur = set()
        for dx in [-1,0,1]:
            for dy in [-1,0,1]:
                if (0 <= x+dx < m) and (0 <= y+dy < n):
                    sur.add((x+dx,y+dy))
                    count+=1
        for nx,ny in sur:
            obs_prob[n*x+y][n*nx+ny] = 1/count
    
    observation_probability = {t:deepcopy(obs_prob) for t in states}
    
    return constrained_pomdp(initial_dist,states,actions,transitions,observations,observation_probability,rewards,constraints,constraint_val,constraint_indices,horizon)







