import random
import numpy as np
import marshal
import copy
from collections import deque

import matplotlib.pyplot as plt
import matplotlib.animation as animation
from collections import defaultdict
import multiprocessing as mp
from functools import partial
from tqdm import tqdm
import numpy as np

def process_seed(seed: int,config, min_actions_to_succeed: int = 5):
    from .env import SokobanEnv
    env = SokobanEnv(config)
    env.reset(seed=seed)
    gt_action_sequence = get_shortest_action_path(
        env.env.room_fixed, 
        env.env.room_state, 
        MAX_DEPTH=min_actions_to_succeed,
    )
    if len(gt_action_sequence) > min_actions_to_succeed:
        return seed, []

    return seed, gt_action_sequence

def generate_seeds(size,config,min_actions_to_succeed=5,seed=0,n_candidate: int = 20000) -> list:
    action_count = defaultdict(int)
    num_processes = mp.cpu_count()
    print(f"Using {num_processes} processes for seed processing")
    pool = mp.Pool(processes=num_processes)
    process_seed_partial = partial(process_seed, min_actions_to_succeed=min_actions_to_succeed,config=config)
    seeds = range(seed, seed + n_candidate)
    results = list(tqdm(pool.imap(process_seed_partial, seeds), total=len(seeds), desc="Processing seeds"))
    pool.close()
    pool.join()

    valid_seeds_with_actions = [(seed, gt_action_sequence) for seed, gt_action_sequence in results if gt_action_sequence and len(gt_action_sequence) <= min_actions_to_succeed]
    seed_to_length = {seed: len(gt_action_sequence) for seed, gt_action_sequence in valid_seeds_with_actions}
    valid_seeds = [seed for seed, _ in valid_seeds_with_actions]
    action_lengths = [len(gt_action_sequence) for _, gt_action_sequence in valid_seeds_with_actions]
    for _, gt_action_sequence in valid_seeds_with_actions:
        for action in gt_action_sequence:
            action_count[action] += 1
    
    # Calculate basic statistics
    avg_length = np.mean(action_lengths) if action_lengths else 0
    median_length = np.median(action_lengths) if action_lengths else 0
    min_length = min(action_lengths) if action_lengths else 0
    max_length = max(action_lengths) if action_lengths else 0
    
    # Count frequency of each action length
    length_counts = {}
    for length in action_lengths:
        length_counts[length] = length_counts.get(length, 0) + 1
    
    # Calculate percentage of valid seeds
    valid_percentage = (len(valid_seeds) / n_candidate) * 100
    
    # Print statistics
    print("\nAction Sequence Statistics:")
    print(f"Total candidates processed: {n_candidate}")
    print(f"Valid seeds found: {len(valid_seeds)} ({valid_percentage:.2f}%)")
    print(f"Average action length: {avg_length:.2f}")
    print(f"Median action length: {median_length}")
    print(f"Min action length: {min_length}")
    print(f"Max action length: {max_length}")
    print("\nAction length distribution:")        
    for length in sorted(length_counts.keys()):
        count = length_counts[length]
        percentage = (count / len(action_lengths)) * 100
        print(f"  Length {length}: {count} instances ({percentage:.2f}%)")
    
    print("\nAction frequency:")
    for action, count in sorted(action_count.items(), key=lambda x: x[1], reverse=True):
        percentage = (count / len(valid_seeds_with_actions)) * 100
        print(f"  {action}: {count} instances ({percentage:.2f}%)")
    
    distribution_of_generated_seeds = defaultdict(int)
    for seed in valid_seeds[:size]:
        length = seed_to_length[seed]
        distribution_of_generated_seeds[length] += 1
    print("\nDistribution of generated seeds:")
    for length, count in sorted(distribution_of_generated_seeds.items(), key=lambda x: x[0]):
        percentage = (count / len(valid_seeds)) * 100
        print(f"  Length {length}: {count} instances ({percentage:.2f}%)")
    return valid_seeds[:size]
        
def get_shortest_action_path(room_fixed, room_state, MAX_DEPTH=100):
        """
        Get the shortest action path to push all boxes to the target spots.
        Use BFS to find the shortest path.
        NOTE currently only support one player, only one shortest solution
        =========================================================
        Parameters:
            room_state (np.ndarray): the state of the room
                - 0: wall
                - 1: empty space
                - 2: box target
                - 3: box on target
                - 4: box not on target
                - 5: player
            room_fixed (np.ndarray): the fixed part of the room
                - 0: wall
                - 1: empty space
                - 2: box target
            MAX_DEPTH (int): the maximum depth of the search
        =========================================================
        Returns:
            action_sequence (list): the action sequence to push all boxes to the target spots
        """
        
        # BFS queue stores (room_state, path)
        queue = deque([(copy.deepcopy(room_state), [])])
        explored_states = set()
        
        # Possible moves: up, down, left, right
        moves = [(-1,0), (1,0), (0,-1), (0,1)]
        actions = [1, 2, 3, 4] # Corresponding action numbers
        
        while queue:
            room_state, path = queue.popleft()
            if len(path) > MAX_DEPTH:
                return [] # No solution found

            # reduce the search space by checking if the state has been explored
            state_tohash = marshal.dumps(room_state)
            if state_tohash in explored_states:
                continue
            explored_states.add(state_tohash)
            

            # get information of the room
            player_pos = tuple(np.argwhere(room_state == 5)[0])
            boxes_on_target = set(map(tuple, np.argwhere((room_state == 3))))
            boxes_not_on_target = set(map(tuple, np.argwhere((room_state == 4))))
            boxes = boxes_on_target | boxes_not_on_target


            # Check if all boxes are on targets
            if not boxes_not_on_target:
                return path
                
            # Try each direction
            for move, action in zip(moves, actions):
                new_room_state = copy.deepcopy(room_state)
                new_player_pos = (player_pos[0] + move[0], player_pos[1] + move[1])
                
                # Check is new player position is wall or out of bound
                if new_player_pos[0] < 0 or new_player_pos[0] >= room_fixed.shape[0] \
                    or new_player_pos[1] < 0 or new_player_pos[1] >= room_fixed.shape[1] \
                    or room_fixed[new_player_pos] == 0:
                    continue
                    
                # If there's a box, check if we can push it
                if new_player_pos in boxes:
                    box_pos = new_player_pos # the original box position
                    new_box_pos = (new_player_pos[0] + move[0], new_player_pos[1] + move[1])
                    
                    # Can't push if hitting wall or another box or out of bound
                    if room_fixed[new_box_pos] == 0 or new_box_pos in boxes \
                        or new_box_pos[0] < 0 or new_box_pos[0] >= room_fixed.shape[0] \
                        or new_box_pos[1] < 0 or new_box_pos[1] >= room_fixed.shape[1]:
                        continue
                        
                    # move the box
                    
                    new_room_state[box_pos] = room_fixed[box_pos]
                    if room_fixed[new_box_pos] == 2:
                        new_room_state[new_box_pos] = 3
                    else:
                        new_room_state[new_box_pos] = 4
                
                # player moves
                new_room_state[player_pos] = room_fixed[player_pos]
                new_room_state[new_player_pos] = 5
                queue.append((new_room_state, path + [action]))
                        
        return [] # No solution found

# def plot_animation(imgs):
#     fig, ax = plt.subplots()
#     im = ax.imshow(imgs[0])
#     def init():
#         im.set_data(imgs[0])
#         return [im]
#     def update(i):
#         im.set_data(imgs[i])
#         return [im]
#     ani = animation.FuncAnimation(fig, update, frames=len(imgs), init_func=init, blit=True)
#     return ani

def plot_animation(imgs):
    height, width = imgs[0].shape[:2]
    fig = plt.figure(figsize=(width/100, height/100), dpi=500)
    
    ax = fig.add_axes([0, 0, 1, 1])
    
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_frame_on(False)
    
    im = ax.imshow(imgs[0])
    def init():
        im.set_data(imgs[0])
        return [im]
    def update(i):
        im.set_data(imgs[i])
        return [im]
    ani = animation.FuncAnimation(fig, update, frames=len(imgs), init_func=init, blit=True)
    return ani

def solve_sokoban(env, saved_animation_path):
    """
    Solve the given sokoban environment and save the animation
    """
    actions = get_shortest_action_path(env.room_fixed, env.room_state)
    print(f"Found {len(actions)} actions: {actions}")
    imgs = []
    img_before_action = env.render('rgb_array')
    imgs.append(img_before_action)
    for action in actions:
        env.step(action)
        img_after_action = env.render('rgb_array')
        imgs.append(img_after_action)
    ani = plot_animation(imgs)
    ani.save(saved_animation_path)



        





def add_random_player_movement(room_state, room_structure, move_probability=0.5, continue_probability=0.5, max_steps=3):
    """
    Randomly move the player after reverse_playing to make the level more challenging, also fix the problem that in generated map, the player is always adjacent to the box
    
    Parameters:
        room_state (np.ndarray): Current state of the room
        room_structure (np.ndarray): Fixed structure of the room
        move_probability (float): Probability of moving the player at all (0.0-1.0)
        continue_probability (float): Probability of continuing to move after each step (0.0-1.0)
        max_steps (int): Maximum number of steps the player can move (1-3)
    
    Returns:
        np.ndarray: Updated room state with randomly moved player
    """
    # Check if we should move the player at all
    if random.random() > move_probability:
        return room_state
    
    # Find player position
    player_pos = np.where(room_state == 5)
    player_pos = np.array([player_pos[0][0], player_pos[1][0]])
    
    # Keep track of previous positions to avoid moving back
    previous_positions = [tuple(player_pos)]
    
    # Make 1-3 random moves
    steps_taken = 0
    while steps_taken < max_steps:
        # Get all valid moves (can't move into walls or boxes)
        valid_moves = []
        for action in range(4):  # 0: up, 1: down, 2: left, 3: right
            change = CHANGE_COORDINATES[action]
            next_pos = player_pos + change
            
            # Check if next position is valid (empty space or target) and not a previous position
            if (room_state[next_pos[0], next_pos[1]] in [1, 2] and 
                tuple(next_pos) not in previous_positions):
                valid_moves.append((action, next_pos))
        
        # If no valid moves, break
        if not valid_moves:
            break
        
        # Choose a random valid move
        chosen_action, next_pos = random.choice(valid_moves)
        # print(f"player_pos: {player_pos}, next_pos: {next_pos}")
        
        # Move player
        room_state[player_pos[0], player_pos[1]] = room_structure[player_pos[0], player_pos[1]]
        room_state[next_pos[0], next_pos[1]] = 5
        
        # Update player position and track previous position
        player_pos = next_pos
        previous_positions.append(tuple(player_pos))
        
        steps_taken += 1
        
        # Decide whether to continue moving
        if steps_taken >= max_steps or random.random() > continue_probability:
            break
    
    return room_state



"""
Following code is adapted from the nicely written gym_sokoban repo
"""

def generate_room(dim=(13, 13), p_change_directions=0.35, num_steps=25, num_boxes=3, tries=4, second_player=False, search_depth=100):
    """
    Generates a Sokoban room, represented by an integer matrix. The elements are encoded as follows:
    wall = 0
    empty space = 1
    box target = 2
    box not on target = 3
    box on target = 4
    player = 5

    :param dim:
    :param p_change_directions:
    :param num_steps:
    :param num_boxes:
    :param tries:
    :param second_player:
    :return: Numpy 2d Array, box mapping, action sequence
    """
    room_state = np.zeros(shape=dim)
    room_structure = np.zeros(shape=dim)

    # Some times rooms with a score == 0 are the only possibility.
    # In these case, we try another model.
    for t in range(tries):
        room = room_topology_generation(dim, p_change_directions, num_steps)
        room = place_boxes_and_player(room, num_boxes=num_boxes, second_player=second_player)

        # Room fixed represents all not movable parts of the room
        room_structure = np.copy(room)
        room_structure[room_structure == 5] = 1

        # Room structure represents the current state of the room including movable parts
        room_state = room.copy()
        room_state[room_state == 2] = 4

        room_state, box_mapping, action_sequence = reverse_playing(room_state, room_structure, search_depth)
        room_state[room_state == 3] = 4

        if box_displacement_score(box_mapping) > 0:
            break

    if box_displacement_score(box_mapping) == 0:
        raise RuntimeWarning('Generated Model with score == 0')

    room_state = add_random_player_movement(
        room_state, 
        room_structure,
        move_probability=0.5,       # 50% chance the player will move
        continue_probability=0.5,   # 50% chance to continue moving after each step
        max_steps=3                 # Maximum of 3 steps
    )

    return room_structure, room_state, box_mapping, action_sequence


