fixes for tests

This commit is contained in:
Neel Kant
2025-08-31 23:29:26 -07:00
parent 3329cb899a
commit ca76e2c5d8
2 changed files with 125 additions and 79 deletions

View File

@@ -161,10 +161,10 @@ class ConnectEntities(Tool):
ticks_added = ticks_after - ticks_before
if ticks_added > 0:
game_speed = self.game_state.instance.get_speed()
real_world_sleep = (
real_world_sleep = ( # noqa
ticks_added / 60 / game_speed if game_speed > 0 else 0
)
sleep(real_world_sleep)
# sleep(real_world_sleep)
if dry_run:
return {
@@ -241,18 +241,9 @@ class ConnectEntities(Tool):
if connection and len(connection) > 0:
return connection[0]
else:
# No entities were created but pathing was successful - find existing connecting entity
existing_entity = self._find_existing_connecting_entity(
target_pos, list(connection_types)[0]
)
return (
existing_entity
if existing_entity
else (
target
if isinstance(target, (Entity, EntityGroup))
else target_pos
)
# No entities were created but pathing was successful - get existing group at target
return self._get_existing_connection_group(
target_pos, list(connection_types)[0], target
)
else:
return connection
@@ -289,18 +280,9 @@ class ConnectEntities(Tool):
if connection and len(connection) > 0:
return connection[0]
else:
# No entities were created but pathing was successful - find existing connecting entity
existing_entity = self._find_existing_connecting_entity(
target_pos, list(connection_types)[0]
)
return (
existing_entity
if existing_entity
else (
target
if isinstance(target, (Entity, EntityGroup))
else target_pos
)
# No entities were created but pathing was successful - get existing group at target
return self._get_existing_connection_group(
target_pos, list(connection_types)[0], target
)
except Exception:
continue
@@ -797,7 +779,23 @@ class ConnectEntities(Tool):
self.rotate_end_belt_to_face(entity_groups[0], target_entity)
# self.rotate_final_belt_when_connecting_groups(entity_groups[0], source_entity)
# Return the properly grouped entities
# Get final groups and filter to relevant one
entity_groups = self.get_entities(
{
Prototype.TransportBelt,
Prototype.ExpressTransportBelt,
Prototype.FastTransportBelt,
Prototype.UndergroundBelt,
Prototype.FastUndergroundBelt,
Prototype.ExpressUndergroundBelt,
},
source_pos,
)
for group in entity_groups:
if source_pos in [entity.position for entity in group.belts]:
return cast(List[BeltGroup], [group])
return cast(List[BeltGroup], entity_groups)
def _update_belt_group(
@@ -918,15 +916,16 @@ class ConnectEntities(Tool):
self, groupable_entities: List[Entity], source_pos: Position
) -> List[PipeGroup]:
"""Process pipe groups"""
# Group the passed entities first
entity_groups = agglomerate_groupable_entities(groupable_entities)
entity_groups = self.get_entities(
{Prototype.Pipe, Prototype.UndergroundPipe}, source_pos
)
# Deduplicate pipes in groups
for group in entity_groups:
if hasattr(group, "pipes"):
group.pipes = _deduplicate_entities(group.pipes)
group.pipes = _deduplicate_entities(group.pipes)
if source_pos in [entity.position for entity in group.pipes]:
return [group]
return cast(List[PipeGroup], entity_groups)
return entity_groups
def _process_groups(
self,
@@ -979,9 +978,17 @@ class ConnectEntities(Tool):
self, groupable_entities: List[Entity], source_pos: Position
) -> List[ElectricityGroup]:
"""Process power pole groups"""
# Group the passed entities first
entity_groups = agglomerate_groupable_entities(groupable_entities)
return cast(List[ElectricityGroup], entity_groups)
return cast(
List[ElectricityGroup],
self.get_entities(
{
Prototype.SmallElectricPole,
Prototype.BigElectricPole,
Prototype.MediumElectricPole,
},
source_pos,
),
)
def _adjust_belt_position(
self, pos: Position, entity: Optional[Entity]
@@ -1242,42 +1249,72 @@ class ConnectEntities(Tool):
entities = self.get_entities(position=pos, radius=radius)
return bool(entities)
def _find_existing_connecting_entity(
self, target_pos: Position, connection_type: Prototype
) -> Optional[Union[Entity, EntityGroup]]:
def _get_existing_connection_group(
self, target_pos: Position, connection_type: Prototype, target_entity
) -> Union[Entity, EntityGroup, Position]:
"""
Find existing entity of the connection type closest to the target position.
This is used when no new entities were created but connection was successful.
Get existing connection group when no new entities were created.
This handles cases where entities are already connected.
"""
# Determine search radius based on connection type
if connection_type in (
Prototype.SmallElectricPole,
Prototype.MediumElectricPole,
Prototype.BigElectricPole,
):
search_radius = 10 # Power poles have larger reach
else:
search_radius = 5 # Belts, pipes have smaller reach
try:
# Try to get existing groups of the connection type at target position
if connection_type in (
Prototype.SmallElectricPole,
Prototype.MediumElectricPole,
Prototype.BigElectricPole,
):
# For power poles, get electricity groups
groups = self.get_entities(
{
Prototype.SmallElectricPole,
Prototype.MediumElectricPole,
Prototype.BigElectricPole,
},
target_pos,
radius=10,
)
elif connection_type in (
Prototype.TransportBelt,
Prototype.FastTransportBelt,
Prototype.ExpressTransportBelt,
):
# For belts, get belt groups
groups = self.get_entities(
{
Prototype.TransportBelt,
Prototype.FastTransportBelt,
Prototype.ExpressTransportBelt,
Prototype.UndergroundBelt,
Prototype.FastUndergroundBelt,
Prototype.ExpressUndergroundBelt,
},
target_pos,
radius=5,
)
elif connection_type in (Prototype.Pipe, Prototype.UndergroundPipe):
# For pipes, get pipe groups
groups = self.get_entities(
{Prototype.Pipe, Prototype.UndergroundPipe}, target_pos, radius=5
)
else:
groups = []
# Get entities of the connection type near target
entities = self.get_entities(connection_type, target_pos, search_radius)
if groups:
# Return the first group found
return groups[0]
if not entities:
return None
# If no groups found, return the target entity if it's an entity/group
if isinstance(target_entity, (Entity, EntityGroup)):
return target_entity
# Find closest entity to target position
closest_entity = None
closest_distance = float("inf")
# Fall back to returning the target position
return target_pos
for entity in entities:
entity_pos = entity.position if hasattr(entity, "position") else None
if entity_pos:
distance = target_pos.distance(entity_pos)
if distance < closest_distance:
closest_distance = distance
closest_entity = entity
return closest_entity
except Exception:
# If anything goes wrong, fall back to target entity or position
if isinstance(target_entity, (Entity, EntityGroup)):
return target_entity
return target_pos
def pickup_entities(
self,

View File

@@ -146,7 +146,13 @@ class GetEntities(Tool):
# Group entities when:
# 1. User explicitly requests group types, OR
# 2. User provides a position filter (suggesting they want nearby entities grouped), OR
# 3. No specific entities requested (get all entities - should be grouped)
# 3. No specific entities requested (get all entities - should be grouped), OR
# 4. User requests individual pole entities (restore original behavior - poles are always grouped)
pole_types = {
Prototype.SmallElectricPole,
Prototype.MediumElectricPole,
Prototype.BigElectricPole,
}
should_group = (
not entities # No filter = group everything
or any(
@@ -161,6 +167,9 @@ class GetEntities(Tool):
or (
entities and position is not None
) # Individual entities with position filter = group for convenience
or any(
proto in pole_types for proto in entities
) # Individual pole entities - always group
)
if should_group:
@@ -223,7 +232,14 @@ class GetEntities(Tool):
for entity in entities_list:
# Check entity prototype or group type
if hasattr(entity, "prototype") and entity.prototype in entities:
filtered_entities.append(entity)
# Exclude power poles from individual handling - they should be handled as groups
pole_types = {
Prototype.SmallElectricPole,
Prototype.MediumElectricPole,
Prototype.BigElectricPole,
}
if entity.prototype not in pole_types:
filtered_entities.append(entity)
elif hasattr(entity, "__class__"):
# Handle group entities
if entity.__class__.__name__ == "ElectricityGroup":
@@ -241,17 +257,10 @@ class GetEntities(Tool):
):
# Individual poles requested with position - return group for convenience
filtered_entities.append(entity)
elif (
any(pole_type in entities for pole_type in pole_types)
and position is None
):
# Individual poles requested without position - extract individual poles from group
for pole in entity.poles:
if (
hasattr(pole, "prototype")
and pole.prototype in entities
):
filtered_entities.append(pole)
elif any(pole_type in entities for pole_type in pole_types):
# Individual poles requested - return group (restores original behavior)
# Power poles are inherently networked, so groups are more useful than individuals
filtered_entities.append(entity)
elif entity.__class__.__name__ == "PipeGroup":
pipe_types = {Prototype.Pipe, Prototype.UndergroundPipe}
if Prototype.PipeGroup in group_requests: