# Frozenlake-like gridworld
from dataclasses import dataclass
import numpy as np

import gym

from small_scale.policies import TabularPolicy, EpsilonGreedyPolicy


# From Frozenlake, all non-H are interpreted as frozen.
# Modified so the
MAPS = {
    "4x4": ["SFFF", "FHFH", "FFFH", "FHFG"],
    "8x8": [
        "SFFFFFFF",
        "FFFFFFFF",
        "FFFHFFFF",
        "FFFFFHFF",
        "FFFHFFFF",
        "FHHFFFHF",
        "FHFFHFHF",
        "FFFHFFFG",
    ],
}

ACTION_CHARS = "↑→↓←"
NUM_ACTIONS = 4


@dataclass(frozen=True)
class FrozenLakeState:
    y: int
    x: int


class FrozenLakeEnv(gym.Env):
    def __init__(self, map, start_pos, goal_pos, slippery=0.0, loop=False):
        self.map = map
        self.wd, self.ht = self.map.shape
        self.state_dim = self.wd * self.ht
        self.start_state = FrozenLakeState(*tuple(start_pos))
        self.goal_state = FrozenLakeState(*tuple(goal_pos))
        self.loop = loop
        
        self.action_space = gym.spaces.Discrete(4)
        self.observation_space = gym.spaces.Discrete(self.wd * self.ht)
        self.slippery = slippery

        # construct transition matrix
        n = self.observation_space.n
        m = self.action_space.n
        P = np.zeros((n, m, n))
        for y in range(self.ht):
            for x in range(self.wd):
                s = y * self.wd + x

                # Check if terminal state
                if self.map[y, x] or (y == self.goal_state.y and x == self.goal_state.x):
                    sp = self.start_state.y * self.wd + self.start_state.x
                    P[s, :, sp] = 1
                    continue

                # Compute next states
                next_states = [
                    [y - 1, x],
                    [y, x + 1],
                    [y + 1, x],
                    [y, x - 1],
                ]
                # Keep within bounds
                for i in range(4):
                    next_states[i][0] = np.clip(next_states[i][0], 0, self.ht - 1)
                    next_states[i][1] = np.clip(next_states[i][1], 0, self.wd - 1)
                # Convert to state indices
                next_states = [v[0] * self.wd + v[1] for v in next_states]

                for action in range(m):
                    dist = np.ones(4) * (self.slippery / 4)
                    dist[action] += 1 - self.slippery
                    for sp, p in zip(next_states, dist):
                        P[s, action, sp] += p
        self.P = P
        
        self.state = FrozenLakeState(0, 0)
        self.reset()

    def get_transition_matrix(self):
        return self.P
    
    def get_terminal_matrix(self):
        terminal_states = np.zeros(self.state_dim, np.bool)
        if not self.loop:
            terminal_states[self.goal_state.y * self.wd + self.goal_state.x] = True
            terminal_states |= self.map.flatten()
        terminal_states = np.tile(terminal_states[:, None], (1, self.action_space.n))
        return terminal_states

    def get_reward_matrix(self):
        rv = np.zeros(self.state_dim)
        goal_s = self.goal_state.y * self.wd + self.goal_state.x
        rv[goal_s] = 1
        return np.tile(rv[:, None], (1, self.action_space.n))

    def step(self, action):
        s = self.state.y * self.wd + self.state.x
        next_state_dist = self.P[s, action]
        sp = np.random.choice(self.state_dim, p=next_state_dist)
        
        y = sp // self.wd
        x = sp % self.wd
        next_obs = FrozenLakeState(y, x)

        # Reward and done-ness from entering state
        reward = 1 if y == self.goal_state.y and x == self.goal_state.x else 0

        if self.loop:
            done = False
        else:
            # We're done if we obtain the goal reward or if we hit a hole in the ice:
            done = reward or self.map[y, x]
    
        self.state = next_obs
        return self.get_obs(), reward, done, {}

    def reset(self):
        self.state = self.start_state
        return self.get_obs()

    def get_obs(self) -> np.ndarray:
        return self.state.y * self.wd + self.state.x

    def render(self, mode="human"):
        rv = []
        for row in self.map:
            rv.append(["#" if v else "_" for v in row])
        rv[self.start_state.y][self.start_state.x] = "S"
        rv[self.goal_state.y][self.goal_state.x] = "G"
        return "\n".join("".join(row) for row in rv)
    
    def render_policy(self, policy):
        rv = []
        for y in range(self.ht):
            rv.append([])
            for x in range(self.wd):
                if self.map[y, x]:
                    rv[-1].append("#")
                else:
                    rv[-1].append(ACTION_CHARS[policy.dist(y * self.wd + x).argmax()])
        rv[self.start_state.y][self.start_state.x] = "S"
        rv[self.goal_state.y][self.goal_state.x] = "G"
        return "\n".join("".join(row) for row in rv)

    def shortest_path_policy(self):
        successor = np.zeros(self.map.shape, np.int32)
        expanded = np.zeros_like(successor)
        
        gy, gx = self.goal_state.y, self.goal_state.x
        queue = [((gy, gx), 0)]
        while queue:
            (y, x), a = queue.pop(0)
            # Check if this has been expanded
            if expanded[y, x]:
                continue
            expanded[y, x] = 1
            successor[y, x] = a
        
            # Expand to neighboring nodes backwards
            for an in range(4):
                if an == 0:
                    yn, xn = y + 1, x
                elif an == 1:
                    yn, xn = y, x - 1
                elif an == 2:
                    yn, xn = y - 1, x
                elif an == 3:
                    yn, xn = y, x + 1
                
                # Keep within bounds
                yn = np.clip(yn, 0, self.ht - 1)
                xn = np.clip(xn, 0, self.wd - 1)
            
                if not expanded[yn, xn]:
                    # Only expand this if it is not a hole:
                    if not self.map[yn, xn]:
                        queue.append(((yn, xn), an))
        
        return TabularPolicy(self.observation_space, self.action_space, successor.flatten())


def frozen_lake_env_from_string(mapsrc: str, slippery=0.0, loop=False):
    map = [[s == "H" for s in row] for row in mapsrc]
    map = np.array(map, np.bool)
    
    starts = [[s == "S" for s in row] for row in mapsrc]
    starts = np.array(starts, np.bool)
    if starts.sum() != 1:
        raise ValueError("Map must have exactly one start position")
    start_pos = np.argwhere(starts).flatten()
    
    goals = [[s == "G" for s in row] for row in mapsrc]
    goals = np.array(goals, np.bool)
    if goals.sum() != 1:
        raise ValueError("Map must have exactly one goal position")
    goal_pos = np.argwhere(goals).flatten()
    
    return FrozenLakeEnv(map, start_pos, goal_pos, slippery=slippery, loop=loop)


def frozen_lake_policy_from_string(policy_str, epsilon=0.0):
    policy_table = [[ACTION_CHARS.index(char) for char in row] for row in policy_str]
    policy_table = np.array(policy_table, np.int32).reshape(-1)
    observation_space = gym.spaces.Discrete(policy_table.shape[0])
    action_space = gym.spaces.Discrete(NUM_ACTIONS)
    policy = TabularPolicy(observation_space, action_space, policy_table)
    if epsilon > 0:
        policy = EpsilonGreedyPolicy(policy, action_space, epsilon)
    return policy

        