mirror of
https://github.com/JackHopkins/factorio-learning-environment.git
synced 2025-09-06 13:23:58 +00:00
fixes for tests
This commit is contained in:
169
fle/env/tools/agent/connect_entities/client.py
vendored
169
fle/env/tools/agent/connect_entities/client.py
vendored
@@ -161,10 +161,10 @@ class ConnectEntities(Tool):
|
||||
ticks_added = ticks_after - ticks_before
|
||||
if ticks_added > 0:
|
||||
game_speed = self.game_state.instance.get_speed()
|
||||
real_world_sleep = (
|
||||
real_world_sleep = ( # noqa
|
||||
ticks_added / 60 / game_speed if game_speed > 0 else 0
|
||||
)
|
||||
sleep(real_world_sleep)
|
||||
# sleep(real_world_sleep)
|
||||
|
||||
if dry_run:
|
||||
return {
|
||||
@@ -241,18 +241,9 @@ class ConnectEntities(Tool):
|
||||
if connection and len(connection) > 0:
|
||||
return connection[0]
|
||||
else:
|
||||
# No entities were created but pathing was successful - find existing connecting entity
|
||||
existing_entity = self._find_existing_connecting_entity(
|
||||
target_pos, list(connection_types)[0]
|
||||
)
|
||||
return (
|
||||
existing_entity
|
||||
if existing_entity
|
||||
else (
|
||||
target
|
||||
if isinstance(target, (Entity, EntityGroup))
|
||||
else target_pos
|
||||
)
|
||||
# No entities were created but pathing was successful - get existing group at target
|
||||
return self._get_existing_connection_group(
|
||||
target_pos, list(connection_types)[0], target
|
||||
)
|
||||
else:
|
||||
return connection
|
||||
@@ -289,18 +280,9 @@ class ConnectEntities(Tool):
|
||||
if connection and len(connection) > 0:
|
||||
return connection[0]
|
||||
else:
|
||||
# No entities were created but pathing was successful - find existing connecting entity
|
||||
existing_entity = self._find_existing_connecting_entity(
|
||||
target_pos, list(connection_types)[0]
|
||||
)
|
||||
return (
|
||||
existing_entity
|
||||
if existing_entity
|
||||
else (
|
||||
target
|
||||
if isinstance(target, (Entity, EntityGroup))
|
||||
else target_pos
|
||||
)
|
||||
# No entities were created but pathing was successful - get existing group at target
|
||||
return self._get_existing_connection_group(
|
||||
target_pos, list(connection_types)[0], target
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
@@ -797,7 +779,23 @@ class ConnectEntities(Tool):
|
||||
self.rotate_end_belt_to_face(entity_groups[0], target_entity)
|
||||
# self.rotate_final_belt_when_connecting_groups(entity_groups[0], source_entity)
|
||||
|
||||
# Return the properly grouped entities
|
||||
# Get final groups and filter to relevant one
|
||||
entity_groups = self.get_entities(
|
||||
{
|
||||
Prototype.TransportBelt,
|
||||
Prototype.ExpressTransportBelt,
|
||||
Prototype.FastTransportBelt,
|
||||
Prototype.UndergroundBelt,
|
||||
Prototype.FastUndergroundBelt,
|
||||
Prototype.ExpressUndergroundBelt,
|
||||
},
|
||||
source_pos,
|
||||
)
|
||||
|
||||
for group in entity_groups:
|
||||
if source_pos in [entity.position for entity in group.belts]:
|
||||
return cast(List[BeltGroup], [group])
|
||||
|
||||
return cast(List[BeltGroup], entity_groups)
|
||||
|
||||
def _update_belt_group(
|
||||
@@ -918,15 +916,16 @@ class ConnectEntities(Tool):
|
||||
self, groupable_entities: List[Entity], source_pos: Position
|
||||
) -> List[PipeGroup]:
|
||||
"""Process pipe groups"""
|
||||
# Group the passed entities first
|
||||
entity_groups = agglomerate_groupable_entities(groupable_entities)
|
||||
entity_groups = self.get_entities(
|
||||
{Prototype.Pipe, Prototype.UndergroundPipe}, source_pos
|
||||
)
|
||||
|
||||
# Deduplicate pipes in groups
|
||||
for group in entity_groups:
|
||||
if hasattr(group, "pipes"):
|
||||
group.pipes = _deduplicate_entities(group.pipes)
|
||||
group.pipes = _deduplicate_entities(group.pipes)
|
||||
if source_pos in [entity.position for entity in group.pipes]:
|
||||
return [group]
|
||||
|
||||
return cast(List[PipeGroup], entity_groups)
|
||||
return entity_groups
|
||||
|
||||
def _process_groups(
|
||||
self,
|
||||
@@ -979,9 +978,17 @@ class ConnectEntities(Tool):
|
||||
self, groupable_entities: List[Entity], source_pos: Position
|
||||
) -> List[ElectricityGroup]:
|
||||
"""Process power pole groups"""
|
||||
# Group the passed entities first
|
||||
entity_groups = agglomerate_groupable_entities(groupable_entities)
|
||||
return cast(List[ElectricityGroup], entity_groups)
|
||||
return cast(
|
||||
List[ElectricityGroup],
|
||||
self.get_entities(
|
||||
{
|
||||
Prototype.SmallElectricPole,
|
||||
Prototype.BigElectricPole,
|
||||
Prototype.MediumElectricPole,
|
||||
},
|
||||
source_pos,
|
||||
),
|
||||
)
|
||||
|
||||
def _adjust_belt_position(
|
||||
self, pos: Position, entity: Optional[Entity]
|
||||
@@ -1242,42 +1249,72 @@ class ConnectEntities(Tool):
|
||||
entities = self.get_entities(position=pos, radius=radius)
|
||||
return bool(entities)
|
||||
|
||||
def _find_existing_connecting_entity(
|
||||
self, target_pos: Position, connection_type: Prototype
|
||||
) -> Optional[Union[Entity, EntityGroup]]:
|
||||
def _get_existing_connection_group(
|
||||
self, target_pos: Position, connection_type: Prototype, target_entity
|
||||
) -> Union[Entity, EntityGroup, Position]:
|
||||
"""
|
||||
Find existing entity of the connection type closest to the target position.
|
||||
This is used when no new entities were created but connection was successful.
|
||||
Get existing connection group when no new entities were created.
|
||||
This handles cases where entities are already connected.
|
||||
"""
|
||||
# Determine search radius based on connection type
|
||||
if connection_type in (
|
||||
Prototype.SmallElectricPole,
|
||||
Prototype.MediumElectricPole,
|
||||
Prototype.BigElectricPole,
|
||||
):
|
||||
search_radius = 10 # Power poles have larger reach
|
||||
else:
|
||||
search_radius = 5 # Belts, pipes have smaller reach
|
||||
try:
|
||||
# Try to get existing groups of the connection type at target position
|
||||
if connection_type in (
|
||||
Prototype.SmallElectricPole,
|
||||
Prototype.MediumElectricPole,
|
||||
Prototype.BigElectricPole,
|
||||
):
|
||||
# For power poles, get electricity groups
|
||||
groups = self.get_entities(
|
||||
{
|
||||
Prototype.SmallElectricPole,
|
||||
Prototype.MediumElectricPole,
|
||||
Prototype.BigElectricPole,
|
||||
},
|
||||
target_pos,
|
||||
radius=10,
|
||||
)
|
||||
elif connection_type in (
|
||||
Prototype.TransportBelt,
|
||||
Prototype.FastTransportBelt,
|
||||
Prototype.ExpressTransportBelt,
|
||||
):
|
||||
# For belts, get belt groups
|
||||
groups = self.get_entities(
|
||||
{
|
||||
Prototype.TransportBelt,
|
||||
Prototype.FastTransportBelt,
|
||||
Prototype.ExpressTransportBelt,
|
||||
Prototype.UndergroundBelt,
|
||||
Prototype.FastUndergroundBelt,
|
||||
Prototype.ExpressUndergroundBelt,
|
||||
},
|
||||
target_pos,
|
||||
radius=5,
|
||||
)
|
||||
elif connection_type in (Prototype.Pipe, Prototype.UndergroundPipe):
|
||||
# For pipes, get pipe groups
|
||||
groups = self.get_entities(
|
||||
{Prototype.Pipe, Prototype.UndergroundPipe}, target_pos, radius=5
|
||||
)
|
||||
else:
|
||||
groups = []
|
||||
|
||||
# Get entities of the connection type near target
|
||||
entities = self.get_entities(connection_type, target_pos, search_radius)
|
||||
if groups:
|
||||
# Return the first group found
|
||||
return groups[0]
|
||||
|
||||
if not entities:
|
||||
return None
|
||||
# If no groups found, return the target entity if it's an entity/group
|
||||
if isinstance(target_entity, (Entity, EntityGroup)):
|
||||
return target_entity
|
||||
|
||||
# Find closest entity to target position
|
||||
closest_entity = None
|
||||
closest_distance = float("inf")
|
||||
# Fall back to returning the target position
|
||||
return target_pos
|
||||
|
||||
for entity in entities:
|
||||
entity_pos = entity.position if hasattr(entity, "position") else None
|
||||
if entity_pos:
|
||||
distance = target_pos.distance(entity_pos)
|
||||
if distance < closest_distance:
|
||||
closest_distance = distance
|
||||
closest_entity = entity
|
||||
|
||||
return closest_entity
|
||||
except Exception:
|
||||
# If anything goes wrong, fall back to target entity or position
|
||||
if isinstance(target_entity, (Entity, EntityGroup)):
|
||||
return target_entity
|
||||
return target_pos
|
||||
|
||||
def pickup_entities(
|
||||
self,
|
||||
|
35
fle/env/tools/agent/get_entities/client.py
vendored
35
fle/env/tools/agent/get_entities/client.py
vendored
@@ -146,7 +146,13 @@ class GetEntities(Tool):
|
||||
# Group entities when:
|
||||
# 1. User explicitly requests group types, OR
|
||||
# 2. User provides a position filter (suggesting they want nearby entities grouped), OR
|
||||
# 3. No specific entities requested (get all entities - should be grouped)
|
||||
# 3. No specific entities requested (get all entities - should be grouped), OR
|
||||
# 4. User requests individual pole entities (restore original behavior - poles are always grouped)
|
||||
pole_types = {
|
||||
Prototype.SmallElectricPole,
|
||||
Prototype.MediumElectricPole,
|
||||
Prototype.BigElectricPole,
|
||||
}
|
||||
should_group = (
|
||||
not entities # No filter = group everything
|
||||
or any(
|
||||
@@ -161,6 +167,9 @@ class GetEntities(Tool):
|
||||
or (
|
||||
entities and position is not None
|
||||
) # Individual entities with position filter = group for convenience
|
||||
or any(
|
||||
proto in pole_types for proto in entities
|
||||
) # Individual pole entities - always group
|
||||
)
|
||||
|
||||
if should_group:
|
||||
@@ -223,7 +232,14 @@ class GetEntities(Tool):
|
||||
for entity in entities_list:
|
||||
# Check entity prototype or group type
|
||||
if hasattr(entity, "prototype") and entity.prototype in entities:
|
||||
filtered_entities.append(entity)
|
||||
# Exclude power poles from individual handling - they should be handled as groups
|
||||
pole_types = {
|
||||
Prototype.SmallElectricPole,
|
||||
Prototype.MediumElectricPole,
|
||||
Prototype.BigElectricPole,
|
||||
}
|
||||
if entity.prototype not in pole_types:
|
||||
filtered_entities.append(entity)
|
||||
elif hasattr(entity, "__class__"):
|
||||
# Handle group entities
|
||||
if entity.__class__.__name__ == "ElectricityGroup":
|
||||
@@ -241,17 +257,10 @@ class GetEntities(Tool):
|
||||
):
|
||||
# Individual poles requested with position - return group for convenience
|
||||
filtered_entities.append(entity)
|
||||
elif (
|
||||
any(pole_type in entities for pole_type in pole_types)
|
||||
and position is None
|
||||
):
|
||||
# Individual poles requested without position - extract individual poles from group
|
||||
for pole in entity.poles:
|
||||
if (
|
||||
hasattr(pole, "prototype")
|
||||
and pole.prototype in entities
|
||||
):
|
||||
filtered_entities.append(pole)
|
||||
elif any(pole_type in entities for pole_type in pole_types):
|
||||
# Individual poles requested - return group (restores original behavior)
|
||||
# Power poles are inherently networked, so groups are more useful than individuals
|
||||
filtered_entities.append(entity)
|
||||
elif entity.__class__.__name__ == "PipeGroup":
|
||||
pipe_types = {Prototype.Pipe, Prototype.UndergroundPipe}
|
||||
if Prototype.PipeGroup in group_requests:
|
||||
|
Reference in New Issue
Block a user