def room_topology_generation(dim=(10, 10), p_change_directions=0.35, num_steps=15):
    """
    Generate a room topology, which consits of empty floors and walls.

    :param dim:
    :param p_change_directions:
    :param num_steps:
    :return:
    """
    dim_x, dim_y = dim

    # The ones in the mask represent all fields which will be set to floors
    # during the random walk. The centered one will be placed over the current
    # position of the walk.
    masks = [
        [
            [0, 0, 0],
            [1, 1, 1],
            [0, 0, 0]
        ],
        [
            [0, 1, 0],
            [0, 1, 0],
            [0, 1, 0]
        ],
        [
            [0, 0, 0],
            [1, 1, 0],
            [0, 1, 0]
        ],
        [
            [0, 0, 0],
            [1, 1, 0],
            [1, 1, 0]
        ],
        [
            [0, 0, 0],
            [0, 1, 1],
            [0, 1, 0]
        ]
    ]

    # Possible directions during the walk
    directions = [(1, 0), (0, 1), (-1, 0), (0, -1)]
    direction = random.sample(directions, 1)[0]

    # Starting position of random walk
    position = np.array([
        random.randint(1, dim_x - 1),
        random.randint(1, dim_y - 1)]
    )

    level = np.zeros(dim, dtype=int)

    for s in range(num_steps):

        # Change direction randomly
        if random.random() < p_change_directions:
            direction = random.sample(directions, 1)[0]

        # Update position
        position = position + direction
        position[0] = max(min(position[0], dim_x - 2), 1)
        position[1] = max(min(position[1], dim_y - 2), 1)

        # Apply mask
        mask = random.sample(masks, 1)[0]
        mask_start = position - 1
        level[mask_start[0]:mask_start[0] + 3, mask_start[1]:mask_start[1] + 3] += mask

    level[level > 0] = 1
    level[:, [0, dim_y - 1]] = 0
    level[[0, dim_x - 1], :] = 0

    return level


def place_boxes_and_player(room, num_boxes, second_player):
    """
    Places the player and the boxes into the floors in a room.

    :param room:
    :param num_boxes:
    :return:
    """
    # Get all available positions
    possible_positions = np.where(room == 1)
    num_possible_positions = possible_positions[0].shape[0]
    num_players = 2 if second_player else 1

    if num_possible_positions <= num_boxes + num_players:
        raise RuntimeError('Not enough free spots (#{}) to place {} player and {} boxes.'.format(
            num_possible_positions,
            num_players,
            num_boxes)
        )

    # Place player(s)
    ind = np.random.randint(num_possible_positions)
    player_position = possible_positions[0][ind], possible_positions[1][ind]
    room[player_position] = 5

    if second_player:
        ind = np.random.randint(num_possible_positions)
        player_position = possible_positions[0][ind], possible_positions[1][ind]
        room[player_position] = 5

    # Place boxes
    for n in range(num_boxes):
        possible_positions = np.where(room == 1)
        num_possible_positions = possible_positions[0].shape[0]

        ind = np.random.randint(num_possible_positions)
        box_position = possible_positions[0][ind], possible_positions[1][ind]
        room[box_position] = 2

    return room


# Global variables used for reverse playing.
explored_states = set()
num_boxes = 0
best_room_score = -1
best_room = None
best_box_mapping = None


def reverse_playing(room_state, room_structure, search_depth=100):
    """
    This function plays Sokoban reverse in a way, such that the player can
    move and pull boxes.
    It ensures a solvable level with all boxes not being placed on a box target.
    :param room_state:
    :param room_structure:
    :param search_depth:
    :return: 2d array, box mapping, action sequence
    """
    global explored_states, num_boxes, best_room_score, best_room, best_box_mapping, best_action_sequence

    # Box_Mapping is used to calculate the box displacement for every box
    box_mapping = {}
    box_locations = np.where(room_structure == 2)
    num_boxes = len(box_locations[0])
    for l in range(num_boxes):
        box = (box_locations[0][l], box_locations[1][l])
        box_mapping[box] = box

    # explored_states globally stores the best room state and score found during search
    explored_states = set()
    best_room_score = -1
    best_room = None
    best_box_mapping = box_mapping
    best_action_sequence = []

    depth_first_search(room_state, room_structure, box_mapping, box_swaps=0, last_pull=(-1, -1), ttl=search_depth, action_sequence=[])

    return best_room, best_box_mapping, best_action_sequence


def depth_first_search(room_state, room_structure, box_mapping, box_swaps=0, last_pull=(-1, -1), ttl=300, action_sequence=[]):
    """
    Searches through all possible states of the room.
    This is a recursive function, which stops if the ttl is reduced to 0 or
    over 1.000.000 states have been explored.
    :param room_state:
    :param room_structure:
    :param box_mapping:
    :param box_swaps:
    :param last_pull:
    :param ttl:
    :param action_sequence:
    :return:
    """
    global explored_states, num_boxes, best_room_score, best_room, best_box_mapping, best_action_sequence

    ttl -= 1
    if ttl <= 0 or len(explored_states) >= 300000:
        return

    state_tohash = marshal.dumps(room_state)

    # Only search this state, if it not yet has been explored
    if not (state_tohash in explored_states):

        # Add current state and its score to explored states
        room_score = box_swaps * box_displacement_score(box_mapping)
        if np.where(room_state == 2)[0].shape[0] != num_boxes:
            room_score = 0

        if room_score > best_room_score:
            best_room = room_state.copy()
            best_room_score = room_score
            best_box_mapping = box_mapping.copy()
            best_action_sequence = action_sequence.copy()

        explored_states.add(state_tohash)

        for action in ACTION_LOOKUP.keys():
            # The state and box mapping need to be copied to ensure
            # every action starts from a similar state.

            # TODO: A tentitive try here to make less moves
            if action >= 4:
                continue

            room_state_next = room_state.copy()
            box_mapping_next = box_mapping.copy()

            room_state_next, box_mapping_next, last_pull_next = \
                reverse_move(room_state_next, room_structure, box_mapping_next, last_pull, action)

            box_swaps_next = box_swaps
            if last_pull_next != last_pull:
                box_swaps_next += 1
            
            action_sequence_next = action_sequence + [action]
            # action_sequence_next = action_sequence + [(action, box_mapping_next != box_mapping)] # add whether a box is moved
            depth_first_search(room_state_next, room_structure, box_mapping_next, box_swaps_next, last_pull_next, ttl, action_sequence_next)
            

def reverse_move(room_state, room_structure, box_mapping, last_pull, action):
    """
    Perform reverse action. Where all actions in the range [0, 3] correspond to
    push actions and the ones greater 3 are simmple move actions.
    :param room_state:
    :param room_structure:
    :param box_mapping:
    :param last_pull:
    :param action:
    :return:
    """
    player_position = np.where(room_state == 5)
    player_position = np.array([player_position[0][0], player_position[1][0]])

    change = CHANGE_COORDINATES[action % 4]
    next_position = player_position + change

    # Check if next position is an empty floor or an empty box target
    if room_state[next_position[0], next_position[1]] in [1, 2]:

        # Move player, independent of pull or move action.
        room_state[player_position[0], player_position[1]] = room_structure[player_position[0], player_position[1]]
        room_state[next_position[0], next_position[1]] = 5

        # In addition try to pull a box if the action is a pull action
        if action < 4:
            possible_box_location = change[0] * -1, change[1] * -1
            possible_box_location += player_position

            if room_state[possible_box_location[0], possible_box_location[1]] in [3, 4]:
                # Perform pull of the adjacent box
                room_state[player_position[0], player_position[1]] = 3
                room_state[possible_box_location[0], possible_box_location[1]] = room_structure[
                    possible_box_location[0], possible_box_location[1]]

                # Update the box mapping
                for k in box_mapping.keys():
                    if box_mapping[k] == (possible_box_location[0], possible_box_location[1]):
                        box_mapping[k] = (player_position[0], player_position[1])
                        last_pull = k

    return room_state, box_mapping, last_pull


def box_displacement_score(box_mapping):
    """
    Calculates the sum of all Manhattan distances, between the boxes
    and their origin box targets.
    :param box_mapping:
    :return:
    """
    score = 0
    
    for box_target in box_mapping.keys():
        box_location = np.array(box_mapping[box_target])
        box_target = np.array(box_target)
        dist = np.sum(np.abs(box_location - box_target))
        score += dist

    return score


TYPE_LOOKUP = {
    0: 'wall',
    1: 'empty space',
    2: 'box target',
    3: 'box on target',
    4: 'box not on target',
    5: 'player'
}

ACTION_LOOKUP = {
    0: 'push up',
    1: 'push down',
    2: 'push left',
    3: 'push right',
    4: 'move up',
    5: 'move down',
    6: 'move left',
    7: 'move right',
}

# Moves are mapped to coordinate changes as follows
# 0: Move up
# 1: Move down
# 2: Move left
# 3: Move right
CHANGE_COORDINATES = {
    0: (-1, 0),
    1: (1, 0),
    2: (0, -1),
    3: (0, 1)
}
def sokoban_state_to_sentences(state_dict):
    """
    Convert Sokoban game state dictionary to descriptive sentences about spatial relationships.
    
    Args:
        state_dict (dict): Dictionary containing:
            - player_position: tuple (row, col)
            - box_positions: list of tuples [(row, col), ...]
            - target_positions: list of tuples [(row, col), ...]
            - wall_positions: list of tuples [(row, col), ...] (ignored)
            - grid_size: tuple (rows, cols)
    
    Returns:
        list: List of descriptive sentences
    """
    sentences = []
    
    player_pos = state_dict['player_position']
    box_positions = state_dict['box_positions']
    target_positions = state_dict['target_positions']
    
    def get_relative_position(pos1, pos2):
        """
        Get relative position description between two positions.
        pos1 is the reference point, pos2 is described relative to pos1.
        """
        row1, col1 = pos1
        row2, col2 = pos2
        
        if pos1 == pos2:
            return "at the same place as"
        
        # Determine row relationship
        if row1 == row2:
            if col1 > col2:
                return "at the same row and to the left of"
            else:  # col1 < col2
                return "at the same row and to the right of"
        elif col1 == col2:
            if row1 > row2:
                return "above and at the same column as"
            else:  # row1 < row2
                return "below and at the same column as"
        else:
            # Different row and column
            row_desc = "above" if row1 > row2 else "below"
            col_desc = "on the left side" if col1 > col2 else "on the right side"
            return f"{row_desc} and {col_desc} of"
    
    # Describe each box relative to player
    for i, box_pos in enumerate(box_positions):
        box_relation = get_relative_position(player_pos, box_pos)
        sentences.append(f"box{i} is {box_relation} the player")
    
    # Describe each target relative to player
    for i, target_pos in enumerate(target_positions):
        target_relation = get_relative_position(player_pos, target_pos)
        sentences.append(f"target{i} is {target_relation} the player")
    
    # Describe each target relative to each box
    for i, target_pos in enumerate(target_positions):
        for j, box_pos in enumerate(box_positions):
            target_box_relation = get_relative_position(box_pos, target_pos)
            sentences.append(f"target{i} is {target_box_relation} box{j}")
    
    return sentences