mirror of
https://github.com/JackHopkins/factorio-learning-environment.git
synced 2025-09-06 13:23:58 +00:00
better connect entities behavior for no new entities placed, better grouped entity behavior, better error messages'
This commit is contained in:
14
fle/env/gym_env/environment.py
vendored
14
fle/env/gym_env/environment.py
vendored
@@ -268,7 +268,6 @@ class FactorioGymEnv(gym.Env):
|
||||
self.last_observation = None
|
||||
# Track last message timestamp for each agent
|
||||
self.last_message_timestamps = {i: 0.0 for i in range(instance.num_agents)}
|
||||
self._last_production_flows = {}
|
||||
|
||||
def get_observation(
|
||||
self, agent_idx: int = 0, response: Optional[Response] = None
|
||||
@@ -395,12 +394,10 @@ class FactorioGymEnv(gym.Env):
|
||||
self.reset_instance(GameState.parse_raw(action.game_state.to_raw()))
|
||||
|
||||
namespace = self.instance.namespaces[agent_idx]
|
||||
# Use last post_production_flows as pre_production_flows if available
|
||||
if self._last_production_flows.get(agent_idx) is not None:
|
||||
production_flows = self._last_production_flows[agent_idx]
|
||||
else:
|
||||
production_flows = namespace._get_production_stats()
|
||||
start_production_flows = ProductionFlows.from_dict(production_flows)
|
||||
# Calculate fresh production flows at the beginning of the step
|
||||
start_production_flows = ProductionFlows.from_dict(
|
||||
namespace._get_production_stats()
|
||||
)
|
||||
|
||||
# Execute the action
|
||||
initial_score, eval_time, result = self.instance.eval(
|
||||
@@ -434,8 +431,6 @@ class FactorioGymEnv(gym.Env):
|
||||
# Get post-execution flows and calculate achievements
|
||||
current_flows = ProductionFlows.from_dict(namespace._get_production_stats())
|
||||
achievements = calculate_achievements(start_production_flows, current_flows)
|
||||
# Store for next step
|
||||
self._last_production_flows[agent_idx] = current_flows.__dict__
|
||||
|
||||
# Create response object for observation
|
||||
response = Response(
|
||||
@@ -481,7 +476,6 @@ class FactorioGymEnv(gym.Env):
|
||||
state: Optional[GameState] to reset to. If None, resets to initial state.
|
||||
"""
|
||||
self.instance.reset(state)
|
||||
self._last_production_flows = {i: None for i in range(self.instance.num_agents)}
|
||||
|
||||
def reset(
|
||||
self, options: Optional[Dict[str, Any]] = None, seed: Optional[int] = None
|
||||
|
2
fle/env/gym_env/trajectory_logger.py
vendored
2
fle/env/gym_env/trajectory_logger.py
vendored
@@ -109,7 +109,7 @@ class TrajectoryLogger:
|
||||
|
||||
raw_text = agent.observation_formatter.format_raw_text(observation.raw_text)
|
||||
for line in raw_text.split("\n"):
|
||||
if "Error" in line:
|
||||
if "Error" in line or "Exception" in line:
|
||||
print("raw_text Error:", line)
|
||||
|
||||
def add_iteration_time(self, iteration_time: float):
|
||||
|
@@ -31,6 +31,11 @@ end
|
||||
|
||||
-- Create agent characters script
|
||||
global.actions.create_agent_characters = function(num_agents)
|
||||
-- delete all character entities on the surface
|
||||
for _, entity in pairs(game.surfaces[1].find_entities_filtered{type = "character"}) do
|
||||
entity.destroy()
|
||||
end
|
||||
|
||||
-- Initialize agent characters table
|
||||
-- Destroy existing agent characters if they exist
|
||||
if global.agent_characters then
|
||||
|
18
fle/env/tools/admin/request_path/server.lua
vendored
18
fle/env/tools/admin/request_path/server.lua
vendored
@@ -64,7 +64,6 @@ global.actions.request_path = function(player_index, start_x, start_y, goal_x, g
|
||||
allow_paths_through_own_entities = allow_paths_through_own_entities
|
||||
}
|
||||
}
|
||||
|
||||
local request_id = surface.request_path(path_request)
|
||||
|
||||
global.clearance_entities[request_id] = clearance_entities
|
||||
@@ -95,24 +94,27 @@ end
|
||||
--end)
|
||||
|
||||
script.on_event(defines.events.on_script_path_request_finished, function(event)
|
||||
game.print("Path request finished for ID: " .. event.id)
|
||||
local request_data = global.path_requests[event.id]
|
||||
if not request_data then
|
||||
game.print("No request data found for ID: " .. event.id)
|
||||
return
|
||||
end
|
||||
|
||||
local player = global.agent_characters[request_data]
|
||||
if not player then
|
||||
return
|
||||
end
|
||||
-- local player = global.agent_characters[request_data]
|
||||
-- if not player then
|
||||
-- game.print("No player found for request ID: " .. event.id)
|
||||
-- return
|
||||
-- end
|
||||
|
||||
if event.path then
|
||||
global.paths[event.id] = event.path
|
||||
-- log("Path found for request ID: " .. event.id)
|
||||
game.print("Path found for request ID: " .. event.id)
|
||||
elseif event.try_again_later then
|
||||
global.paths[event.id] = "busy"
|
||||
-- log("Pathfinder busy for request ID: " .. event.id)
|
||||
game.print("Pathfinder busy for request ID: " .. event.id)
|
||||
else
|
||||
global.paths[event.id] = "not_found"
|
||||
-- log("Path not found for request ID: " .. event.id)
|
||||
game.print("Path not found for request ID: " .. event.id)
|
||||
end
|
||||
end)
|
140
fle/env/tools/agent/connect_entities/client.py
vendored
140
fle/env/tools/agent/connect_entities/client.py
vendored
@@ -194,7 +194,7 @@ class ConnectEntities(Tool):
|
||||
|
||||
# Resolve connection type if not provided
|
||||
if not connection_types:
|
||||
self._infer_connection_type(source, target)
|
||||
connection_types = {self._infer_connection_type(source, target)}
|
||||
else:
|
||||
valid = self._validate_connection_types(connection_types)
|
||||
if not valid:
|
||||
@@ -237,7 +237,25 @@ class ConnectEntities(Tool):
|
||||
if isinstance(target, (Entity, EntityGroup))
|
||||
else None,
|
||||
)
|
||||
return connection[0] if not dry_run else connection
|
||||
if not dry_run:
|
||||
if connection and len(connection) > 0:
|
||||
return connection[0]
|
||||
else:
|
||||
# No entities were created but pathing was successful - find existing connecting entity
|
||||
existing_entity = self._find_existing_connecting_entity(
|
||||
target_pos, list(connection_types)[0]
|
||||
)
|
||||
return (
|
||||
existing_entity
|
||||
if existing_entity
|
||||
else (
|
||||
target
|
||||
if isinstance(target, (Entity, EntityGroup))
|
||||
else target_pos
|
||||
)
|
||||
)
|
||||
else:
|
||||
return connection
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
pass
|
||||
@@ -268,7 +286,22 @@ class ConnectEntities(Tool):
|
||||
if isinstance(target, (Entity, EntityGroup))
|
||||
else None,
|
||||
)
|
||||
return connection[0]
|
||||
if connection and len(connection) > 0:
|
||||
return connection[0]
|
||||
else:
|
||||
# No entities were created but pathing was successful - find existing connecting entity
|
||||
existing_entity = self._find_existing_connecting_entity(
|
||||
target_pos, list(connection_types)[0]
|
||||
)
|
||||
return (
|
||||
existing_entity
|
||||
if existing_entity
|
||||
else (
|
||||
target
|
||||
if isinstance(target, (Entity, EntityGroup))
|
||||
else target_pos
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
@@ -734,8 +767,8 @@ class ConnectEntities(Tool):
|
||||
def _process_belt_groups(
|
||||
self,
|
||||
groupable_entities: List[Entity],
|
||||
source_entity: Optional[Entity],
|
||||
target_entity: Optional[Entity],
|
||||
source_entity: Optional[Union[Entity, EntityGroup]],
|
||||
target_entity: Optional[Union[Entity, EntityGroup]],
|
||||
source_pos: Position,
|
||||
) -> List[BeltGroup]:
|
||||
"""Process transport belt groups"""
|
||||
@@ -764,23 +797,7 @@ class ConnectEntities(Tool):
|
||||
self.rotate_end_belt_to_face(entity_groups[0], target_entity)
|
||||
# self.rotate_final_belt_when_connecting_groups(entity_groups[0], source_entity)
|
||||
|
||||
# Get final groups and filter to relevant one
|
||||
entity_groups = self.get_entities(
|
||||
{
|
||||
Prototype.TransportBelt,
|
||||
Prototype.ExpressTransportBelt,
|
||||
Prototype.FastTransportBelt,
|
||||
Prototype.UndergroundBelt,
|
||||
Prototype.FastUndergroundBelt,
|
||||
Prototype.ExpressUndergroundBelt,
|
||||
},
|
||||
source_pos,
|
||||
)
|
||||
|
||||
for group in entity_groups:
|
||||
if source_pos in [entity.position for entity in group.belts]:
|
||||
return cast(List[BeltGroup], [group])
|
||||
|
||||
# Return the properly grouped entities
|
||||
return cast(List[BeltGroup], entity_groups)
|
||||
|
||||
def _update_belt_group(
|
||||
@@ -901,16 +918,15 @@ class ConnectEntities(Tool):
|
||||
self, groupable_entities: List[Entity], source_pos: Position
|
||||
) -> List[PipeGroup]:
|
||||
"""Process pipe groups"""
|
||||
entity_groups = self.get_entities(
|
||||
{Prototype.Pipe, Prototype.UndergroundPipe}, source_pos
|
||||
)
|
||||
# Group the passed entities first
|
||||
entity_groups = agglomerate_groupable_entities(groupable_entities)
|
||||
|
||||
# Deduplicate pipes in groups
|
||||
for group in entity_groups:
|
||||
group.pipes = _deduplicate_entities(group.pipes)
|
||||
if source_pos in [entity.position for entity in group.pipes]:
|
||||
return [group]
|
||||
if hasattr(group, "pipes"):
|
||||
group.pipes = _deduplicate_entities(group.pipes)
|
||||
|
||||
return entity_groups
|
||||
return cast(List[PipeGroup], entity_groups)
|
||||
|
||||
def _process_groups(
|
||||
self,
|
||||
@@ -928,21 +944,6 @@ class ConnectEntities(Tool):
|
||||
|
||||
return entity_groups
|
||||
|
||||
def _process_pipe_groups(
|
||||
self, groupable_entities: List[Entity], source_pos: Position
|
||||
) -> List[PipeGroup]:
|
||||
"""Process pipe groups"""
|
||||
entity_groups = self.get_entities(
|
||||
{Prototype.Pipe, Prototype.UndergroundPipe}, source_pos
|
||||
)
|
||||
|
||||
for group in entity_groups:
|
||||
group.pipes = _deduplicate_entities(group.pipes)
|
||||
if source_pos in [entity.position for entity in group.pipes]:
|
||||
return [group]
|
||||
|
||||
return entity_groups
|
||||
|
||||
def _attempt_to_get_entity(
|
||||
self, position: Position, get_connectors: bool = False
|
||||
) -> Union[Position, Entity, EntityGroup]:
|
||||
@@ -978,17 +979,9 @@ class ConnectEntities(Tool):
|
||||
self, groupable_entities: List[Entity], source_pos: Position
|
||||
) -> List[ElectricityGroup]:
|
||||
"""Process power pole groups"""
|
||||
return cast(
|
||||
List[ElectricityGroup],
|
||||
self.get_entities(
|
||||
{
|
||||
Prototype.SmallElectricPole,
|
||||
Prototype.BigElectricPole,
|
||||
Prototype.MediumElectricPole,
|
||||
},
|
||||
source_pos,
|
||||
),
|
||||
)
|
||||
# Group the passed entities first
|
||||
entity_groups = agglomerate_groupable_entities(groupable_entities)
|
||||
return cast(List[ElectricityGroup], entity_groups)
|
||||
|
||||
def _adjust_belt_position(
|
||||
self, pos: Position, entity: Optional[Entity]
|
||||
@@ -1249,6 +1242,43 @@ class ConnectEntities(Tool):
|
||||
entities = self.get_entities(position=pos, radius=radius)
|
||||
return bool(entities)
|
||||
|
||||
def _find_existing_connecting_entity(
|
||||
self, target_pos: Position, connection_type: Prototype
|
||||
) -> Optional[Union[Entity, EntityGroup]]:
|
||||
"""
|
||||
Find existing entity of the connection type closest to the target position.
|
||||
This is used when no new entities were created but connection was successful.
|
||||
"""
|
||||
# Determine search radius based on connection type
|
||||
if connection_type in (
|
||||
Prototype.SmallElectricPole,
|
||||
Prototype.MediumElectricPole,
|
||||
Prototype.BigElectricPole,
|
||||
):
|
||||
search_radius = 10 # Power poles have larger reach
|
||||
else:
|
||||
search_radius = 5 # Belts, pipes have smaller reach
|
||||
|
||||
# Get entities of the connection type near target
|
||||
entities = self.get_entities(connection_type, target_pos, search_radius)
|
||||
|
||||
if not entities:
|
||||
return None
|
||||
|
||||
# Find closest entity to target position
|
||||
closest_entity = None
|
||||
closest_distance = float("inf")
|
||||
|
||||
for entity in entities:
|
||||
entity_pos = entity.position if hasattr(entity, "position") else None
|
||||
if entity_pos:
|
||||
distance = target_pos.distance(entity_pos)
|
||||
if distance < closest_distance:
|
||||
closest_distance = distance
|
||||
closest_entity = entity
|
||||
|
||||
return closest_entity
|
||||
|
||||
def pickup_entities(
|
||||
self,
|
||||
path_data: dict,
|
||||
|
@@ -39,8 +39,7 @@ def _construct_group(
|
||||
id: int, entities: List[Entity], prototype: Prototype, position: Position
|
||||
) -> EntityGroup:
|
||||
if prototype == Prototype.TransportBelt or isinstance(entities[0], TransportBelt):
|
||||
if len(entities) == 1:
|
||||
return entities[0]
|
||||
# Always return BeltGroup for consistent return types, even for single belts
|
||||
inputs = [c for c in entities if c.is_source]
|
||||
outputs = [c for c in entities if c.is_terminus]
|
||||
inventory = Inventory()
|
||||
|
205
fle/env/tools/agent/connect_entities/server.lua
vendored
205
fle/env/tools/agent/connect_entities/server.lua
vendored
@@ -58,153 +58,6 @@ local function are_positions_in_same_network(pos1, pos2)
|
||||
return network1 and network2 and network1 == network2
|
||||
end
|
||||
|
||||
local function is_large_entity(entity)
|
||||
if not entity then return false end
|
||||
|
||||
-- Check by size (3x3 or larger)
|
||||
local prototype = game.entity_prototypes[entity.name]
|
||||
if prototype then
|
||||
local collision_box = prototype.collision_box
|
||||
local width = math.abs(collision_box.right_bottom.x - collision_box.left_top.x)
|
||||
local height = math.abs(collision_box.right_bottom.y - collision_box.left_top.y)
|
||||
return width >= 2.6 and height >= 2.6 -- 3x3 entities have collision box ~2.8x2.8
|
||||
end
|
||||
|
||||
return false
|
||||
end
|
||||
|
||||
local function get_reserved_positions_around_entity(entity)
|
||||
if not is_large_entity(entity) then
|
||||
return {}
|
||||
end
|
||||
|
||||
local pos = entity.position
|
||||
local reserved = {}
|
||||
|
||||
-- Middle adjacent positions (reserved for inserters, chests, etc.)
|
||||
local middle_adjacent = {
|
||||
{x = pos.x, y = pos.y - 2}, -- North (middle)
|
||||
{x = pos.x + 2, y = pos.y}, -- East (middle)
|
||||
{x = pos.x, y = pos.y + 2}, -- South (middle)
|
||||
{x = pos.x - 2, y = pos.y} -- West (middle)
|
||||
}
|
||||
|
||||
for _, reserved_pos in pairs(middle_adjacent) do
|
||||
table.insert(reserved, reserved_pos)
|
||||
end
|
||||
|
||||
return reserved
|
||||
end
|
||||
|
||||
local function get_corner_positions_around_entity(entity)
|
||||
if not is_large_entity(entity) then
|
||||
return {}
|
||||
end
|
||||
|
||||
local pos = entity.position
|
||||
local corners = {
|
||||
{x = pos.x - 2, y = pos.y - 2}, -- Northwest corner
|
||||
{x = pos.x + 2, y = pos.y - 2}, -- Northeast corner
|
||||
{x = pos.x + 2, y = pos.y + 2}, -- Southeast corner
|
||||
{x = pos.x - 2, y = pos.y + 2} -- Southwest corner
|
||||
}
|
||||
|
||||
return corners
|
||||
end
|
||||
|
||||
local function is_position_reserved_for_poles(position)
|
||||
-- Find large entities within reasonable range
|
||||
local nearby_entities = game.surfaces[1].find_entities_filtered{
|
||||
position = position,
|
||||
radius = 4, -- Search within 4 tiles
|
||||
force = "player"
|
||||
}
|
||||
|
||||
for _, entity in pairs(nearby_entities) do
|
||||
if is_large_entity(entity) then
|
||||
local reserved_positions = get_reserved_positions_around_entity(entity)
|
||||
|
||||
-- Check if current position is in reserved list
|
||||
for _, reserved_pos in pairs(reserved_positions) do
|
||||
local distance = math.sqrt(
|
||||
(position.x - reserved_pos.x)^2 + (position.y - reserved_pos.y)^2
|
||||
)
|
||||
if distance < 0.5 then -- Within half a tile
|
||||
return true, entity -- Position is reserved
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return false, nil
|
||||
end
|
||||
|
||||
local function find_alternative_pole_position(ideal_position, connection_type)
|
||||
local search_radius = wire_reach[connection_type] or 4
|
||||
local wire_reach_distance = wire_reach[connection_type] or 4
|
||||
|
||||
-- First, try corner positions of nearby large entities
|
||||
local nearby_entities = game.surfaces[1].find_entities_filtered{
|
||||
position = ideal_position,
|
||||
radius = search_radius,
|
||||
force = "player"
|
||||
}
|
||||
|
||||
for _, entity in pairs(nearby_entities) do
|
||||
if is_large_entity(entity) then
|
||||
local corners = get_corner_positions_around_entity(entity)
|
||||
|
||||
for _, corner_pos in pairs(corners) do
|
||||
local distance_to_ideal = math.sqrt(
|
||||
(corner_pos.x - ideal_position.x)^2 + (corner_pos.y - ideal_position.y)^2
|
||||
)
|
||||
|
||||
-- Check if corner is within wire reach of ideal position
|
||||
if distance_to_ideal <= wire_reach_distance then
|
||||
local can_place = game.surfaces[1].can_place_entity({
|
||||
name = connection_type,
|
||||
position = corner_pos,
|
||||
force = "player"
|
||||
})
|
||||
|
||||
if can_place then
|
||||
return corner_pos
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- Fallback: try positions in expanding circles, avoiding reserved areas
|
||||
for radius = 1, search_radius, 0.5 do
|
||||
for angle = 0, 7 do -- 8 directions
|
||||
local test_pos = {
|
||||
x = ideal_position.x + radius * math.cos(angle * math.pi / 4),
|
||||
y = ideal_position.y + radius * math.sin(angle * math.pi / 4)
|
||||
}
|
||||
|
||||
-- Round to grid
|
||||
test_pos.x = math.floor(test_pos.x * 2) / 2
|
||||
test_pos.y = math.floor(test_pos.y * 2) / 2
|
||||
|
||||
local is_reserved, blocking_entity = is_position_reserved_for_poles(test_pos)
|
||||
if not is_reserved then
|
||||
local can_place = game.surfaces[1].can_place_entity({
|
||||
name = connection_type,
|
||||
position = test_pos,
|
||||
force = "player"
|
||||
})
|
||||
|
||||
if can_place then
|
||||
return test_pos
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return nil -- No suitable position found
|
||||
end
|
||||
|
||||
local function is_position_saturated(position, reach)
|
||||
-- Get nearby poles
|
||||
local nearby_poles = game.surfaces[1].find_entities_filtered{
|
||||
@@ -606,43 +459,13 @@ local function place_at_position(player, connection_type, current_position, dir,
|
||||
|
||||
if is_electric_pole then
|
||||
if is_position_saturated(current_position, wire_reach[connection_type]) then
|
||||
|
||||
return -- No need to place another pole
|
||||
end
|
||||
|
||||
-- Check if the position is reserved for non-pole entities
|
||||
local is_reserved, blocking_entity = is_position_reserved_for_poles(current_position)
|
||||
if is_reserved then
|
||||
-- Try to find an alternative position (preferably at corners)
|
||||
placement_position = find_alternative_pole_position(current_position, connection_type)
|
||||
if not placement_position then
|
||||
-- Fallback to original method but avoid reserved areas
|
||||
local attempts = 0
|
||||
local max_attempts = 10
|
||||
|
||||
repeat
|
||||
placement_position = game.surfaces[1].find_non_colliding_position(connection_type, current_position, 2 + attempts * 0.5, 0.1)
|
||||
attempts = attempts + 1
|
||||
|
||||
if placement_position then
|
||||
local still_reserved, _ = is_position_reserved_for_poles(placement_position)
|
||||
if not still_reserved then
|
||||
break -- Found a good position
|
||||
else
|
||||
placement_position = nil -- Keep searching
|
||||
end
|
||||
end
|
||||
until placement_position or attempts >= max_attempts
|
||||
|
||||
if not placement_position then
|
||||
error("Cannot find suitable position to place " .. connection_type .. " without blocking important adjacent slots around large entities")
|
||||
end
|
||||
end
|
||||
else
|
||||
-- Original logic for non-reserved positions
|
||||
placement_position = game.surfaces[1].find_non_colliding_position(connection_type, current_position, 2, 0.1)
|
||||
if not placement_position then
|
||||
error("Cannot find suitable position to place " .. connection_type)
|
||||
end
|
||||
placement_position = game.surfaces[1].find_non_colliding_position(connection_type, current_position, 2, 0.1)
|
||||
if not placement_position then
|
||||
error("Cannot find suitable position to place " .. connection_type)
|
||||
end
|
||||
else
|
||||
local entities = game.surfaces[1].find_entities_filtered{
|
||||
@@ -925,16 +748,8 @@ local function connect_entities(player_index, source_x, source_y, target_x, targ
|
||||
-- game.print("Path length "..#raw_path)
|
||||
-- game.print(serpent.line(start_position).." - "..serpent.line(end_position))
|
||||
|
||||
if not raw_path then
|
||||
error("No path found for handle " .. path_handle .. ". Pathfinding may have failed.")
|
||||
elseif raw_path == "not_found" then
|
||||
error("Pathfinding failed: no valid path exists between source and target positions.")
|
||||
elseif raw_path == "busy" then
|
||||
error("Pathfinder is busy, try again later.")
|
||||
elseif type(raw_path) ~= "table" then
|
||||
error("Invalid path type: " .. type(raw_path) .. " (value: " .. serpent.line(raw_path) .. ")")
|
||||
elseif #raw_path == 0 then
|
||||
error("Empty path returned from pathfinder.")
|
||||
if not raw_path or type(raw_path) ~= "table" or #raw_path == 0 then
|
||||
error("Invalid path: " .. serpent.line(path))
|
||||
end
|
||||
|
||||
-- game.print("Normalising", {print_skip=defines.print_skip.never})
|
||||
@@ -978,14 +793,6 @@ local function connect_entities(player_index, source_x, source_y, target_x, targ
|
||||
-- Get source and target entities
|
||||
local source_entity = global.utils.get_closest_entity(player, {x = source_x, y = source_y})
|
||||
local target_entity = global.utils.get_closest_entity(player, {x = target_x, y = target_y})
|
||||
|
||||
-- Validate that entities were found
|
||||
if not source_entity then
|
||||
error("No entity found at source position x=" .. source_x .. " y=" .. source_y .. " within radius 3")
|
||||
end
|
||||
if not target_entity then
|
||||
error("No entity found at target position x=" .. target_x .. " y=" .. target_y .. " within radius 3")
|
||||
end
|
||||
|
||||
|
||||
if #connection_types == 1 and connection_types[1] == 'pipe-to-ground' then
|
||||
|
166
fle/env/tools/agent/get_entities/client.py
vendored
166
fle/env/tools/agent/get_entities/client.py
vendored
@@ -32,10 +32,49 @@ class GetEntities(Tool):
|
||||
if not isinstance(entities, Set):
|
||||
entities = set([entities])
|
||||
|
||||
# Serialize entity_names as a string
|
||||
# Handle group prototypes by expanding them to their component types
|
||||
expanded_entities = set()
|
||||
group_requests = set()
|
||||
|
||||
for entity in entities:
|
||||
if entity == Prototype.BeltGroup:
|
||||
# For belt groups, search for all belt types
|
||||
belt_types = {
|
||||
Prototype.TransportBelt,
|
||||
Prototype.FastTransportBelt,
|
||||
Prototype.ExpressTransportBelt,
|
||||
Prototype.UndergroundBelt,
|
||||
Prototype.FastUndergroundBelt,
|
||||
Prototype.ExpressUndergroundBelt,
|
||||
}
|
||||
expanded_entities.update(belt_types)
|
||||
group_requests.add(Prototype.BeltGroup)
|
||||
elif entity == Prototype.PipeGroup:
|
||||
# For pipe groups, search for pipe types
|
||||
pipe_types = {Prototype.Pipe, Prototype.UndergroundPipe}
|
||||
expanded_entities.update(pipe_types)
|
||||
group_requests.add(Prototype.PipeGroup)
|
||||
elif entity == Prototype.ElectricityGroup:
|
||||
# For electricity groups, search for pole types
|
||||
pole_types = {
|
||||
Prototype.SmallElectricPole,
|
||||
Prototype.MediumElectricPole,
|
||||
Prototype.BigElectricPole,
|
||||
}
|
||||
expanded_entities.update(pole_types)
|
||||
group_requests.add(Prototype.ElectricityGroup)
|
||||
else:
|
||||
expanded_entities.add(entity)
|
||||
|
||||
# Use expanded entities for the Lua query
|
||||
query_entities = expanded_entities
|
||||
|
||||
# Serialize entity_names as a string
|
||||
entity_names = (
|
||||
"[" + ",".join([f'"{entity.value[0]}"' for entity in entities]) + "]"
|
||||
if entities
|
||||
"["
|
||||
+ ",".join([f'"{entity.value[0]}"' for entity in query_entities])
|
||||
+ "]"
|
||||
if query_entities
|
||||
else "[]"
|
||||
)
|
||||
|
||||
@@ -104,15 +143,24 @@ class GetEntities(Tool):
|
||||
except Exception as e1:
|
||||
print(f"Could not create {entity_data['name']} object: {e1}")
|
||||
|
||||
# Only group entities if user is looking for group types or no specific filter
|
||||
should_group = not entities or any(
|
||||
proto
|
||||
in {
|
||||
Prototype.ElectricityGroup,
|
||||
Prototype.PipeGroup,
|
||||
Prototype.BeltGroup,
|
||||
}
|
||||
for proto in entities
|
||||
# Group entities when:
|
||||
# 1. User explicitly requests group types, OR
|
||||
# 2. User provides a position filter (suggesting they want nearby entities grouped), OR
|
||||
# 3. No specific entities requested (get all entities - should be grouped)
|
||||
should_group = (
|
||||
not entities # No filter = group everything
|
||||
or any(
|
||||
proto
|
||||
in {
|
||||
Prototype.ElectricityGroup,
|
||||
Prototype.PipeGroup,
|
||||
Prototype.BeltGroup,
|
||||
}
|
||||
for proto in entities
|
||||
) # Explicit group request
|
||||
or (
|
||||
entities and position is not None
|
||||
) # Individual entities with position filter = group for convenience
|
||||
)
|
||||
|
||||
if should_group:
|
||||
@@ -177,22 +225,84 @@ class GetEntities(Tool):
|
||||
if hasattr(entity, "prototype") and entity.prototype in entities:
|
||||
filtered_entities.append(entity)
|
||||
elif hasattr(entity, "__class__"):
|
||||
# Check for group types
|
||||
if (
|
||||
entity.__class__.__name__ == "ElectricityGroup"
|
||||
and Prototype.ElectricityGroup in entities
|
||||
):
|
||||
filtered_entities.append(entity)
|
||||
elif (
|
||||
entity.__class__.__name__ == "PipeGroup"
|
||||
and Prototype.PipeGroup in entities
|
||||
):
|
||||
filtered_entities.append(entity)
|
||||
elif (
|
||||
entity.__class__.__name__ == "BeltGroup"
|
||||
and Prototype.BeltGroup in entities
|
||||
):
|
||||
filtered_entities.append(entity)
|
||||
# Handle group entities
|
||||
if entity.__class__.__name__ == "ElectricityGroup":
|
||||
pole_types = {
|
||||
Prototype.SmallElectricPole,
|
||||
Prototype.MediumElectricPole,
|
||||
Prototype.BigElectricPole,
|
||||
}
|
||||
if Prototype.ElectricityGroup in group_requests:
|
||||
# Explicit group request - return the group
|
||||
filtered_entities.append(entity)
|
||||
elif (
|
||||
any(pole_type in entities for pole_type in pole_types)
|
||||
and position is not None
|
||||
):
|
||||
# Individual poles requested with position - return group for convenience
|
||||
filtered_entities.append(entity)
|
||||
elif (
|
||||
any(pole_type in entities for pole_type in pole_types)
|
||||
and position is None
|
||||
):
|
||||
# Individual poles requested without position - extract individual poles from group
|
||||
for pole in entity.poles:
|
||||
if (
|
||||
hasattr(pole, "prototype")
|
||||
and pole.prototype in entities
|
||||
):
|
||||
filtered_entities.append(pole)
|
||||
elif entity.__class__.__name__ == "PipeGroup":
|
||||
pipe_types = {Prototype.Pipe, Prototype.UndergroundPipe}
|
||||
if Prototype.PipeGroup in group_requests:
|
||||
# Explicit group request - return the group
|
||||
filtered_entities.append(entity)
|
||||
elif (
|
||||
any(pipe_type in entities for pipe_type in pipe_types)
|
||||
and position is not None
|
||||
):
|
||||
# Individual pipes requested with position - return group for convenience
|
||||
filtered_entities.append(entity)
|
||||
elif (
|
||||
any(pipe_type in entities for pipe_type in pipe_types)
|
||||
and position is None
|
||||
):
|
||||
# Individual pipes requested without position - extract individual pipes from group
|
||||
for pipe in entity.pipes:
|
||||
if (
|
||||
hasattr(pipe, "prototype")
|
||||
and pipe.prototype in entities
|
||||
):
|
||||
filtered_entities.append(pipe)
|
||||
elif entity.__class__.__name__ == "BeltGroup":
|
||||
belt_types = {
|
||||
Prototype.TransportBelt,
|
||||
Prototype.FastTransportBelt,
|
||||
Prototype.ExpressTransportBelt,
|
||||
Prototype.UndergroundBelt,
|
||||
Prototype.FastUndergroundBelt,
|
||||
Prototype.ExpressUndergroundBelt,
|
||||
}
|
||||
if Prototype.BeltGroup in group_requests:
|
||||
# Explicit group request - return the group
|
||||
filtered_entities.append(entity)
|
||||
elif (
|
||||
any(belt_type in entities for belt_type in belt_types)
|
||||
and position is not None
|
||||
):
|
||||
# Individual belts requested with position - return group for convenience
|
||||
filtered_entities.append(entity)
|
||||
elif (
|
||||
any(belt_type in entities for belt_type in belt_types)
|
||||
and position is None
|
||||
):
|
||||
# Individual belts requested without position - extract individual belts from group
|
||||
for belt in entity.belts:
|
||||
if (
|
||||
hasattr(belt, "prototype")
|
||||
and belt.prototype in entities
|
||||
):
|
||||
filtered_entities.append(belt)
|
||||
elif entity.__class__.__name__ == "WallGroup":
|
||||
# WallGroup doesn't have a corresponding Prototype, but include if present
|
||||
filtered_entities.append(entity)
|
||||
|
3
fle/env/tools/agent/move_to/client.py
vendored
3
fle/env/tools/agent/move_to/client.py
vendored
@@ -76,6 +76,9 @@ class MoveTo(Tool):
|
||||
if isinstance(response, int) and response == 0:
|
||||
raise Exception("Could not move.")
|
||||
|
||||
if isinstance(response, str):
|
||||
raise Exception(f"Could not move. {response}")
|
||||
|
||||
if response == "trailing" or response == "leading":
|
||||
raise Exception("Could not lay entity, perhaps a typo?")
|
||||
|
||||
|
1
fle/env/tools/agent/move_to/server.lua
vendored
1
fle/env/tools/agent/move_to/server.lua
vendored
@@ -25,6 +25,7 @@ end
|
||||
global.actions.move_to = function(player_index, path_handle, trailing_entity, is_trailing)
|
||||
--local player = global.agent_characters[player_index]
|
||||
local player = global.agent_characters[player_index]
|
||||
game.print("Moving to path with handle: " .. path_handle)
|
||||
local path = global.paths[path_handle]
|
||||
local surface = player.surface
|
||||
|
||||
|
Reference in New Issue
Block a user