import numpy as np

from utils import reshape_param
import ipdb
 
 
def step(x, y, direction, grid):
    grid_shape = grid.shape
    neighbor = (x + direction[0], y + direction[1])
    valid = 0 <= neighbor[0] < grid_shape[0] and 0 <= neighbor[1] < grid_shape[1]
    no_wall = valid and grid[neighbor] == 0
 
    return no_wall, neighbor
 
 
 
def DFS(nodes, level, paths, tmp_path, current_node, start_node):
 
    if current_node == start_node:
        tmp_path = tmp_path.copy()
        tmp_path.reverse()
        paths.append(tmp_path)
        return paths
 
    for node in nodes[level+1]:
        if abs(node[0]-current_node[0])+abs(node[1]-current_node[1]) == 1:
            tmp_path.append(node)
            paths = DFS(nodes, level+1, paths, tmp_path, node, start_node)
            del tmp_path[-1]
 
    return paths
 
 
def BFS(lvl_queue, queue, grid, visited, count, goal):
    if len(queue) == 0:
        return False, count, lvl_queue
 
    next_queue = []
    for node in queue:
        directions = [[1, 0], [0, 1], [-1, 0], [0, -1]]
        for direction in directions:
            valid, neighbor = step(node[0], node[1], direction, grid)
 
            if valid and not visited[neighbor]:
                if neighbor == goal:
                    lvl_queue.append([goal])
                    return True, count + 1, lvl_queue
 
                visited[neighbor] = True
                next_queue.append(neighbor)
 
    lvl_queue.append(next_queue)
    return BFS(lvl_queue, next_queue, grid, visited, count + 1, goal)
 
 
