import os
import copy
import math
import itertools
import time
import json
import heapq
import random
import numpy as np
from tqdm import tqdm
from contextlib import contextmanager
from typing import List, Dict, Tuple, Set, Optional
from functools import lru_cache
import ray

from gen_data import generate_configs
from box1_llm import object_to_json_serializable
import yaml
from inference.utils import seed_everything
from simulation.simenv.box1 import Box1Env, Action, ExecutionRes
from verl.trainer.main_ppo import get_cache_dir


@contextmanager
def _timer(label="Elapsed time"):
    start = time.time()
    try:
        yield
    finally:
        end = time.time()
        # print(f"{label}: {end - start:.4f} seconds")


def count_unique_envs(grid_n: int, grid_m: int, num_objects: int):
    sub_offsets = [0.25, 0.75]
    valid_positions = [
        (i + dx, j + dy)
        for i in range(grid_n)
        for j in range(grid_m)
        for dx in sub_offsets
        for dy in sub_offsets
    ]
    total_pos = len(valid_positions)

    # We need 2 * num_objects unique positions for (start, target), disjoint
    if 2 * num_objects > total_pos:
        return 0

    object_combos = math.comb(total_pos, num_objects)
    target_combos = math.comb(total_pos - num_objects, num_objects)
    total_unique = object_combos * target_combos
    return total_unique


# Ray task for processing state batches
@ray.remote
def process_state_batch(current_hash, current_state, g_score, randomness_factor):
    # Early check
    if current_state.is_goal():
        return "goal", current_hash, None, None

    successors = []
    action_sets = generate_valid_actions(current_state)

    for actions, next_state in action_sets:
        next_hash = next_state.hash()

        action_cost = (
            1.0 + (random.random() - 0.5) * randomness_factor
            if randomness_factor > 0
            else 1.0
        )
        tentative_g = g_score + action_cost
        heuristic = next_state.heuristic()
        f_score = tentative_g + heuristic * (1.0 + random.random() * randomness_factor)

        successors.append(
            (f_score, heuristic, next_hash, tentative_g, actions, next_state)
        )

    return "continue", current_hash, successors, None


# Ray task for generating valid actions for a single robot
@ray.remote
def generate_valid_action_single_robot_ray(robot_id, robot_pos, env):
    """Optimized single robot action generation with restricted search space"""
    robot_actions = []
    touch_obj = []
    occured_actions = {}

    # Check if robot arm is already at an object position
    carrying_object = False
    carried_obj_id = None
    for obj_id, obj_pos in env.objects.items():
        if list(robot_pos.arm_pos) == list(obj_pos):
            carrying_object = True
            carried_obj_id = obj_id
            break

    # OPTIMIZATION: Only consider objects that need to be moved (not at target)
    objects_to_move = {
        obj_id: obj_pos
        for obj_id, obj_pos in env.objects.items()
        if list(obj_pos) != list(env.targets[obj_id])
    }

    # If no objects need moving, don't generate actions for this robot
    if not objects_to_move and not carrying_object:
        return robot_id, []

    if carrying_object:
        # Robot arm is at an object position - try to move the object toward its target
        target_pos = env.targets[carried_obj_id]
        obj_pos = env.objects[carried_obj_id]

        # Skip if object is already at target
        if list(obj_pos) != list(target_pos):
            # OPTIMIZATION: Prioritize direct move to target if possible
            if can_robot_reach(robot_pos.base_pos, target_pos):
                action = create_action(robot_id, robot_pos, obj_pos, target_pos, True)
                if verify_action(action, env):
                    robot_actions.append(action)
                    touch_obj.append(True)
                    occured_actions[action.hash()] = True

            # If direct move isn't possible, try most promising intermediate steps
            else:
                cur_distance = abs(obj_pos[0] - target_pos[0]) + abs(
                    obj_pos[1] - target_pos[1]
                )

                # OPTIMIZATION: Limit intermediate positions to consider
                step_positions = get_step_positions(robot_pos, obj_pos, target_pos, env)
                # Take only top 3 most promising positions
                for step_pos in step_positions[:3]:
                    new_distance = abs(step_pos[0] - target_pos[0]) + abs(
                        step_pos[1] - target_pos[1]
                    )
                    if (
                        can_robot_reach(robot_pos.base_pos, step_pos)
                        and new_distance < cur_distance
                    ):
                        action = create_action(
                            robot_id, robot_pos, obj_pos, step_pos, True
                        )
                        if verify_action(action, env):
                            if action.hash() not in occured_actions:
                                robot_actions.append(action)
                                touch_obj.append(True)
                                occured_actions[action.hash()] = True
    else:
        # OPTIMIZATION: For non-carrying robots, prioritize nearby objects
        # Sort objects by distance to robot arm
        object_distances = []
        for obj_id, obj_pos in objects_to_move.items():
            # Calculate distance from robot arm to object
            dist = abs(robot_pos.arm_pos[0] - obj_pos[0]) + abs(
                robot_pos.arm_pos[1] - obj_pos[1]
            )
            # Also consider distance from object to its target
            target_dist = abs(obj_pos[0] - env.targets[obj_id][0]) + abs(
                obj_pos[1] - env.targets[obj_id][1]
            )
            # Weighted score prioritizing objects that need moving and are close to robot
            score = dist + 0.5 * target_dist
            object_distances.append((obj_id, obj_pos, score))

        # Sort by score (lower is better)
        object_distances.sort(key=lambda x: x[2])

        # OPTIMIZATION: Consider only the closest 2 objects
        for obj_id, obj_pos, _ in object_distances[:2]:
            # Check if robot can reach the object
            if can_robot_reach(robot_pos.base_pos, obj_pos):
                action = create_action(
                    robot_id, robot_pos, robot_pos.arm_pos, obj_pos, False
                )
                if verify_action(action, env):
                    if action.hash() not in occured_actions:
                        robot_actions.append(action)
                        touch_obj.append(True)
                        occured_actions[action.hash()] = True

    robot_actions = [(x, y) for x, y in zip(robot_actions, touch_obj)]
    return robot_id, robot_actions


# Ray task for checking and applying actions
@ray.remote
def check_and_apply_actions_ray(combined_actions, env, state):
    if verify_parallel_actions(combined_actions, env):
        return (combined_actions, state.apply_actions(combined_actions))
    return (False,)


def plan_actions(
    env: "Box1Env",
    max_iterations: int = 10000,
    randomness_factor: float = 0.15,
    max_plan_number: int = 3,
) -> List[List["Action"]]:
    """
    Implements an A* search algorithm with randomness to find a sequence of actions that moves
    all objects to their target positions in the Box1Env. Optimized with Ray.
    """
    print(f"Starting search algorithm with randomness factor: {randomness_factor}...")

    # Initialize Ray if not already started
    if not ray.is_initialized():
        ray.init(_temp_dir=get_cache_dir())

    # OPTIMIZATION: Use heapq with a custom wrapper to avoid comparing states directly
    class PrioritizedItem:
        def __init__(self, priority, heuristic, hash_val):
            self.priority = priority
            self.heuristic = heuristic
            self.hash_val = hash_val

        def __lt__(self, other):
            # Compare by priority first, then heuristic, then hash
            if self.priority != other.priority:
                return self.priority < other.priority
            if self.heuristic != other.heuristic:
                return self.heuristic < other.heuristic
            return self.hash_val < other.hash_val

    open_set = []
    closed_set = set()
    found_plans = []
    plan_hashes = set()

    # Initial state
    initial_state = State(env)
    initial_hash = initial_state.hash()

    # Track g_scores (cost from start to node)
    g_scores = {initial_hash: 0.0}

    # Track parent states and actions that led to them
    came_from = {}

    # Initialize search with the starting state
    random_factor = random.random() * randomness_factor
    heapq.heappush(
        open_set,
        PrioritizedItem(
            initial_state.heuristic() * (1.0 + random_factor),
            initial_state.heuristic(),
            initial_hash,
        ),
    )
    states = {initial_hash: initial_state}
    last_time = time.time()
    iterations = 0

    # OPTIMIZATION: Adaptive batch size based on remaining work
    max_batch_size = 32

    # OPTIMIZATION: Track the "best so far" state
    best_heuristic = initial_state.heuristic()
    last_improvement = 0

    # OPTIMIZATION: Early stopping if stuck
    no_improvement_limit = 2000

    while (
        open_set and iterations < max_iterations and len(found_plans) < max_plan_number
    ):
        iterations += 1

        if iterations % 100 == 0:
            print(
                f"Iteration {iterations}, queue size: {len(open_set)}, best heuristic: {best_heuristic:.2f}, time: {time.time() - last_time:.2f}s"
            )
            last_time = time.time()

            # Check if we're stuck with no improvement
            if iterations - last_improvement > no_improvement_limit:
                print(
                    f"No improvement after {no_improvement_limit} iterations, increasing randomness"
                )
                randomness_factor *= 1.5  # Increase randomness to escape local minimum
                last_improvement = iterations  # Reset counter

        # OPTIMIZATION: Adaptive batch size based on queue size
        batch_size = min(max_batch_size, max(1, len(open_set) // 2))

        # Process states in batches for better parallelism with Ray
        if len(open_set) >= batch_size:
            batch_futures = []
            batch_hashes = []

            # Extract a batch of promising states
            for _ in range(min(batch_size, len(open_set))):
                item = heapq.heappop(open_set)
                current_hash = item.hash_val

                # Skip if already processed
                if current_hash in closed_set:
                    continue

                batch_hashes.append(current_hash)
                current_state = states[current_hash]

                # Track best state seen so far
                h_val = current_state.heuristic()
                if h_val < best_heuristic:
                    best_heuristic = h_val
                    last_improvement = iterations

                # Submit Ray task for this state
                future = process_state_batch.remote(
                    current_hash,
                    current_state,
                    g_scores[current_hash],
                    randomness_factor,
                )
                batch_futures.append(future)

            # Wait for all batch results
            batch_results = ray.get(batch_futures)

            # Process results
            for i, result in enumerate(batch_results):
                status, current_hash, successors, _ = result

                # Mark as processed
                closed_set.add(current_hash)

                # Goal found
                if status == "goal":
                    print(f"Solution found after {iterations} iterations!")
                    path = reconstruct_path(came_from, current_hash, states)
                    path_hash = str(path)

                    if path_hash not in plan_hashes:
                        print(f"Unique solution #{len(found_plans) + 1} found!")
                        found_plans.append(path)
                        plan_hashes.add(path_hash)

                    if len(found_plans) >= max_plan_number:
                        break

                # Process successors
                if successors:
                    for (
                        f_score,
                        heuristic,
                        next_hash,
                        tentative_g,
                        actions,
                        next_state,
                    ) in successors:
                        if next_hash in closed_set:
                            continue

                        if (
                            next_hash not in g_scores
                            or tentative_g < g_scores[next_hash]
                        ):
                            came_from[next_hash] = (current_hash, actions)
                            g_scores[next_hash] = tentative_g

                            if next_hash not in states:
                                states[next_hash] = next_state
                                heapq.heappush(
                                    open_set,
                                    PrioritizedItem(f_score, heuristic, next_hash),
                                )
        else:
            # Process single state when batch is not full
            item = heapq.heappop(open_set)
            current_hash = item.hash_val

            # Skip if already processed
            if current_hash in closed_set:
                continue

            current_state = states[current_hash]

            # Track best state seen so far
            h_val = current_state.heuristic()
            if h_val < best_heuristic:
                best_heuristic = h_val
                last_improvement = iterations

            # Check if goal reached
            if current_state.is_goal():
                print(f"Solution found after {iterations} iterations!")
                path = reconstruct_path(came_from, current_hash, states)
                path_hash = str(path)

                if path_hash not in plan_hashes:
                    print(f"Unique solution #{len(found_plans) + 1} found!")
                    found_plans.append(path)
                    plan_hashes.add(path_hash)

                closed_set.add(current_hash)
                continue

            # Mark as processed
            closed_set.add(current_hash)

            # Generate successors
            action_sets = generate_valid_actions(current_state)

            for actions, next_state in action_sets:
                next_hash = next_state.hash()

                if next_hash in closed_set:
                    continue

                action_cost = 1.0
                if randomness_factor > 0:
                    action_cost = 1.0 + (random.random() - 0.5) * randomness_factor

                tentative_g = g_scores[current_hash] + action_cost

                if next_hash not in g_scores or tentative_g < g_scores[next_hash]:
                    came_from[next_hash] = (current_hash, actions)
                    g_scores[next_hash] = tentative_g

                    heuristic = next_state.heuristic()
                    random_factor = random.random() * randomness_factor
                    f_score = tentative_g + heuristic * (1.0 + random_factor)

                    if next_hash not in states:
                        states[next_hash] = next_state
                        heapq.heappush(
                            open_set, PrioritizedItem(f_score, heuristic, next_hash)
                        )

    if found_plans:
        print(f"Total found plans: {len(found_plans)}")
    else:
        print(f"Failed to find solution after {iterations} iterations")

    return found_plans


def generate_valid_actions(state: "State") -> List[List["Action"]]:
    """
    Generate valid actions from the current state, including multi-robot actions.
    Optimized with Ray for parallelism and with restricted search space.
    """
    env = state.env
    single_robot_actions = {}  # Map robot_id -> list of possible actions
    valid_action_sets = []

    # OPTIMIZATION: Only consider robots that can reach objects that need moving
    relevant_robots = {}
    objects_to_move = {
        obj_id: obj_pos
        for obj_id, obj_pos in env.objects.items()
        if list(obj_pos) != list(env.targets[obj_id])
    }

    # Skip action generation if all objects are at their targets
    if not objects_to_move:
        return valid_action_sets

    # For each robot, determine if it's relevant (carrying an object or can reach one that needs moving)
    for robot_id, robot_pos in env.robots.items():
        # Check if robot is carrying an object
        carrying = False
        for obj_id, obj_pos in env.objects.items():
            if list(robot_pos.arm_pos) == list(obj_pos):
                carrying = True
                relevant_robots[robot_id] = robot_pos
                break

        # If not carrying, check if it can reach any object that needs moving
        if not carrying:
            for obj_id, obj_pos in objects_to_move.items():
                if can_robot_reach(robot_pos.base_pos, obj_pos):
                    relevant_robots[robot_id] = robot_pos
                    break

    # First: generate possible actions for each relevant robot individually using Ray
    with _timer("Single robot actions"):
        futures = []
        for robot_id, robot_pos in relevant_robots.items():
            future = generate_valid_action_single_robot_ray.remote(
                robot_id, robot_pos, env
            )
            futures.append(future)

        results = ray.get(futures)
        for robot_id, actions in results:
            if actions:
                single_robot_actions[robot_id] = actions

    # Add single robot actions that touch objects
    with _timer("Add single robot actions"):
        for robot_id, actions in single_robot_actions.items():
            for action in actions:
                if action[1]:  # If action touches object
                    valid_action_sets.append(
                        [action[0], state.apply_actions([action[0]])]
                    )

    # OPTIMIZATION: For combinations, only consider robots with high-value actions
    with _timer("Combination of robot actions"):
        if len(single_robot_actions) >= 2:
            robot_ids = sorted(
                single_robot_actions.keys(),
                # Prioritize robots with more actions
                key=lambda r_id: len(single_robot_actions[r_id]),
                reverse=True,
            )[:5]
            # [:4]  # Consider max 4 robots for combinations

            with _timer("Generate tasks"):
                tasks = []
                # OPTIMIZATION: Limit group size based on object count
                max_group_size = min(3, len(objects_to_move))

                for group_size in range(2, max_group_size + 1):
                    for robot_group in itertools.combinations(robot_ids, group_size):
                        # Filter actions to those that touch objects
                        filtered_actions = {
                            r_id: [a for a in single_robot_actions[r_id] if a[1]]
                            for r_id in robot_group
                        }

                        # Skip groups with no object-touching actions
                        if any(len(acts) == 0 for acts in filtered_actions.values()):
                            continue

                        # OPTIMIZATION: Limit number of actions per robot to reduce combinations
                        limited_actions = {
                            r_id: acts[:2]  # Consider only top 2 actions per robot
                            for r_id, acts in filtered_actions.items()
                        }

                        combos = list(
                            itertools.product(
                                *[limited_actions[robot_id] for robot_id in robot_group]
                            )
                        )

                        # Extract just the actions (not the touch flags)
                        combos = [[a[0] for a in combo] for combo in combos]

                        # Quick filter before submitting to Ray
                        viable_combos = []
                        for combo in combos:
                            # Quick check - robot positions and object interactions should not conflict
                            robot_targets = set()
                            object_targets = set()
                            valid = True

                            for action in combo:
                                # Check for robot position conflicts
                                if tuple(action.pos_e) in robot_targets:
                                    valid = False
                                    break
                                robot_targets.add(tuple(action.pos_e))

                                # Check for object conflicts
                                if action.carry:
                                    obj_pos = tuple(action.pos_s)
                                    if obj_pos in object_targets:
                                        valid = False
                                        break
                                    object_targets.add(obj_pos)

                            if valid:
                                viable_combos.append(combo)

                        # OPTIMIZATION: Limit number of viable combinations to process
                        viable_combos = viable_combos[
                            :20
                        ]  # Process at most 20 combinations

                        for combo in viable_combos:
                            tasks.append((combo, env, state))

            # Process in batches to avoid overwhelming Ray
            with _timer("Check and apply actions"):
                batch_size = 50
                for i in range(0, len(tasks), batch_size):
                    batch = tasks[i : i + batch_size]
                    futures = [
                        check_and_apply_actions_ray.remote(*task) for task in batch
                    ]
                    results = ray.get(futures)

                    for result in results:
                        if result and result[0]:
                            valid_action_sets.append(result)

    # OPTIMIZATION: Improve action sorting with weighted heuristic
    def action_score(action_set):
        state = action_set[1]
        # Use heuristic but also consider number of robots involved
        action_count = len(action_set[0]) if isinstance(action_set[0], list) else 1
        # Prioritize actions that involve more robots (more parallel work)
        return state.heuristic() / (1 + 0.2 * action_count)

    # Sort by score and limit to reduce branching factor
    valid_action_sets = sorted(valid_action_sets, key=action_score)
    # OPTIMIZATION: Adaptive branching factor based on current state complexity
    branching_limit = min(50 + 10 * len(objects_to_move), 100)
    valid_action_sets = valid_action_sets[:branching_limit]

    return valid_action_sets


def reconstruct_path(came_from, final_hash, states):
    """
    Reconstruct the path from start to goal from the came_from map.

    Args:
        came_from: Dictionary mapping state hash to (parent_hash, actions)
        final_hash: Hash of the goal state
        states: Dictionary mapping hashes to state objects

    Returns:
        List of action sequences
    """
    actions_sequence = []
    current_hash = final_hash

    while current_hash in came_from:
        parent_hash, actions = came_from[current_hash]
        actions_sequence.append(actions)
        current_hash = parent_hash

    # Reverse to get actions from start to goal
    actions_sequence.reverse()
    return actions_sequence


def verify_parallel_actions(actions, env):
    """
    Verify that multiple actions can be executed in parallel.
    This is more complex than verifying single actions.

    Args:
        actions: List of Action objects
        env: Environment

    Returns:
        True if all actions can be executed in parallel, False otherwise
    """
    test_env = copy.deepcopy(env)
    result = test_env.verify(actions)
    return result.success == ExecutionRes.Success


def verify_action(action, env):
    """Verify that an action is valid using the environment's verify function"""
    # Create a copy of the environment to verify
    test_env = copy.deepcopy(env)
    result = test_env.verify([action])
    return result.success == ExecutionRes.Success


def is_position_occupied(position, env):
    """Check if a position is occupied by an object"""
    for obj_id, obj_pos in env.objects.items():
        if tuple(obj_pos) == position:
            return True
    return False


def create_action(robot_id, robot_pos, start_pos, end_pos, carry):
    """Create an Action object"""
    # Convert to tuples if needed
    from_pos = tuple(start_pos) if isinstance(start_pos, list) else start_pos
    to_pos = tuple(end_pos) if isinstance(end_pos, list) else end_pos
    base_pos = (
        list(robot_pos.base_pos)
        if isinstance(robot_pos.base_pos, list)
        else robot_pos.base_pos
    )
    arm_pos = (
        list(robot_pos.arm_pos)
        if isinstance(robot_pos.arm_pos, list)
        else robot_pos.arm_pos
    )

    return Action(
        robot_id=robot_id,
        arm_pos=list(arm_pos),
        base_pos=list(base_pos),
        pos_s=list(from_pos),
        pos_e=list(to_pos),
        carry=carry,
    )


def can_robot_reach(base_pos, target_pos):
    """Check if a robot can reach the target position from its base"""
    x_diff = abs(base_pos[0] - target_pos[0])
    y_diff = abs(base_pos[1] - target_pos[1])
    return x_diff < 1.0 and y_diff < 1.0


def get_robot_can_reach_positions(base_pos, env):
    # Right now we only consider 0.25, 0.75
    offsets = [-0.25, -0.75, 0.25, 0.75]
    res = [(base_pos[0] + dx, base_pos[1] + dy) for dx in offsets for dy in offsets]
    res = [x for x in res if 0 < x[0] < env.grid_n and 0 < x[1] < env.grid_m]
    res = [x for x in res if x[0] != base_pos[0] or x[1] != base_pos[1]]
    return res


def get_step_positions(robot_pos, start_pos, end_pos, env):
    """
    Get valid intermediate positions when moving from start to end.
    Returns positions in order of preference.
    """
    positions = []
    positions = get_robot_can_reach_positions(robot_pos.base_pos, env)
    positions = [list(x) for x in positions]
    positions = [x for x in positions if not (x[0] % 0.5 == 0 or x[1] % 0.5 == 0)]

    # sort by the manhattan distance to the target
    positions = sorted(
        positions,
        key=lambda x: abs(x[0] - end_pos[0]) + abs(x[1] - end_pos[1]),
    )
    return positions


class State:
    """
    Represents a state in the search space.
    Contains the environment configuration at a point in time.
    """

    def __init__(self, env):
        self.env = copy.deepcopy(env)
        # OPTIMIZATION: Cache heuristic value
        self._heuristic_value = None

    def __str__(self):
        res = "State:(\n"
        for obj_id, obj_pos in self.env.objects.items():
            res += f"\tObject {obj_id}: {obj_pos}\tTarget: {self.env.targets[obj_id]}\n"
        res += "\n"
        for robot_id, robot_pos in self.env.robots.items():
            res += f"\tRobot {robot_id}: {robot_pos.base_pos} -> {robot_pos.arm_pos}\n"
        res += ")\n"
        return res

    def apply_actions(self, actions):
        """Apply a set of actions to create a new state"""
        new_state = State(self.env)
        new_state.env.simulate(actions)
        return new_state

    def is_goal(self):
        """Check if this state represents the goal (all objects at targets)"""
        return self.env.check_final()

    def hash(self):
        # Use a more direct hashing approach that avoids sorting and tuple conversion
        h = 0
        for obj_id in sorted(self.env.objects.keys()):
            obj_pos = self.env.objects[obj_id]
            h = h * 31 + hash(obj_id)
            h = h * 31 + hash((int(obj_pos[0] * 100), int(obj_pos[1] * 100)))

        for r_id in sorted(self.env.robots.keys()):
            pos = self.env.robots[r_id]
            h = h * 31 + hash(r_id)
            h = h * 31 + hash((int(pos.arm_pos[0] * 100), int(pos.arm_pos[1] * 100)))

        return h

    def heuristic(self):
        """
        Estimate the cost to reach the goal from this state.
        Uses the sum of Manhattan distances from objects to their targets.
        Optimized with caching and focused calculations.
        """
        # Return cached value if available
        if self._heuristic_value is not None:
            return self._heuristic_value

        total_distance = 0.0

        # OPTIMIZATION: Only consider objects not at their targets
        objects_to_move = {}
        for obj_id, obj_pos in self.env.objects.items():
            target_pos = self.env.targets[obj_id]
            if list(obj_pos) != list(target_pos):
                objects_to_move[obj_id] = (obj_pos, target_pos)

        # If all objects are at targets, return 0
        if not objects_to_move:
            self._heuristic_value = 0.0
            return 0.0

        # Pre-compute which robots cover which objects
        robot_covering = {}  # obj_id -> robot_id
        for obj_id, (obj_pos, _) in objects_to_move.items():
            for robot_id, robot_pos in self.env.robots.items():
                if list(obj_pos) == list(robot_pos.arm_pos):
                    robot_covering[obj_id] = robot_id
                    break

        # Calculate which robots can reach which objects
        robot_can_reach = {}  # robot_id -> list of obj_ids
        for robot_id, robot_pos in self.env.robots.items():
            robot_can_reach[robot_id] = []
            for obj_id, (obj_pos, _) in objects_to_move.items():
                if can_robot_reach(robot_pos.base_pos, obj_pos):
                    robot_can_reach[robot_id].append(obj_id)

        for obj_id, (obj_pos, target_pos) in objects_to_move.items():
            # Manhattan distance from object to target
            obj_distance = abs(obj_pos[0] - target_pos[0]) + abs(
                obj_pos[1] - target_pos[1]
            )
            total_distance += obj_distance * 10  # Weight distance heavily

            # If object is covered by a robot, add robot's distance to target
            if obj_id in robot_covering:
                robot_id = robot_covering[obj_id]
                robot_pos = self.env.robots[robot_id]
                robot_distance = abs(robot_pos.base_pos[0] - target_pos[0]) + abs(
                    robot_pos.base_pos[1] - target_pos[1]
                )
                total_distance += robot_distance * 5
            else:
                # Object not covered - find closest robot that can reach it
                min_robot_distance = float("inf")
                for robot_id, reachable_objs in robot_can_reach.items():
                    if obj_id in reachable_objs:
                        robot_pos = self.env.robots[robot_id]
                        # Distance from robot arm to object plus robot to target
                        arm_to_obj = abs(robot_pos.arm_pos[0] - obj_pos[0]) + abs(
                            robot_pos.arm_pos[1] - obj_pos[1]
                        )
                        robot_to_target = abs(
                            robot_pos.base_pos[0] - target_pos[0]
                        ) + abs(robot_pos.base_pos[1] - target_pos[1])
                        total_dist = arm_to_obj + robot_to_target
                        min_robot_distance = min(min_robot_distance, total_dist)

                if min_robot_distance < float("inf"):
                    total_distance += min_robot_distance * 5
                else:
                    # No robot can reach this object - severe penalty
                    total_distance += obj_distance * 20

        self._heuristic_value = total_distance
        return total_distance


def remove_redundant_move(actions):
    # This function reduces the unnecessary move for some robots
    # Basically, it counts whether a robot touches an object after one move, if not then drop it
    new_actions = []
    used = set()
    for step_action in reversed(actions):
        new_step_action = []
        if not isinstance(step_action, list):
            step_action = [step_action]
        for action in step_action:
            if action.carry:
                used.add(action.robot_id)
            if action.robot_id in used:
                new_step_action.append(action)
        if new_step_action:
            new_actions.append(new_step_action)
    return list(reversed(new_actions))


def squeeze_move(env, actions):
    # This function try to squeeze some later moves of certain objects to previous
    # Basically, try if the next step actions can be done in previous step
    env = copy.deepcopy(env)
    new_actions = []
    i = 0
    cur_step_actions = []
    while i < len(actions):
        if not isinstance(actions[i], list):
            step_action = [actions[i]]
        else:
            step_action = actions[i]
        tmp_step_actions = cur_step_actions + step_action
        if env.verify(tmp_step_actions).success == ExecutionRes.Success:
            cur_step_actions = tmp_step_actions
        else:
            new_actions.append(cur_step_actions)
            env.simulate(cur_step_actions)
            cur_step_actions = step_action
        i += 1
    if cur_step_actions:
        new_actions.append(cur_step_actions)
    return new_actions


SHOW = False


def solve_box_environment(
    outfile=None, env_path=None, randomness_factor=0.15, env_config=None
):
    """
    Solve a Box1Env environment with randomized A* search.

    Args:
        env_path: Path to load environment from, or None to create a new one
        randomness_factor: Factor to determine how much randomness to introduce (0.0-1.0)

    Returns:
        List of action sequences that solve the environment
    """
    # Initialize Ray if not already started
    if not ray.is_initialized():
        ray.init()

    # Create or load environment
    if env_path:
        env = Box1Env.load_json(env_path)
    else:
        if env_config is not None:
            env = Box1Env(**env_config)
        else:
            env = Box1Env(grid_n=2, grid_m=4, num_objects=3)
        env.create()
        if SHOW:
            env.visualize(out_file_path="tmp-initial.png")

    # Run planning algorithm
    action_plans = plan_actions(
        env, max_iterations=1000, randomness_factor=randomness_factor, max_plan_number=1
    )

    if not action_plans:
        print("No solution found!")
        return None

    for action_plan in action_plans:
        action_plan = remove_redundant_move(action_plan)
        action_plan = squeeze_move(env, action_plan)

        # Verify plan
        test_env = copy.deepcopy(env)
        for i, step_actions in enumerate(action_plan):
            if not isinstance(step_actions, list):
                step_actions = [step_actions]
            exec_res = env.verify(step_actions)
            test_env.simulate(step_actions)
            if SHOW:
                test_env.visualize(
                    step_actions,
                    exec_res,
                    out_file_path=f"tmp-step-{i}.png",
                )

        success = test_env.check_final()
        if success:
            out_obj = {
                "env": env.to_json(),
                "plan": object_to_json_serializable(action_plan),
            }
            if outfile is not None:
                outfile.write(json.dumps(out_obj) + "\n")


train_env_config_str = """
name: "box1"
# grid_n: [2, 3, 4, 2, 3, 4]
# grid_m: [2, 3, 4, 2, 3, 4]
# num_objects: [1, 1, 1, 2, 2, 2]
grid_n: [2, 3, 4]
grid_m: [2, 3, 4]
num_objects: [2, 2, 2]
robot_mode: "full"
robot_as: "full"
movement: "full"
"""


large_train_env_config_str = """
name: "box1"
# grid_n: [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6]
# grid_m: [2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6]
# num_objects: [1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6]

# grid_n: [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6]
# grid_m: [3, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6]
# num_objects: [6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6]

grid_n: [2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6]
grid_m: [6, 2, 3, 4, 5, 6, 2, 3, 4, 5, 6, 2, 3, 4, 5, 6, 2, 3, 4, 5, 6, 6, 2, 3, 4, 5, 6, 2, 3, 4, 5, 6, 2, 3, 4, 5, 6, 2, 3, 4, 5, 6, 6, 2, 3, 4, 5, 6, 2, 3, 4, 5, 6, 2, 3, 4, 5, 6, 2, 3, 4, 5, 6, 6, 2, 3, 4, 5, 6, 2, 3, 4, 5, 6, 2, 3, 4, 5, 6, 2, 3, 4, 5, 6, 6, 2, 3, 4, 5, 6, 2, 3, 4, 5, 6, 2, 3, 4, 5, 6, 2, 3, 4, 5, 6, 6, 2, 3, 4, 5, 6, 2, 3, 4, 5, 6, 2, 3, 4, 5, 6, 2, 3, 4, 5, 6]
num_objects: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6]

robot_mode: "full"
robot_as: "full"
movement: "full"
"""


def main(arg_configs):
    # Initialize Ray if not already started
    if not ray.is_initialized():
        ray.init(num_cpus=64, _temp_dir=get_cache_dir())  # Adjust based on your machine

    # config_tuple = [(6, 4, 3), (6, 5, 3), (6, 6, 3)]
    config_tuple = [
        # (4, 4, 4),
        # (2, 6, 6),
        (6, 6, 6),
    ]
    tmp_config = yaml.safe_load(large_train_env_config_str)
    tmp_config["grid_n"] = [x[0] for x in config_tuple]
    tmp_config["grid_m"] = [x[1] for x in config_tuple]
    tmp_config["num_objects"] = [x[2] for x in config_tuple]

    env_configs = generate_configs(tmp_config)
    split_idx = arg_configs.numobj
    out_file_path = f"debuggg-train-large-full-split5-obj@{split_idx}.jsonl"

    # env_configs = env_configs[52:]

    with open(out_file_path, "w") as f:
        for env_config in env_configs:
            if env_config["num_objects"] < split_idx:
                continue
            if env_config["num_objects"] > split_idx:
                break

            seen_envs: Set[str] = set()
            num_generated = 0
            attempts = 0

            total_unique_envs_per_config = min(
                420,
                count_unique_envs(
                    grid_n=env_config["grid_n"],
                    grid_m=env_config["grid_m"],
                    num_objects=env_config["num_objects"],
                ),
            )
            pbar_desc = f"N={env_config['grid_n']}, M={env_config['grid_m']}, Obj={env_config['num_objects']}"
            with tqdm(total=total_unique_envs_per_config, desc=pbar_desc) as pbar:
                while num_generated < total_unique_envs_per_config:
                    while num_generated < total_unique_envs_per_config:
                        attempts += 1
                        seed_everything(42 + attempts)

                        # Create environment
                        env = Box1Env(**env_config)
                        env.create()
                        env_json = json.dumps(env.to_json(), sort_keys=True)

                        if env_json in seen_envs:
                            continue  # Skip if already generated
                        seen_envs.add(env_json)

                        # Solve and save
                        solve_box_environment(
                            outfile=f, env_config=env_config, randomness_factor=0.0
                        )
                        num_generated += 1
                        pbar.update(1)


if __name__ == "__main__":
    from argparse import ArgumentParser

    parser = ArgumentParser()
    parser.add_argument("--numobj", type=int)
    configs = parser.parse_args()
    main(configs)
