From 3329cb899acbc529a62f6a57856ff5ebc5cf226d Mon Sep 17 00:00:00 2001 From: Neel Kant Date: Thu, 28 Aug 2025 15:08:22 -0700 Subject: [PATCH] better connect entities behavior for no new entities placed, better grouped entity behavior, better error messages' --- fle/env/gym_env/environment.py | 14 +- fle/env/gym_env/trajectory_logger.py | 2 +- .../admin/create_agent_characters/server.lua | 5 + fle/env/tools/admin/request_path/server.lua | 18 +- .../tools/agent/connect_entities/client.py | 140 +++++++----- .../connect_entities/groupable_entities.py | 3 +- .../tools/agent/connect_entities/server.lua | 205 +----------------- fle/env/tools/agent/get_entities/client.py | 166 +++++++++++--- fle/env/tools/agent/move_to/client.py | 3 + fle/env/tools/agent/move_to/server.lua | 1 + 10 files changed, 254 insertions(+), 303 deletions(-) diff --git a/fle/env/gym_env/environment.py b/fle/env/gym_env/environment.py index 13953545..2723e65a 100644 --- a/fle/env/gym_env/environment.py +++ b/fle/env/gym_env/environment.py @@ -268,7 +268,6 @@ class FactorioGymEnv(gym.Env): self.last_observation = None # Track last message timestamp for each agent self.last_message_timestamps = {i: 0.0 for i in range(instance.num_agents)} - self._last_production_flows = {} def get_observation( self, agent_idx: int = 0, response: Optional[Response] = None @@ -395,12 +394,10 @@ class FactorioGymEnv(gym.Env): self.reset_instance(GameState.parse_raw(action.game_state.to_raw())) namespace = self.instance.namespaces[agent_idx] - # Use last post_production_flows as pre_production_flows if available - if self._last_production_flows.get(agent_idx) is not None: - production_flows = self._last_production_flows[agent_idx] - else: - production_flows = namespace._get_production_stats() - start_production_flows = ProductionFlows.from_dict(production_flows) + # Calculate fresh production flows at the beginning of the step + start_production_flows = ProductionFlows.from_dict( + namespace._get_production_stats() + ) # Execute the action initial_score, eval_time, result = self.instance.eval( @@ -434,8 +431,6 @@ class FactorioGymEnv(gym.Env): # Get post-execution flows and calculate achievements current_flows = ProductionFlows.from_dict(namespace._get_production_stats()) achievements = calculate_achievements(start_production_flows, current_flows) - # Store for next step - self._last_production_flows[agent_idx] = current_flows.__dict__ # Create response object for observation response = Response( @@ -481,7 +476,6 @@ class FactorioGymEnv(gym.Env): state: Optional[GameState] to reset to. If None, resets to initial state. """ self.instance.reset(state) - self._last_production_flows = {i: None for i in range(self.instance.num_agents)} def reset( self, options: Optional[Dict[str, Any]] = None, seed: Optional[int] = None diff --git a/fle/env/gym_env/trajectory_logger.py b/fle/env/gym_env/trajectory_logger.py index ec10b5f5..403642c8 100644 --- a/fle/env/gym_env/trajectory_logger.py +++ b/fle/env/gym_env/trajectory_logger.py @@ -109,7 +109,7 @@ class TrajectoryLogger: raw_text = agent.observation_formatter.format_raw_text(observation.raw_text) for line in raw_text.split("\n"): - if "Error" in line: + if "Error" in line or "Exception" in line: print("raw_text Error:", line) def add_iteration_time(self, iteration_time: float): diff --git a/fle/env/tools/admin/create_agent_characters/server.lua b/fle/env/tools/admin/create_agent_characters/server.lua index 6d16ac36..28dcf8c5 100644 --- a/fle/env/tools/admin/create_agent_characters/server.lua +++ b/fle/env/tools/admin/create_agent_characters/server.lua @@ -31,6 +31,11 @@ end -- Create agent characters script global.actions.create_agent_characters = function(num_agents) + -- delete all character entities on the surface + for _, entity in pairs(game.surfaces[1].find_entities_filtered{type = "character"}) do + entity.destroy() + end + -- Initialize agent characters table -- Destroy existing agent characters if they exist if global.agent_characters then diff --git a/fle/env/tools/admin/request_path/server.lua b/fle/env/tools/admin/request_path/server.lua index 1f31d345..e319acc9 100644 --- a/fle/env/tools/admin/request_path/server.lua +++ b/fle/env/tools/admin/request_path/server.lua @@ -64,7 +64,6 @@ global.actions.request_path = function(player_index, start_x, start_y, goal_x, g allow_paths_through_own_entities = allow_paths_through_own_entities } } - local request_id = surface.request_path(path_request) global.clearance_entities[request_id] = clearance_entities @@ -95,24 +94,27 @@ end --end) script.on_event(defines.events.on_script_path_request_finished, function(event) + game.print("Path request finished for ID: " .. event.id) local request_data = global.path_requests[event.id] if not request_data then + game.print("No request data found for ID: " .. event.id) return end - local player = global.agent_characters[request_data] - if not player then - return - end + -- local player = global.agent_characters[request_data] + -- if not player then + -- game.print("No player found for request ID: " .. event.id) + -- return + -- end if event.path then global.paths[event.id] = event.path - -- log("Path found for request ID: " .. event.id) + game.print("Path found for request ID: " .. event.id) elseif event.try_again_later then global.paths[event.id] = "busy" - -- log("Pathfinder busy for request ID: " .. event.id) + game.print("Pathfinder busy for request ID: " .. event.id) else global.paths[event.id] = "not_found" - -- log("Path not found for request ID: " .. event.id) + game.print("Path not found for request ID: " .. event.id) end end) \ No newline at end of file diff --git a/fle/env/tools/agent/connect_entities/client.py b/fle/env/tools/agent/connect_entities/client.py index de2612a5..8e189b4b 100644 --- a/fle/env/tools/agent/connect_entities/client.py +++ b/fle/env/tools/agent/connect_entities/client.py @@ -194,7 +194,7 @@ class ConnectEntities(Tool): # Resolve connection type if not provided if not connection_types: - self._infer_connection_type(source, target) + connection_types = {self._infer_connection_type(source, target)} else: valid = self._validate_connection_types(connection_types) if not valid: @@ -237,7 +237,25 @@ class ConnectEntities(Tool): if isinstance(target, (Entity, EntityGroup)) else None, ) - return connection[0] if not dry_run else connection + if not dry_run: + if connection and len(connection) > 0: + return connection[0] + else: + # No entities were created but pathing was successful - find existing connecting entity + existing_entity = self._find_existing_connecting_entity( + target_pos, list(connection_types)[0] + ) + return ( + existing_entity + if existing_entity + else ( + target + if isinstance(target, (Entity, EntityGroup)) + else target_pos + ) + ) + else: + return connection except Exception as e: last_exception = e pass @@ -268,7 +286,22 @@ class ConnectEntities(Tool): if isinstance(target, (Entity, EntityGroup)) else None, ) - return connection[0] + if connection and len(connection) > 0: + return connection[0] + else: + # No entities were created but pathing was successful - find existing connecting entity + existing_entity = self._find_existing_connecting_entity( + target_pos, list(connection_types)[0] + ) + return ( + existing_entity + if existing_entity + else ( + target + if isinstance(target, (Entity, EntityGroup)) + else target_pos + ) + ) except Exception: continue @@ -734,8 +767,8 @@ class ConnectEntities(Tool): def _process_belt_groups( self, groupable_entities: List[Entity], - source_entity: Optional[Entity], - target_entity: Optional[Entity], + source_entity: Optional[Union[Entity, EntityGroup]], + target_entity: Optional[Union[Entity, EntityGroup]], source_pos: Position, ) -> List[BeltGroup]: """Process transport belt groups""" @@ -764,23 +797,7 @@ class ConnectEntities(Tool): self.rotate_end_belt_to_face(entity_groups[0], target_entity) # self.rotate_final_belt_when_connecting_groups(entity_groups[0], source_entity) - # Get final groups and filter to relevant one - entity_groups = self.get_entities( - { - Prototype.TransportBelt, - Prototype.ExpressTransportBelt, - Prototype.FastTransportBelt, - Prototype.UndergroundBelt, - Prototype.FastUndergroundBelt, - Prototype.ExpressUndergroundBelt, - }, - source_pos, - ) - - for group in entity_groups: - if source_pos in [entity.position for entity in group.belts]: - return cast(List[BeltGroup], [group]) - + # Return the properly grouped entities return cast(List[BeltGroup], entity_groups) def _update_belt_group( @@ -901,16 +918,15 @@ class ConnectEntities(Tool): self, groupable_entities: List[Entity], source_pos: Position ) -> List[PipeGroup]: """Process pipe groups""" - entity_groups = self.get_entities( - {Prototype.Pipe, Prototype.UndergroundPipe}, source_pos - ) + # Group the passed entities first + entity_groups = agglomerate_groupable_entities(groupable_entities) + # Deduplicate pipes in groups for group in entity_groups: - group.pipes = _deduplicate_entities(group.pipes) - if source_pos in [entity.position for entity in group.pipes]: - return [group] + if hasattr(group, "pipes"): + group.pipes = _deduplicate_entities(group.pipes) - return entity_groups + return cast(List[PipeGroup], entity_groups) def _process_groups( self, @@ -928,21 +944,6 @@ class ConnectEntities(Tool): return entity_groups - def _process_pipe_groups( - self, groupable_entities: List[Entity], source_pos: Position - ) -> List[PipeGroup]: - """Process pipe groups""" - entity_groups = self.get_entities( - {Prototype.Pipe, Prototype.UndergroundPipe}, source_pos - ) - - for group in entity_groups: - group.pipes = _deduplicate_entities(group.pipes) - if source_pos in [entity.position for entity in group.pipes]: - return [group] - - return entity_groups - def _attempt_to_get_entity( self, position: Position, get_connectors: bool = False ) -> Union[Position, Entity, EntityGroup]: @@ -978,17 +979,9 @@ class ConnectEntities(Tool): self, groupable_entities: List[Entity], source_pos: Position ) -> List[ElectricityGroup]: """Process power pole groups""" - return cast( - List[ElectricityGroup], - self.get_entities( - { - Prototype.SmallElectricPole, - Prototype.BigElectricPole, - Prototype.MediumElectricPole, - }, - source_pos, - ), - ) + # Group the passed entities first + entity_groups = agglomerate_groupable_entities(groupable_entities) + return cast(List[ElectricityGroup], entity_groups) def _adjust_belt_position( self, pos: Position, entity: Optional[Entity] @@ -1249,6 +1242,43 @@ class ConnectEntities(Tool): entities = self.get_entities(position=pos, radius=radius) return bool(entities) + def _find_existing_connecting_entity( + self, target_pos: Position, connection_type: Prototype + ) -> Optional[Union[Entity, EntityGroup]]: + """ + Find existing entity of the connection type closest to the target position. + This is used when no new entities were created but connection was successful. + """ + # Determine search radius based on connection type + if connection_type in ( + Prototype.SmallElectricPole, + Prototype.MediumElectricPole, + Prototype.BigElectricPole, + ): + search_radius = 10 # Power poles have larger reach + else: + search_radius = 5 # Belts, pipes have smaller reach + + # Get entities of the connection type near target + entities = self.get_entities(connection_type, target_pos, search_radius) + + if not entities: + return None + + # Find closest entity to target position + closest_entity = None + closest_distance = float("inf") + + for entity in entities: + entity_pos = entity.position if hasattr(entity, "position") else None + if entity_pos: + distance = target_pos.distance(entity_pos) + if distance < closest_distance: + closest_distance = distance + closest_entity = entity + + return closest_entity + def pickup_entities( self, path_data: dict, diff --git a/fle/env/tools/agent/connect_entities/groupable_entities.py b/fle/env/tools/agent/connect_entities/groupable_entities.py index 7e680abd..70fb7ea8 100644 --- a/fle/env/tools/agent/connect_entities/groupable_entities.py +++ b/fle/env/tools/agent/connect_entities/groupable_entities.py @@ -39,8 +39,7 @@ def _construct_group( id: int, entities: List[Entity], prototype: Prototype, position: Position ) -> EntityGroup: if prototype == Prototype.TransportBelt or isinstance(entities[0], TransportBelt): - if len(entities) == 1: - return entities[0] + # Always return BeltGroup for consistent return types, even for single belts inputs = [c for c in entities if c.is_source] outputs = [c for c in entities if c.is_terminus] inventory = Inventory() diff --git a/fle/env/tools/agent/connect_entities/server.lua b/fle/env/tools/agent/connect_entities/server.lua index e521ad7c..cac4a546 100644 --- a/fle/env/tools/agent/connect_entities/server.lua +++ b/fle/env/tools/agent/connect_entities/server.lua @@ -58,153 +58,6 @@ local function are_positions_in_same_network(pos1, pos2) return network1 and network2 and network1 == network2 end -local function is_large_entity(entity) - if not entity then return false end - - -- Check by size (3x3 or larger) - local prototype = game.entity_prototypes[entity.name] - if prototype then - local collision_box = prototype.collision_box - local width = math.abs(collision_box.right_bottom.x - collision_box.left_top.x) - local height = math.abs(collision_box.right_bottom.y - collision_box.left_top.y) - return width >= 2.6 and height >= 2.6 -- 3x3 entities have collision box ~2.8x2.8 - end - - return false -end - -local function get_reserved_positions_around_entity(entity) - if not is_large_entity(entity) then - return {} - end - - local pos = entity.position - local reserved = {} - - -- Middle adjacent positions (reserved for inserters, chests, etc.) - local middle_adjacent = { - {x = pos.x, y = pos.y - 2}, -- North (middle) - {x = pos.x + 2, y = pos.y}, -- East (middle) - {x = pos.x, y = pos.y + 2}, -- South (middle) - {x = pos.x - 2, y = pos.y} -- West (middle) - } - - for _, reserved_pos in pairs(middle_adjacent) do - table.insert(reserved, reserved_pos) - end - - return reserved -end - -local function get_corner_positions_around_entity(entity) - if not is_large_entity(entity) then - return {} - end - - local pos = entity.position - local corners = { - {x = pos.x - 2, y = pos.y - 2}, -- Northwest corner - {x = pos.x + 2, y = pos.y - 2}, -- Northeast corner - {x = pos.x + 2, y = pos.y + 2}, -- Southeast corner - {x = pos.x - 2, y = pos.y + 2} -- Southwest corner - } - - return corners -end - -local function is_position_reserved_for_poles(position) - -- Find large entities within reasonable range - local nearby_entities = game.surfaces[1].find_entities_filtered{ - position = position, - radius = 4, -- Search within 4 tiles - force = "player" - } - - for _, entity in pairs(nearby_entities) do - if is_large_entity(entity) then - local reserved_positions = get_reserved_positions_around_entity(entity) - - -- Check if current position is in reserved list - for _, reserved_pos in pairs(reserved_positions) do - local distance = math.sqrt( - (position.x - reserved_pos.x)^2 + (position.y - reserved_pos.y)^2 - ) - if distance < 0.5 then -- Within half a tile - return true, entity -- Position is reserved - end - end - end - end - - return false, nil -end - -local function find_alternative_pole_position(ideal_position, connection_type) - local search_radius = wire_reach[connection_type] or 4 - local wire_reach_distance = wire_reach[connection_type] or 4 - - -- First, try corner positions of nearby large entities - local nearby_entities = game.surfaces[1].find_entities_filtered{ - position = ideal_position, - radius = search_radius, - force = "player" - } - - for _, entity in pairs(nearby_entities) do - if is_large_entity(entity) then - local corners = get_corner_positions_around_entity(entity) - - for _, corner_pos in pairs(corners) do - local distance_to_ideal = math.sqrt( - (corner_pos.x - ideal_position.x)^2 + (corner_pos.y - ideal_position.y)^2 - ) - - -- Check if corner is within wire reach of ideal position - if distance_to_ideal <= wire_reach_distance then - local can_place = game.surfaces[1].can_place_entity({ - name = connection_type, - position = corner_pos, - force = "player" - }) - - if can_place then - return corner_pos - end - end - end - end - end - - -- Fallback: try positions in expanding circles, avoiding reserved areas - for radius = 1, search_radius, 0.5 do - for angle = 0, 7 do -- 8 directions - local test_pos = { - x = ideal_position.x + radius * math.cos(angle * math.pi / 4), - y = ideal_position.y + radius * math.sin(angle * math.pi / 4) - } - - -- Round to grid - test_pos.x = math.floor(test_pos.x * 2) / 2 - test_pos.y = math.floor(test_pos.y * 2) / 2 - - local is_reserved, blocking_entity = is_position_reserved_for_poles(test_pos) - if not is_reserved then - local can_place = game.surfaces[1].can_place_entity({ - name = connection_type, - position = test_pos, - force = "player" - }) - - if can_place then - return test_pos - end - end - end - end - - return nil -- No suitable position found -end - local function is_position_saturated(position, reach) -- Get nearby poles local nearby_poles = game.surfaces[1].find_entities_filtered{ @@ -606,43 +459,13 @@ local function place_at_position(player, connection_type, current_position, dir, if is_electric_pole then if is_position_saturated(current_position, wire_reach[connection_type]) then + return -- No need to place another pole end - -- Check if the position is reserved for non-pole entities - local is_reserved, blocking_entity = is_position_reserved_for_poles(current_position) - if is_reserved then - -- Try to find an alternative position (preferably at corners) - placement_position = find_alternative_pole_position(current_position, connection_type) - if not placement_position then - -- Fallback to original method but avoid reserved areas - local attempts = 0 - local max_attempts = 10 - - repeat - placement_position = game.surfaces[1].find_non_colliding_position(connection_type, current_position, 2 + attempts * 0.5, 0.1) - attempts = attempts + 1 - - if placement_position then - local still_reserved, _ = is_position_reserved_for_poles(placement_position) - if not still_reserved then - break -- Found a good position - else - placement_position = nil -- Keep searching - end - end - until placement_position or attempts >= max_attempts - - if not placement_position then - error("Cannot find suitable position to place " .. connection_type .. " without blocking important adjacent slots around large entities") - end - end - else - -- Original logic for non-reserved positions - placement_position = game.surfaces[1].find_non_colliding_position(connection_type, current_position, 2, 0.1) - if not placement_position then - error("Cannot find suitable position to place " .. connection_type) - end + placement_position = game.surfaces[1].find_non_colliding_position(connection_type, current_position, 2, 0.1) + if not placement_position then + error("Cannot find suitable position to place " .. connection_type) end else local entities = game.surfaces[1].find_entities_filtered{ @@ -925,16 +748,8 @@ local function connect_entities(player_index, source_x, source_y, target_x, targ -- game.print("Path length "..#raw_path) -- game.print(serpent.line(start_position).." - "..serpent.line(end_position)) - if not raw_path then - error("No path found for handle " .. path_handle .. ". Pathfinding may have failed.") - elseif raw_path == "not_found" then - error("Pathfinding failed: no valid path exists between source and target positions.") - elseif raw_path == "busy" then - error("Pathfinder is busy, try again later.") - elseif type(raw_path) ~= "table" then - error("Invalid path type: " .. type(raw_path) .. " (value: " .. serpent.line(raw_path) .. ")") - elseif #raw_path == 0 then - error("Empty path returned from pathfinder.") + if not raw_path or type(raw_path) ~= "table" or #raw_path == 0 then + error("Invalid path: " .. serpent.line(path)) end -- game.print("Normalising", {print_skip=defines.print_skip.never}) @@ -978,14 +793,6 @@ local function connect_entities(player_index, source_x, source_y, target_x, targ -- Get source and target entities local source_entity = global.utils.get_closest_entity(player, {x = source_x, y = source_y}) local target_entity = global.utils.get_closest_entity(player, {x = target_x, y = target_y}) - - -- Validate that entities were found - if not source_entity then - error("No entity found at source position x=" .. source_x .. " y=" .. source_y .. " within radius 3") - end - if not target_entity then - error("No entity found at target position x=" .. target_x .. " y=" .. target_y .. " within radius 3") - end if #connection_types == 1 and connection_types[1] == 'pipe-to-ground' then diff --git a/fle/env/tools/agent/get_entities/client.py b/fle/env/tools/agent/get_entities/client.py index 1cfe5e1c..0e002ba4 100644 --- a/fle/env/tools/agent/get_entities/client.py +++ b/fle/env/tools/agent/get_entities/client.py @@ -32,10 +32,49 @@ class GetEntities(Tool): if not isinstance(entities, Set): entities = set([entities]) - # Serialize entity_names as a string + # Handle group prototypes by expanding them to their component types + expanded_entities = set() + group_requests = set() + + for entity in entities: + if entity == Prototype.BeltGroup: + # For belt groups, search for all belt types + belt_types = { + Prototype.TransportBelt, + Prototype.FastTransportBelt, + Prototype.ExpressTransportBelt, + Prototype.UndergroundBelt, + Prototype.FastUndergroundBelt, + Prototype.ExpressUndergroundBelt, + } + expanded_entities.update(belt_types) + group_requests.add(Prototype.BeltGroup) + elif entity == Prototype.PipeGroup: + # For pipe groups, search for pipe types + pipe_types = {Prototype.Pipe, Prototype.UndergroundPipe} + expanded_entities.update(pipe_types) + group_requests.add(Prototype.PipeGroup) + elif entity == Prototype.ElectricityGroup: + # For electricity groups, search for pole types + pole_types = { + Prototype.SmallElectricPole, + Prototype.MediumElectricPole, + Prototype.BigElectricPole, + } + expanded_entities.update(pole_types) + group_requests.add(Prototype.ElectricityGroup) + else: + expanded_entities.add(entity) + + # Use expanded entities for the Lua query + query_entities = expanded_entities + + # Serialize entity_names as a string entity_names = ( - "[" + ",".join([f'"{entity.value[0]}"' for entity in entities]) + "]" - if entities + "[" + + ",".join([f'"{entity.value[0]}"' for entity in query_entities]) + + "]" + if query_entities else "[]" ) @@ -104,15 +143,24 @@ class GetEntities(Tool): except Exception as e1: print(f"Could not create {entity_data['name']} object: {e1}") - # Only group entities if user is looking for group types or no specific filter - should_group = not entities or any( - proto - in { - Prototype.ElectricityGroup, - Prototype.PipeGroup, - Prototype.BeltGroup, - } - for proto in entities + # Group entities when: + # 1. User explicitly requests group types, OR + # 2. User provides a position filter (suggesting they want nearby entities grouped), OR + # 3. No specific entities requested (get all entities - should be grouped) + should_group = ( + not entities # No filter = group everything + or any( + proto + in { + Prototype.ElectricityGroup, + Prototype.PipeGroup, + Prototype.BeltGroup, + } + for proto in entities + ) # Explicit group request + or ( + entities and position is not None + ) # Individual entities with position filter = group for convenience ) if should_group: @@ -177,22 +225,84 @@ class GetEntities(Tool): if hasattr(entity, "prototype") and entity.prototype in entities: filtered_entities.append(entity) elif hasattr(entity, "__class__"): - # Check for group types - if ( - entity.__class__.__name__ == "ElectricityGroup" - and Prototype.ElectricityGroup in entities - ): - filtered_entities.append(entity) - elif ( - entity.__class__.__name__ == "PipeGroup" - and Prototype.PipeGroup in entities - ): - filtered_entities.append(entity) - elif ( - entity.__class__.__name__ == "BeltGroup" - and Prototype.BeltGroup in entities - ): - filtered_entities.append(entity) + # Handle group entities + if entity.__class__.__name__ == "ElectricityGroup": + pole_types = { + Prototype.SmallElectricPole, + Prototype.MediumElectricPole, + Prototype.BigElectricPole, + } + if Prototype.ElectricityGroup in group_requests: + # Explicit group request - return the group + filtered_entities.append(entity) + elif ( + any(pole_type in entities for pole_type in pole_types) + and position is not None + ): + # Individual poles requested with position - return group for convenience + filtered_entities.append(entity) + elif ( + any(pole_type in entities for pole_type in pole_types) + and position is None + ): + # Individual poles requested without position - extract individual poles from group + for pole in entity.poles: + if ( + hasattr(pole, "prototype") + and pole.prototype in entities + ): + filtered_entities.append(pole) + elif entity.__class__.__name__ == "PipeGroup": + pipe_types = {Prototype.Pipe, Prototype.UndergroundPipe} + if Prototype.PipeGroup in group_requests: + # Explicit group request - return the group + filtered_entities.append(entity) + elif ( + any(pipe_type in entities for pipe_type in pipe_types) + and position is not None + ): + # Individual pipes requested with position - return group for convenience + filtered_entities.append(entity) + elif ( + any(pipe_type in entities for pipe_type in pipe_types) + and position is None + ): + # Individual pipes requested without position - extract individual pipes from group + for pipe in entity.pipes: + if ( + hasattr(pipe, "prototype") + and pipe.prototype in entities + ): + filtered_entities.append(pipe) + elif entity.__class__.__name__ == "BeltGroup": + belt_types = { + Prototype.TransportBelt, + Prototype.FastTransportBelt, + Prototype.ExpressTransportBelt, + Prototype.UndergroundBelt, + Prototype.FastUndergroundBelt, + Prototype.ExpressUndergroundBelt, + } + if Prototype.BeltGroup in group_requests: + # Explicit group request - return the group + filtered_entities.append(entity) + elif ( + any(belt_type in entities for belt_type in belt_types) + and position is not None + ): + # Individual belts requested with position - return group for convenience + filtered_entities.append(entity) + elif ( + any(belt_type in entities for belt_type in belt_types) + and position is None + ): + # Individual belts requested without position - extract individual belts from group + for belt in entity.belts: + if ( + hasattr(belt, "prototype") + and belt.prototype in entities + ): + filtered_entities.append(belt) elif entity.__class__.__name__ == "WallGroup": # WallGroup doesn't have a corresponding Prototype, but include if present filtered_entities.append(entity) diff --git a/fle/env/tools/agent/move_to/client.py b/fle/env/tools/agent/move_to/client.py index 3bf610fb..d229b26c 100644 --- a/fle/env/tools/agent/move_to/client.py +++ b/fle/env/tools/agent/move_to/client.py @@ -76,6 +76,9 @@ class MoveTo(Tool): if isinstance(response, int) and response == 0: raise Exception("Could not move.") + if isinstance(response, str): + raise Exception(f"Could not move. {response}") + if response == "trailing" or response == "leading": raise Exception("Could not lay entity, perhaps a typo?") diff --git a/fle/env/tools/agent/move_to/server.lua b/fle/env/tools/agent/move_to/server.lua index a03dce61..05324015 100644 --- a/fle/env/tools/agent/move_to/server.lua +++ b/fle/env/tools/agent/move_to/server.lua @@ -25,6 +25,7 @@ end global.actions.move_to = function(player_index, path_handle, trailing_entity, is_trailing) --local player = global.agent_characters[player_index] local player = global.agent_characters[player_index] + game.print("Moving to path with handle: " .. path_handle) local path = global.paths[path_handle] local surface = player.surface