class Maze_Gen:
 
    def __init__(self, seed, size):
        #self.size = 11
        #self.start = np.array([6, 6])

        assert size % 2 == 0, "Size must be divisible by two!"

        self.size = size
        #self.start = np.array([1, 1])
        self.start = np.array([0,0])

        self.p = 0.3

        self.seed = seed
        self.rng = np.random.default_rng(seed)

        # Big maze
        #self.goal = np.array([2, self.size-1 + 1])

        self.goal = np.array([1, self.size - 1])
 
    def get_maze(self):
        while True:
            grid = np.zeros((2, self.size))
            p = 0.5
            random_pos_top = self.rng.binomial(1, p, (1, self.size // 2))
            random_pos_bottom = 1 - random_pos_top
            random_pos = np.vstack((random_pos_top, random_pos_bottom))

            for i in range(random_pos.shape[-1]):
                pos_i = random_pos[:, i]
                grid[0][2 * i] = pos_i[0]
                grid[1][2 * i] = pos_i[1]
                
            grid[0][0] = 0
            grid[1][-1] = 0

            # grid_new = np.ones((grid.shape[0] + 2, grid.shape[1] + 2))
            # grid_new[1:-1, 1:-1] = grid
            # grid = grid_new
 
            # while True:
            #     goal = self.rng.integers(1, self.size+1, (2))
            #     # if np.sum(np.abs(goal - self.start)) > 5:
            #     #     break
            #     # if np.sum(np.abs(goal - self.start)) > 2:
            #     #     break

            goal = self.goal
 
            grid[self.start[0], self.start[1]] = 0
            grid[goal[0], goal[1]] = 0
 
            queue = []
            lvl_queue = []
            queue.append((self.start[0], self.start[1]))
            lvl_queue.append([(self.start[0], self.start[1])])
 
            visited = np.full(grid.shape, False, dtype=bool)
            visited[self.start[0], self.start[1]] = True
            count = 0
            BFS_res, BFS_count, nodes = BFS(lvl_queue, queue, grid, visited, count, (goal[0], goal[1]))
 
            if BFS_res:  # and BFS_count == self.path_length:
                tmp_path = [(goal[0], goal[1])]
                nodes.reverse()
                paths = DFS(nodes, 0, [], tmp_path, (goal[0], goal[1]), (self.start[0], self.start[1]))
 
                return self.start, goal, grid.copy(), paths[0], self.get_action(paths[0])
 
                # return start, goal, grid.copy(), BFS_count, self.get_action(paths)
 
    def get_action(self, path):
        actions = []
        # for path in paths:
        temp_actions = []
        current_state = path[0]
        for i in range(1, len(path)):
            if path[i][0] - current_state[0] == 1:
                temp_actions.append(0)
            if path[i][0] - current_state[0] == -1:
                temp_actions.append(1)
            if path[i][1] - current_state[1] == 1:
                temp_actions.append(2)
            if path[i][1] - current_state[1] == -1:
                temp_actions.append(3)
            current_state = path[i]
        actions.append(temp_actions)
 
        return actions
 

def grid_to_support(goal, grid, support_size, rng):
    idx_x, idx_y = np.where(grid == 0)
    pos_idx = np.concatenate((np.expand_dims(idx_x, -1), np.expand_dims(idx_y, -1)), -1)

    idx_x, idx_y = np.where(grid == 1)
    neg_idx = np.concatenate((np.expand_dims(idx_x, -1), np.expand_dims(idx_y, -1)), -1)

    y_neg = np.zeros((len(neg_idx), 1))
    y_pos = np.ones((len(pos_idx), 1))
    support_x = np.vstack((pos_idx, neg_idx))
    support_y = np.vstack((y_pos, y_neg))
    shuffled_idx = np.arange(len(support_x))
    rng.shuffle(shuffled_idx)
    shuffled_idx = shuffled_idx[0:support_size]
    support_x = support_x[shuffled_idx]
    support_y = support_y[shuffled_idx]

    return torch.from_numpy(support_x).float(), torch.from_numpy(support_y).float()


    
class Simulator:
 
    def __init__(self, seed, grid_size, sparse=False):
 
        self.sparse = sparse
        #self.sparse = True
 
        self.generator = Maze_Gen(seed, grid_size)

        self.grid_size = grid_size
 
        self.start = None
        self.goal = None
        self.grid = None
        self.optimal_policy = None
        self.actual_pos = None
 
    def reset(self):
 
        self.start, self.goal, self.grid, _, optimal_policy = self.generator.get_maze()
        self.actual_pos = self.start
        self.optimal_policy = np.array(optimal_policy[0])
 
    def step(self, a):
 
        old_pos = self.actual_pos.copy()
 
        if a == 0:  # 0: forward
            self.actual_pos += np.array([1, 0])
        elif a == 1:  # 1: backward
            self.actual_pos += np.array([-1, 0])
        elif a == 2:  # 2: right
            self.actual_pos += np.array([0, 1])
        elif a == 3:  # 3: left
            self.actual_pos += np.array([0, -1])
 
        # If hit obstacle: stay where you are
        #ipdb.set_trace()
        # try:
        #     hej = self.grid[self.actual_pos[0], self.actual_pos[1]]
        # except Exception as e:
        #     # Out of bounds
        #     self.actual_pos = old_pos

        try:
            if self.grid[self.actual_pos[0], self.actual_pos[1]] == 1:
                # reset initial position without moving
                self.actual_pos = old_pos
        except Exception as e:
            # Out of bounds
            self.actual_pos = old_pos
 
        if self.sparse:
            r = -1
        else:
            r = - (np.abs(self.actual_pos[0] - self.goal[0]) + np.abs(self.actual_pos[1] - self.goal[1])) / (self.grid_size*2.)
 
        # Check if goal reached
        if self.actual_pos[0] == self.goal[0] and self.actual_pos[1] == self.goal[1]:
            return self.actual_pos.copy(), 0, True, None
 
        # Set new position
        return self.actual_pos.copy(), r, False, None
 
    # def get_state(self):
    #     goal = np.zeros(self.grid.shape)
    #     goal[self.goal] = self.goal_val
    #     pos = np.zeros(self.grid.shape)
    #     pos[(self.actual_pos_x, self.actual_pos_y)] = self.pos_val
    #     if self.show_goal == 0:
    #         return np.expand_dims(np.reshape(self.visited_states, (-1)), 0).copy() #np.expand_dims(np.concatenate([np.expand_dims(self.grid, 0), np.expand_dims(pos, 0)], 0), 0)
    #     else:
    #         return np.expand_dims(np.concatenate([np.expand_dims(self.grid, 0), np.expand_dims(pos, 0), np.expand_dims(goal, 0)], 0), 0)
    #
    # def normalize_reward(self, r, T):
    #     return r #- 1 + T * (T-1)/2.
    #
    # def get_distance(self):
    #     return - (np.abs(self.actual_pos_x - self.goal[0]) + np.abs(self.actual_pos_y - self.goal[1])) / self.T
 
 
 

import torch
def state_to_action(old_state, agent, device):
    next_action = agent(torch.from_numpy(old_state).unsqueeze(0).float().to(device))
    return next_action

def evaluate_reward_maze(model, learner, shifter, hypernetwork, grid_size, support_size, device, seed):

    #generator = Maze_Gen(seed, grid_size)
    env = Simulator(seed, grid_size)
    env.reset()
    start = env.start
    goal = env.goal
    grid = env.grid

    ## Initialize learner by encoding grid
    rng = np.random.default_rng(seed)
    x_s, y_s = grid_to_support(goal, grid, support_size, rng)
    

    if model == 'meta-fun':
        lambda_model = lambda x : learner(x_s, y_s, x)
    elif model == 'eigen' or model == 'cavia' or model == 'leo':
        zs = learner.encode(x_s.to(device), y_s.to(device))
        theta_1_arr = hypernetwork(zs)
        theta_1 = reshape_param(theta_1_arr, shifter.theta_0)
        lambda_model = lambda x : shifter(x, theta_1)
    elif model == 'hyper':
        zs = learner.encoder(x_s.to(device), y_s.to(device)).squeeze(0)
        theta_1_arr = hypernetwork(zs)
        theta_1 = reshape_param(theta_1_arr, shifter.theta_0)
        lambda_model = lambda x : shifter(x, theta_1)

    agent = lambda_model

    ##

    #old_state = torch.from_numpy(start).unsqueeze(0).float().to(device)
    old_state = start
    T_policy = 30
    rewards = []
    for t in range(T_policy):
        state_transformed = np.hstack((old_state, goal))
        action = state_to_action(state_transformed, agent, device)

        action = torch.argmax(action, -1)
        new_state, reward, done, info = env.step(action.cpu().data.numpy()[0])
        old_state = new_state
        if done:
            break

        rewards.append(reward)

    avg_reward = np.mean(rewards)
    final_reward = rewards[-1]
    max_reward = np.max(reward)

    reward_dict = {'avg_reward': avg_reward, 'final_reward': final_reward, 'max_reward': max_reward}

    return reward_dict
 
if __name__ == '__main__':
 
    generator = Maze_Gen()
    for _ in range(20):
         s, g, grid, optimal_states, policy = generator.get_maze()
         print()
 
 