import sys
import numpy as np
import matplotlib
import matplotlib.pylab as plt
import mpl_toolkits.mplot3d.axes3d as axes3d
import matplotlib.patches as patches
from matplotlib import cm
import random

np.set_printoptions(threshold=np.inf)

class GridWorld:
    strMDP  = ''
    numRows = -1
    numCols = -1
    numStates = -1
    matrixMDP = None
    adjMatrix = None
    rewardFunction = None
    useNegativeRewards = False
    currX = 0
    currY = 0
    startX = 0
    startY = 0
    goalX = 0
    goalY = 0

    def __init__(self, path=None, strin=None, useNegativeRewards=False, stochastic=True, random_reset=False):
        if path != None:
            self._readFile(path)
        elif strin != None:
            self.strMDP = strin
        else:
            print('You are supposed to provide an MDP specification as input!')
            sys.exit()
        self.random_reset = random_reset
        self.start_states = None
        self._parseString()
        self.currX = self.startX
        self.currY = self.startY
        self.numStates = self.numRows * self.numCols
        self.useNegativeRewards = useNegativeRewards
        self.stochastic = stochastic

    def getAvailableStates(self):
        ''' Returns the set of available states.'''
        states = []
        for i in range(self.numRows):
            for j in range(self.numCols):
                if self.matrixMDP[i][j] != -1 and self.matrixMDP[i][j] != -2:
                    states.append(self.getStateIndex(i, j))
        return states
    
    def getAvailableActions(self, s):
        ''' Returns the set of available actions for state s.'''
        directions = {
            'right': (0, 1),
            'left': (0, -1),
            'up': (-1, 0),
            'down': (1, 0)
        }
        row, col = self.getStateXY(s)
        actions = []
        for a in self.getActionSet(): 
            dx, dy = directions[a]
            main_move = (row + dx, col + dy)       
            if  0 <= main_move[0] < self.numRows \
                and 0 <= main_move[1] < self.numCols \
                and self.matrixMDP[main_move[0]][main_move[1]] != -1 \
                and self.matrixMDP[main_move[0]][main_move[1]] != -2:
                actions.append(a) 
        return actions
    
    def _transition_probabilities(self, s, a):
        row, col = self.getStateXY(s)
        result = {}
        
        if self.matrixMDP[row][col] == -1 or self.matrixMDP[row][col] == -2:
            result[(row, col)] = 0.0
            return result 
        
        if row == self.goalX and col == self.goalY:
            result[(row, col)] = 0.0
            return result
        
        # Directions mapping
        directions = {
            'right': (0, 1),
            'left': (0, -1),
            'up': (-1, 0),
            'down': (1, 0)
        }
        slip_offsets = {
            'right': [(-1, 0), (1, 0)],  # Up and Down
            'left': [(-1, 0), (1, 0)],   # Up and Down
            'up': [(0, -1), (0, 1)],     # Left and Right
            'down': [(0, -1), (0, 1)]    # Left and Right
        }

        # Main move
        dx, dy = directions[a]
        main_move = (row + dx, col + dy)
        
        
        if 0 <= main_move[0] < self.numRows and 0 <= main_move[1] < self.numCols and self.matrixMDP[main_move[0]][main_move[1]] != -1:
            result[main_move] = 0.6 
        else:
            result[(row, col)] = 0.0 
            return result

        if self.stochastic:
            # Slip moves
            for offset in slip_offsets[a]:
                slip_move = (row + offset[0], col + offset[1])
                if 0 <= slip_move[0] < self.numRows and 0 <= slip_move[1] < self.numCols and self.matrixMDP[slip_move[0]][slip_move[1]] != -1:
                    result[slip_move] = result.get(slip_move, 0) + 0.20
                else:
                    result[(row, col)] = result.get((row, col), 0) + 0.20  # If blocked, add probability to stay in place

        return result
    
    def transition_probabilities(self, s, a):
        if a not in self.getAvailableActions(s):
            raise ValueError('Action not available in state s')
        
        transitions = self._transition_probabilities(s, a)
        total = sum(transitions.values())
        for state in transitions:
            transitions[state] /= (total + 1e-10)
        result = {self.getStateIndex(state[0], state[1]): prob for state, prob in transitions.items()}
        return result 
    
    def _getNextState(self, action):
        current_state = self.getStateIndex(self.currX, self.currY)
        transitions = self.transition_probabilities(current_state, action)
        all_states = list(transitions.keys())
        probabilities = [transitions[state] for state in all_states]
        chosen_state = random.choices(all_states, weights=probabilities, k=1)[0]
        return self.getStateXY(chosen_state) # X, Y
     
    def getRewardXY(self, currX, currY, action):
        # if self.useNegativeRewards: 
        #     if self.matrixMDP[nextX][nextY] == -1:
        #         raise
        #     elif nextX == self.goalX and nextY == self.goalY:
        #         return 1
        #     else:
        #         return -1
        
        if self.useNegativeRewards: 
            if self.matrixMDP[currX][currY] == -1:
                raise
            elif currX == self.goalX and currY == self.goalY:
                return 1
            else:
                return -1
        else:
            if self.matrixMDP[currX][currY] == -1:
                raise
            elif currX == self.goalX and currY == self.goalY:
                return 1
            else:
                return 0
        
    def isTerminal(self):
        if self.currX == self.goalX and self.currY == self.goalY:
            return True
        elif self.matrixMDP[self.currX][self.currY] == -2:
            return True  # Dead zone is also a terminal state
        return False

    def step(self, action):
        if self.rewardFunction == None and self.isTerminal():
            observation = self.getStateIndex(self.currX, self.currY)
            reward = self.getRewardS(self.getCurrentState(), action)
            done = self.isTerminal()
            return observation, reward, done
        else:
            nextX, nextY = self._getNextState(action)
            self.currX = nextX
            self.currY = nextY
            reward = self.getRewardXY(self.currX, self.currY, action)
            observation = self.getStateIndex(self.currX, self.currY)
            done = self.isTerminal()
            return observation, reward, done

    def getStateIndex(self, x, y):
        idx = y + x * self.numCols
        return idx
    
    def getStateXY(self, idx):
        ''' Given the index that uniquely identifies each state this method
            returns its equivalent coordinate (x,y).'''
        y = idx % self.numCols
        x = (idx - y)/self.numCols
        # print(x,y)
        return int(x), int(y)

    def isXYWall(self, x, y):
        ''' Returns True if the state (x,y) is a wall, otherwise returns False.'''
        if self.matrixMDP[x][y] == -1:
            return True
        else:
            return False
    
    def isStateWall(self, idx):
        ''' Returns True if the state with index idx is a wall, otherwise returns False.'''
        x, y = self.getStateXY(idx)
        return self.isXYWall(x, y)
    
    def getCurrentState(self):
        ''' Returns the unique identifier for the current state the agent is.'''
        currStateIdx = self.getStateIndex(self.currX, self.currY)
        return currStateIdx

    def getGoalState(self):
        ''' Returns the unique identifier to the goal.'''
        goalStateIdx = self.getStateIndex(self.goalX, self.goalY)
        return goalStateIdx

    def isGoalState(self, idx):
        ''' Returns True if the state with index idx is the goal, otherwise returns False.'''
        x, y = self.getStateXY(idx)
        if x == self.goalX and y == self.goalY:
            return True
        else:
            return False
    
    def getStartState(self):
        ''' Returns the unique identifier to the goal.'''
        startStateIdx = self.getStateIndex(self.startX, self.startY)
        return startStateIdx
 
    def _fillAdjacencyMatrix(self):
        
        # Initialize adjacency matrix and index matrix
        self.adjMatrix = np.zeros((self.numStates, self.numStates), dtype=int)
        self.idxMatrix = np.zeros((self.numRows, self.numCols), dtype=int)
        # Set up indices for each state in the index matrix
        for i in range(self.numRows):
            for j in range(self.numCols):
                self.idxMatrix[i][j] = i * self.numCols + j
        # Define blocked states based on matrixMDP having -1
        blocked_states = [(i, j) for i in range(self.numRows) for j in range(self.numCols) if self.matrixMDP[i][j] == -1]
        # Iterate over all states in the grid
        for i in range(self.numRows):
            for j in range(self.numCols):
                if (i, j) not in blocked_states:
                    current_state = self.idxMatrix[i][j]
                    # Possible neighbors
                    neighbors = [(i-1, j), (i+1, j), (i, j-1), (i, j+1)]
                    
                    # Validate neighbors and update adjacency matrix
                    for ni, nj in neighbors:
                        if 0 <= ni < self.numRows and 0 <= nj < self.numCols and (ni, nj) not in blocked_states:
                            neighbor_state = self.idxMatrix[ni][nj]
                            self.adjMatrix[current_state][neighbor_state] = 1
                            self.adjMatrix[neighbor_state][current_state] = 1  # Ensure symmetry if needed
       
    def getAdjacencyMatrix(self):
        ''' If I never did it before, I will fill the adjacency matrix.
        Otherwise I'll just return the one that was filled before.'''
        # if self.adjMatrix == None:
        self._fillAdjacencyMatrix()
        return self.adjMatrix
                       
    def getGridDimensions(self):
        ''' Returns gridworld width and height.'''
        return self.numRows, self.numCols
    
    def getNextState(self, s, action):
        transitions = self.transition_probabilities(s, action)
        all_states = list(transitions.keys())
        probabilities = [transitions[state] for state in all_states]
        chosen_state = random.choices(all_states, weights=probabilities, k=1)[0]
        return chosen_state
    
    def getRewardS(self, s, action):
        currX, currY = self.getStateXY(s)
        return self.getRewardXY(currX, currY, action)
        
    def defineRewardFunction(self, vector):
        ''' Load vector that will define the reward function: the dot product
            between the loaded vector and the feature representation.'''
        self.rewardFunction = vector
    
    def defineGoalState(self, idx):
        ''' Returns True if the goal was properly set, otherwise returns False.
            One may fail to set a goal if it tries to do so in a wall state, in
            an invalid index, etc.'''
        x, y = self.getStateXY(idx)
        if self.adjMatrix == None:
            self._fillAdjacencyMatrix()
        if idx >= self.numStates:
            return False
        elif self.matrixMDP[x][y] == -1:
            return False
        else:
            self.goalX = x
            self.goalY = y
            self.reset()
            return True
    
    def getNumStates(self):
        ''' Returns the total number of states (including walls) in the MDP.'''
        return self.numStates
    
    def getActionSet(self):
        ''' At first the four directional actions are the ones available.'''
        return ['up', 'right', 'down', 'left']
    
    def sampleAction(self):
        ''' Returns a random action from the set of available actions.'''
        return random.choice(range(len(self.getActionSet())))
    
    def reset(self):
        ''' Resets the agent to its initial position.'''
        if self.random_reset:
            # X = random.randint(0, self.numRows - 1)
            # Y = random.randint(0, self.numCols - 1)
            self.startX, self.startY = random.choice(self.start_states)
            assert self.matrixMDP[self.startX][self.startY] != -1
            assert self.matrixMDP[self.startX][self.startY] != -2
            assert self.startX != self.goalX or self.startY != self.goalY
            self.currX = self.startX
            self.currY = self.startY
             
            # if self.matrixMDP[X][Y] == -1 or self.matrixMDP[X][Y] == -2:
            #     self.reset()
            # elif self.startX == self.goalX and self.startY == self.goalY:
            #     self.reset()
            # else:
            #     self.currX = X
            #     self.currY = Y
        else:
            self.currX = self.startX
            self.currY = self.startY
        return self.getStateIndex(self.currX, self.currY)
     
    def plot(self):
        plt.clf()
        for idx in range(self.getNumStates()): 
            i, j = self.getStateXY(idx)
            # if idx == self.getGoalState():
            #     plt.gca().add_patch(
            #         patches.Rectangle(
            #         (j, self.numRows - i - 1), 
            #         1.0,
            #         1.0,
            #         facecolor = "cyan"
            #         )
            #     )
            
            # elif idx == self._getStateIndex(self.startX, self.startY):
            #     plt.gca().add_patch(
            #         patches.Rectangle(
            #         (j, self.numRows - i - 1), 
            #         1.0,
            #         1.0,
            #         facecolor = "tomato"
            #         )
            #     )
            if self.matrixMDP[i][j] != -1:
                plt.gca().add_patch(
                    patches.Rectangle(
                    (j, self.numRows - i - 1), # (x,y)
                    1.0,                   # width
                    1.0,                   # height
                    facecolor = "white"
                    )
                )
            else:
                plt.gca().add_patch(
                    patches.Rectangle(
                    (j, self.numRows - i - 1), # (x,y)
                    1.0,                   # width
                    1.0,                   # height
                    facecolor = "gray"
                    )
                )
        plt.xlim([0, self.numCols])
        plt.ylim([0, self.numRows])
        for i in range(self.numCols):
            plt.axvline(i, color='k', linestyle=':')
        plt.axvline(self.numCols, color='k', linestyle=':')
        for j in range(self.numRows):
            plt.axhline(j, color='k', linestyle=':')
        plt.axhline(self.numRows, color='k', linestyle=':')
        plt.show()
    
    def _readFile(self, path):
        file = open(path, 'r')
        for line in file:
            self.strMDP += line

    def _parseString(self):
        
        data = self.strMDP.split('\n')
        self.numRows = int(data[0].split(',')[0])
        self.numCols = int(data[0].split(',')[1])
        self.matrixMDP = np.zeros((self.numRows, self.numCols))
        self.start_states = []

        for i in range(len(data) - 1):
            for j in range(len(data[i+1])):
                char = data[i+1][j]
                if char == 'X':
                    self.matrixMDP[i][j] = -1
                elif char == '.':
                    self.matrixMDP[i][j] = 0
                elif char == 'S':
                    self.matrixMDP[i][j] = 0
                    self.start_states.append((i, j))
                elif char == 'G':
                    self.matrixMDP[i][j] = 0
                    self.goalX = i
                    self.goalY = j
                elif char == 'D':
                    self.matrixMDP[i][j] = -2  # Dead zone

        if self.random_reset and self.start_states:
            self.startX, self.startY = random.choice(self.start_states)
        elif self.start_states:
            self.startX, self.startY = self.start_states[0]  # Default to the first start state
            

# XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
# XX.......................................XXX
# XX.......................................XXX
# XX.......................................XXX
# XS..XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX..GX
# XX..XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX.XXX
# XX.........................................X
# XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX