better connect entities behavior for no new entities placed, better grouped entity behavior, better error messages'

This commit is contained in:
Neel Kant
2025-08-28 15:08:22 -07:00
parent b43ae1daba
commit 3329cb899a
10 changed files with 254 additions and 303 deletions

View File

@@ -268,7 +268,6 @@ class FactorioGymEnv(gym.Env):
self.last_observation = None
# Track last message timestamp for each agent
self.last_message_timestamps = {i: 0.0 for i in range(instance.num_agents)}
self._last_production_flows = {}
def get_observation(
self, agent_idx: int = 0, response: Optional[Response] = None
@@ -395,12 +394,10 @@ class FactorioGymEnv(gym.Env):
self.reset_instance(GameState.parse_raw(action.game_state.to_raw()))
namespace = self.instance.namespaces[agent_idx]
# Use last post_production_flows as pre_production_flows if available
if self._last_production_flows.get(agent_idx) is not None:
production_flows = self._last_production_flows[agent_idx]
else:
production_flows = namespace._get_production_stats()
start_production_flows = ProductionFlows.from_dict(production_flows)
# Calculate fresh production flows at the beginning of the step
start_production_flows = ProductionFlows.from_dict(
namespace._get_production_stats()
)
# Execute the action
initial_score, eval_time, result = self.instance.eval(
@@ -434,8 +431,6 @@ class FactorioGymEnv(gym.Env):
# Get post-execution flows and calculate achievements
current_flows = ProductionFlows.from_dict(namespace._get_production_stats())
achievements = calculate_achievements(start_production_flows, current_flows)
# Store for next step
self._last_production_flows[agent_idx] = current_flows.__dict__
# Create response object for observation
response = Response(
@@ -481,7 +476,6 @@ class FactorioGymEnv(gym.Env):
state: Optional[GameState] to reset to. If None, resets to initial state.
"""
self.instance.reset(state)
self._last_production_flows = {i: None for i in range(self.instance.num_agents)}
def reset(
self, options: Optional[Dict[str, Any]] = None, seed: Optional[int] = None

View File

@@ -109,7 +109,7 @@ class TrajectoryLogger:
raw_text = agent.observation_formatter.format_raw_text(observation.raw_text)
for line in raw_text.split("\n"):
if "Error" in line:
if "Error" in line or "Exception" in line:
print("raw_text Error:", line)
def add_iteration_time(self, iteration_time: float):

View File

@@ -31,6 +31,11 @@ end
-- Create agent characters script
global.actions.create_agent_characters = function(num_agents)
-- delete all character entities on the surface
for _, entity in pairs(game.surfaces[1].find_entities_filtered{type = "character"}) do
entity.destroy()
end
-- Initialize agent characters table
-- Destroy existing agent characters if they exist
if global.agent_characters then

View File

@@ -64,7 +64,6 @@ global.actions.request_path = function(player_index, start_x, start_y, goal_x, g
allow_paths_through_own_entities = allow_paths_through_own_entities
}
}
local request_id = surface.request_path(path_request)
global.clearance_entities[request_id] = clearance_entities
@@ -95,24 +94,27 @@ end
--end)
script.on_event(defines.events.on_script_path_request_finished, function(event)
game.print("Path request finished for ID: " .. event.id)
local request_data = global.path_requests[event.id]
if not request_data then
game.print("No request data found for ID: " .. event.id)
return
end
local player = global.agent_characters[request_data]
if not player then
return
end
-- local player = global.agent_characters[request_data]
-- if not player then
-- game.print("No player found for request ID: " .. event.id)
-- return
-- end
if event.path then
global.paths[event.id] = event.path
-- log("Path found for request ID: " .. event.id)
game.print("Path found for request ID: " .. event.id)
elseif event.try_again_later then
global.paths[event.id] = "busy"
-- log("Pathfinder busy for request ID: " .. event.id)
game.print("Pathfinder busy for request ID: " .. event.id)
else
global.paths[event.id] = "not_found"
-- log("Path not found for request ID: " .. event.id)
game.print("Path not found for request ID: " .. event.id)
end
end)

View File

@@ -194,7 +194,7 @@ class ConnectEntities(Tool):
# Resolve connection type if not provided
if not connection_types:
self._infer_connection_type(source, target)
connection_types = {self._infer_connection_type(source, target)}
else:
valid = self._validate_connection_types(connection_types)
if not valid:
@@ -237,7 +237,25 @@ class ConnectEntities(Tool):
if isinstance(target, (Entity, EntityGroup))
else None,
)
return connection[0] if not dry_run else connection
if not dry_run:
if connection and len(connection) > 0:
return connection[0]
else:
# No entities were created but pathing was successful - find existing connecting entity
existing_entity = self._find_existing_connecting_entity(
target_pos, list(connection_types)[0]
)
return (
existing_entity
if existing_entity
else (
target
if isinstance(target, (Entity, EntityGroup))
else target_pos
)
)
else:
return connection
except Exception as e:
last_exception = e
pass
@@ -268,7 +286,22 @@ class ConnectEntities(Tool):
if isinstance(target, (Entity, EntityGroup))
else None,
)
return connection[0]
if connection and len(connection) > 0:
return connection[0]
else:
# No entities were created but pathing was successful - find existing connecting entity
existing_entity = self._find_existing_connecting_entity(
target_pos, list(connection_types)[0]
)
return (
existing_entity
if existing_entity
else (
target
if isinstance(target, (Entity, EntityGroup))
else target_pos
)
)
except Exception:
continue
@@ -734,8 +767,8 @@ class ConnectEntities(Tool):
def _process_belt_groups(
self,
groupable_entities: List[Entity],
source_entity: Optional[Entity],
target_entity: Optional[Entity],
source_entity: Optional[Union[Entity, EntityGroup]],
target_entity: Optional[Union[Entity, EntityGroup]],
source_pos: Position,
) -> List[BeltGroup]:
"""Process transport belt groups"""
@@ -764,23 +797,7 @@ class ConnectEntities(Tool):
self.rotate_end_belt_to_face(entity_groups[0], target_entity)
# self.rotate_final_belt_when_connecting_groups(entity_groups[0], source_entity)
# Get final groups and filter to relevant one
entity_groups = self.get_entities(
{
Prototype.TransportBelt,
Prototype.ExpressTransportBelt,
Prototype.FastTransportBelt,
Prototype.UndergroundBelt,
Prototype.FastUndergroundBelt,
Prototype.ExpressUndergroundBelt,
},
source_pos,
)
for group in entity_groups:
if source_pos in [entity.position for entity in group.belts]:
return cast(List[BeltGroup], [group])
# Return the properly grouped entities
return cast(List[BeltGroup], entity_groups)
def _update_belt_group(
@@ -901,16 +918,15 @@ class ConnectEntities(Tool):
self, groupable_entities: List[Entity], source_pos: Position
) -> List[PipeGroup]:
"""Process pipe groups"""
entity_groups = self.get_entities(
{Prototype.Pipe, Prototype.UndergroundPipe}, source_pos
)
# Group the passed entities first
entity_groups = agglomerate_groupable_entities(groupable_entities)
# Deduplicate pipes in groups
for group in entity_groups:
group.pipes = _deduplicate_entities(group.pipes)
if source_pos in [entity.position for entity in group.pipes]:
return [group]
if hasattr(group, "pipes"):
group.pipes = _deduplicate_entities(group.pipes)
return entity_groups
return cast(List[PipeGroup], entity_groups)
def _process_groups(
self,
@@ -928,21 +944,6 @@ class ConnectEntities(Tool):
return entity_groups
def _process_pipe_groups(
self, groupable_entities: List[Entity], source_pos: Position
) -> List[PipeGroup]:
"""Process pipe groups"""
entity_groups = self.get_entities(
{Prototype.Pipe, Prototype.UndergroundPipe}, source_pos
)
for group in entity_groups:
group.pipes = _deduplicate_entities(group.pipes)
if source_pos in [entity.position for entity in group.pipes]:
return [group]
return entity_groups
def _attempt_to_get_entity(
self, position: Position, get_connectors: bool = False
) -> Union[Position, Entity, EntityGroup]:
@@ -978,17 +979,9 @@ class ConnectEntities(Tool):
self, groupable_entities: List[Entity], source_pos: Position
) -> List[ElectricityGroup]:
"""Process power pole groups"""
return cast(
List[ElectricityGroup],
self.get_entities(
{
Prototype.SmallElectricPole,
Prototype.BigElectricPole,
Prototype.MediumElectricPole,
},
source_pos,
),
)
# Group the passed entities first
entity_groups = agglomerate_groupable_entities(groupable_entities)
return cast(List[ElectricityGroup], entity_groups)
def _adjust_belt_position(
self, pos: Position, entity: Optional[Entity]
@@ -1249,6 +1242,43 @@ class ConnectEntities(Tool):
entities = self.get_entities(position=pos, radius=radius)
return bool(entities)
def _find_existing_connecting_entity(
self, target_pos: Position, connection_type: Prototype
) -> Optional[Union[Entity, EntityGroup]]:
"""
Find existing entity of the connection type closest to the target position.
This is used when no new entities were created but connection was successful.
"""
# Determine search radius based on connection type
if connection_type in (
Prototype.SmallElectricPole,
Prototype.MediumElectricPole,
Prototype.BigElectricPole,
):
search_radius = 10 # Power poles have larger reach
else:
search_radius = 5 # Belts, pipes have smaller reach
# Get entities of the connection type near target
entities = self.get_entities(connection_type, target_pos, search_radius)
if not entities:
return None
# Find closest entity to target position
closest_entity = None
closest_distance = float("inf")
for entity in entities:
entity_pos = entity.position if hasattr(entity, "position") else None
if entity_pos:
distance = target_pos.distance(entity_pos)
if distance < closest_distance:
closest_distance = distance
closest_entity = entity
return closest_entity
def pickup_entities(
self,
path_data: dict,

View File

@@ -39,8 +39,7 @@ def _construct_group(
id: int, entities: List[Entity], prototype: Prototype, position: Position
) -> EntityGroup:
if prototype == Prototype.TransportBelt or isinstance(entities[0], TransportBelt):
if len(entities) == 1:
return entities[0]
# Always return BeltGroup for consistent return types, even for single belts
inputs = [c for c in entities if c.is_source]
outputs = [c for c in entities if c.is_terminus]
inventory = Inventory()

View File

@@ -58,153 +58,6 @@ local function are_positions_in_same_network(pos1, pos2)
return network1 and network2 and network1 == network2
end
local function is_large_entity(entity)
if not entity then return false end
-- Check by size (3x3 or larger)
local prototype = game.entity_prototypes[entity.name]
if prototype then
local collision_box = prototype.collision_box
local width = math.abs(collision_box.right_bottom.x - collision_box.left_top.x)
local height = math.abs(collision_box.right_bottom.y - collision_box.left_top.y)
return width >= 2.6 and height >= 2.6 -- 3x3 entities have collision box ~2.8x2.8
end
return false
end
local function get_reserved_positions_around_entity(entity)
if not is_large_entity(entity) then
return {}
end
local pos = entity.position
local reserved = {}
-- Middle adjacent positions (reserved for inserters, chests, etc.)
local middle_adjacent = {
{x = pos.x, y = pos.y - 2}, -- North (middle)
{x = pos.x + 2, y = pos.y}, -- East (middle)
{x = pos.x, y = pos.y + 2}, -- South (middle)
{x = pos.x - 2, y = pos.y} -- West (middle)
}
for _, reserved_pos in pairs(middle_adjacent) do
table.insert(reserved, reserved_pos)
end
return reserved
end
local function get_corner_positions_around_entity(entity)
if not is_large_entity(entity) then
return {}
end
local pos = entity.position
local corners = {
{x = pos.x - 2, y = pos.y - 2}, -- Northwest corner
{x = pos.x + 2, y = pos.y - 2}, -- Northeast corner
{x = pos.x + 2, y = pos.y + 2}, -- Southeast corner
{x = pos.x - 2, y = pos.y + 2} -- Southwest corner
}
return corners
end
local function is_position_reserved_for_poles(position)
-- Find large entities within reasonable range
local nearby_entities = game.surfaces[1].find_entities_filtered{
position = position,
radius = 4, -- Search within 4 tiles
force = "player"
}
for _, entity in pairs(nearby_entities) do
if is_large_entity(entity) then
local reserved_positions = get_reserved_positions_around_entity(entity)
-- Check if current position is in reserved list
for _, reserved_pos in pairs(reserved_positions) do
local distance = math.sqrt(
(position.x - reserved_pos.x)^2 + (position.y - reserved_pos.y)^2
)
if distance < 0.5 then -- Within half a tile
return true, entity -- Position is reserved
end
end
end
end
return false, nil
end
local function find_alternative_pole_position(ideal_position, connection_type)
local search_radius = wire_reach[connection_type] or 4
local wire_reach_distance = wire_reach[connection_type] or 4
-- First, try corner positions of nearby large entities
local nearby_entities = game.surfaces[1].find_entities_filtered{
position = ideal_position,
radius = search_radius,
force = "player"
}
for _, entity in pairs(nearby_entities) do
if is_large_entity(entity) then
local corners = get_corner_positions_around_entity(entity)
for _, corner_pos in pairs(corners) do
local distance_to_ideal = math.sqrt(
(corner_pos.x - ideal_position.x)^2 + (corner_pos.y - ideal_position.y)^2
)
-- Check if corner is within wire reach of ideal position
if distance_to_ideal <= wire_reach_distance then
local can_place = game.surfaces[1].can_place_entity({
name = connection_type,
position = corner_pos,
force = "player"
})
if can_place then
return corner_pos
end
end
end
end
end
-- Fallback: try positions in expanding circles, avoiding reserved areas
for radius = 1, search_radius, 0.5 do
for angle = 0, 7 do -- 8 directions
local test_pos = {
x = ideal_position.x + radius * math.cos(angle * math.pi / 4),
y = ideal_position.y + radius * math.sin(angle * math.pi / 4)
}
-- Round to grid
test_pos.x = math.floor(test_pos.x * 2) / 2
test_pos.y = math.floor(test_pos.y * 2) / 2
local is_reserved, blocking_entity = is_position_reserved_for_poles(test_pos)
if not is_reserved then
local can_place = game.surfaces[1].can_place_entity({
name = connection_type,
position = test_pos,
force = "player"
})
if can_place then
return test_pos
end
end
end
end
return nil -- No suitable position found
end
local function is_position_saturated(position, reach)
-- Get nearby poles
local nearby_poles = game.surfaces[1].find_entities_filtered{
@@ -606,43 +459,13 @@ local function place_at_position(player, connection_type, current_position, dir,
if is_electric_pole then
if is_position_saturated(current_position, wire_reach[connection_type]) then
return -- No need to place another pole
end
-- Check if the position is reserved for non-pole entities
local is_reserved, blocking_entity = is_position_reserved_for_poles(current_position)
if is_reserved then
-- Try to find an alternative position (preferably at corners)
placement_position = find_alternative_pole_position(current_position, connection_type)
if not placement_position then
-- Fallback to original method but avoid reserved areas
local attempts = 0
local max_attempts = 10
repeat
placement_position = game.surfaces[1].find_non_colliding_position(connection_type, current_position, 2 + attempts * 0.5, 0.1)
attempts = attempts + 1
if placement_position then
local still_reserved, _ = is_position_reserved_for_poles(placement_position)
if not still_reserved then
break -- Found a good position
else
placement_position = nil -- Keep searching
end
end
until placement_position or attempts >= max_attempts
if not placement_position then
error("Cannot find suitable position to place " .. connection_type .. " without blocking important adjacent slots around large entities")
end
end
else
-- Original logic for non-reserved positions
placement_position = game.surfaces[1].find_non_colliding_position(connection_type, current_position, 2, 0.1)
if not placement_position then
error("Cannot find suitable position to place " .. connection_type)
end
placement_position = game.surfaces[1].find_non_colliding_position(connection_type, current_position, 2, 0.1)
if not placement_position then
error("Cannot find suitable position to place " .. connection_type)
end
else
local entities = game.surfaces[1].find_entities_filtered{
@@ -925,16 +748,8 @@ local function connect_entities(player_index, source_x, source_y, target_x, targ
-- game.print("Path length "..#raw_path)
-- game.print(serpent.line(start_position).." - "..serpent.line(end_position))
if not raw_path then
error("No path found for handle " .. path_handle .. ". Pathfinding may have failed.")
elseif raw_path == "not_found" then
error("Pathfinding failed: no valid path exists between source and target positions.")
elseif raw_path == "busy" then
error("Pathfinder is busy, try again later.")
elseif type(raw_path) ~= "table" then
error("Invalid path type: " .. type(raw_path) .. " (value: " .. serpent.line(raw_path) .. ")")
elseif #raw_path == 0 then
error("Empty path returned from pathfinder.")
if not raw_path or type(raw_path) ~= "table" or #raw_path == 0 then
error("Invalid path: " .. serpent.line(path))
end
-- game.print("Normalising", {print_skip=defines.print_skip.never})
@@ -978,14 +793,6 @@ local function connect_entities(player_index, source_x, source_y, target_x, targ
-- Get source and target entities
local source_entity = global.utils.get_closest_entity(player, {x = source_x, y = source_y})
local target_entity = global.utils.get_closest_entity(player, {x = target_x, y = target_y})
-- Validate that entities were found
if not source_entity then
error("No entity found at source position x=" .. source_x .. " y=" .. source_y .. " within radius 3")
end
if not target_entity then
error("No entity found at target position x=" .. target_x .. " y=" .. target_y .. " within radius 3")
end
if #connection_types == 1 and connection_types[1] == 'pipe-to-ground' then

View File

@@ -32,10 +32,49 @@ class GetEntities(Tool):
if not isinstance(entities, Set):
entities = set([entities])
# Serialize entity_names as a string
# Handle group prototypes by expanding them to their component types
expanded_entities = set()
group_requests = set()
for entity in entities:
if entity == Prototype.BeltGroup:
# For belt groups, search for all belt types
belt_types = {
Prototype.TransportBelt,
Prototype.FastTransportBelt,
Prototype.ExpressTransportBelt,
Prototype.UndergroundBelt,
Prototype.FastUndergroundBelt,
Prototype.ExpressUndergroundBelt,
}
expanded_entities.update(belt_types)
group_requests.add(Prototype.BeltGroup)
elif entity == Prototype.PipeGroup:
# For pipe groups, search for pipe types
pipe_types = {Prototype.Pipe, Prototype.UndergroundPipe}
expanded_entities.update(pipe_types)
group_requests.add(Prototype.PipeGroup)
elif entity == Prototype.ElectricityGroup:
# For electricity groups, search for pole types
pole_types = {
Prototype.SmallElectricPole,
Prototype.MediumElectricPole,
Prototype.BigElectricPole,
}
expanded_entities.update(pole_types)
group_requests.add(Prototype.ElectricityGroup)
else:
expanded_entities.add(entity)
# Use expanded entities for the Lua query
query_entities = expanded_entities
# Serialize entity_names as a string
entity_names = (
"[" + ",".join([f'"{entity.value[0]}"' for entity in entities]) + "]"
if entities
"["
+ ",".join([f'"{entity.value[0]}"' for entity in query_entities])
+ "]"
if query_entities
else "[]"
)
@@ -104,15 +143,24 @@ class GetEntities(Tool):
except Exception as e1:
print(f"Could not create {entity_data['name']} object: {e1}")
# Only group entities if user is looking for group types or no specific filter
should_group = not entities or any(
proto
in {
Prototype.ElectricityGroup,
Prototype.PipeGroup,
Prototype.BeltGroup,
}
for proto in entities
# Group entities when:
# 1. User explicitly requests group types, OR
# 2. User provides a position filter (suggesting they want nearby entities grouped), OR
# 3. No specific entities requested (get all entities - should be grouped)
should_group = (
not entities # No filter = group everything
or any(
proto
in {
Prototype.ElectricityGroup,
Prototype.PipeGroup,
Prototype.BeltGroup,
}
for proto in entities
) # Explicit group request
or (
entities and position is not None
) # Individual entities with position filter = group for convenience
)
if should_group:
@@ -177,22 +225,84 @@ class GetEntities(Tool):
if hasattr(entity, "prototype") and entity.prototype in entities:
filtered_entities.append(entity)
elif hasattr(entity, "__class__"):
# Check for group types
if (
entity.__class__.__name__ == "ElectricityGroup"
and Prototype.ElectricityGroup in entities
):
filtered_entities.append(entity)
elif (
entity.__class__.__name__ == "PipeGroup"
and Prototype.PipeGroup in entities
):
filtered_entities.append(entity)
elif (
entity.__class__.__name__ == "BeltGroup"
and Prototype.BeltGroup in entities
):
filtered_entities.append(entity)
# Handle group entities
if entity.__class__.__name__ == "ElectricityGroup":
pole_types = {
Prototype.SmallElectricPole,
Prototype.MediumElectricPole,
Prototype.BigElectricPole,
}
if Prototype.ElectricityGroup in group_requests:
# Explicit group request - return the group
filtered_entities.append(entity)
elif (
any(pole_type in entities for pole_type in pole_types)
and position is not None
):
# Individual poles requested with position - return group for convenience
filtered_entities.append(entity)
elif (
any(pole_type in entities for pole_type in pole_types)
and position is None
):
# Individual poles requested without position - extract individual poles from group
for pole in entity.poles:
if (
hasattr(pole, "prototype")
and pole.prototype in entities
):
filtered_entities.append(pole)
elif entity.__class__.__name__ == "PipeGroup":
pipe_types = {Prototype.Pipe, Prototype.UndergroundPipe}
if Prototype.PipeGroup in group_requests:
# Explicit group request - return the group
filtered_entities.append(entity)
elif (
any(pipe_type in entities for pipe_type in pipe_types)
and position is not None
):
# Individual pipes requested with position - return group for convenience
filtered_entities.append(entity)
elif (
any(pipe_type in entities for pipe_type in pipe_types)
and position is None
):
# Individual pipes requested without position - extract individual pipes from group
for pipe in entity.pipes:
if (
hasattr(pipe, "prototype")
and pipe.prototype in entities
):
filtered_entities.append(pipe)
elif entity.__class__.__name__ == "BeltGroup":
belt_types = {
Prototype.TransportBelt,
Prototype.FastTransportBelt,
Prototype.ExpressTransportBelt,
Prototype.UndergroundBelt,
Prototype.FastUndergroundBelt,
Prototype.ExpressUndergroundBelt,
}
if Prototype.BeltGroup in group_requests:
# Explicit group request - return the group
filtered_entities.append(entity)
elif (
any(belt_type in entities for belt_type in belt_types)
and position is not None
):
# Individual belts requested with position - return group for convenience
filtered_entities.append(entity)
elif (
any(belt_type in entities for belt_type in belt_types)
and position is None
):
# Individual belts requested without position - extract individual belts from group
for belt in entity.belts:
if (
hasattr(belt, "prototype")
and belt.prototype in entities
):
filtered_entities.append(belt)
elif entity.__class__.__name__ == "WallGroup":
# WallGroup doesn't have a corresponding Prototype, but include if present
filtered_entities.append(entity)

View File

@@ -76,6 +76,9 @@ class MoveTo(Tool):
if isinstance(response, int) and response == 0:
raise Exception("Could not move.")
if isinstance(response, str):
raise Exception(f"Could not move. {response}")
if response == "trailing" or response == "leading":
raise Exception("Could not lay entity, perhaps a typo?")

View File

@@ -25,6 +25,7 @@ end
global.actions.move_to = function(player_index, path_handle, trailing_entity, is_trailing)
--local player = global.agent_characters[player_index]
local player = global.agent_characters[player_index]
game.print("Moving to path with handle: " .. path_handle)
local path = global.paths[path_handle]
local surface = player.surface