Files
factorio-learning-environment/fle/env/tools/agent/get_entities/client.py

329 lines
15 KiB
Python

from time import sleep
from typing import List, Set, Union
from fle.env.entities import Position, Entity, EntityGroup
from fle.env.game_types import Prototype
from fle.env.tools.agent.connect_entities.groupable_entities import (
agglomerate_groupable_entities,
)
from fle.env.tools import Tool
class GetEntities(Tool):
def __init__(self, connection, game_state):
super().__init__(connection, game_state)
def __call__(
self,
entities: Union[Set[Prototype], Prototype] = set(),
position: Position = None,
radius: float = 1000,
) -> List[Union[Entity, EntityGroup]]:
"""
Get entities within a radius of a given position.
:param entities: Set of entity prototypes to filter by. If empty, all entities are returned.
:param position: Position to search around. Can be a Position object or "player" for player's position.
:param radius: Radius to search within.
:return: Found entities
"""
try:
if not isinstance(position, Position) and position is not None:
raise ValueError("The second argument must be a Position object")
if not isinstance(entities, Set):
entities = set([entities])
# Handle group prototypes by expanding them to their component types
expanded_entities = set()
group_requests = set()
for entity in entities:
if entity == Prototype.BeltGroup:
# For belt groups, search for all belt types
belt_types = {
Prototype.TransportBelt,
Prototype.FastTransportBelt,
Prototype.ExpressTransportBelt,
Prototype.UndergroundBelt,
Prototype.FastUndergroundBelt,
Prototype.ExpressUndergroundBelt,
}
expanded_entities.update(belt_types)
group_requests.add(Prototype.BeltGroup)
elif entity == Prototype.PipeGroup:
# For pipe groups, search for pipe types
pipe_types = {Prototype.Pipe, Prototype.UndergroundPipe}
expanded_entities.update(pipe_types)
group_requests.add(Prototype.PipeGroup)
elif entity == Prototype.ElectricityGroup:
# For electricity groups, search for pole types
pole_types = {
Prototype.SmallElectricPole,
Prototype.MediumElectricPole,
Prototype.BigElectricPole,
}
expanded_entities.update(pole_types)
group_requests.add(Prototype.ElectricityGroup)
else:
expanded_entities.add(entity)
# Use expanded entities for the Lua query
query_entities = expanded_entities
# Serialize entity_names as a string
entity_names = (
"["
+ ",".join([f'"{entity.value[0]}"' for entity in query_entities])
+ "]"
if query_entities
else "[]"
)
# We need to add a small 50ms sleep to ensure that the entities have updated after previous actions
sleep(0.05)
if position is None:
response, time_elapsed = self.execute(
self.player_index, radius, entity_names
)
else:
response, time_elapsed = self.execute(
self.player_index, radius, entity_names, position.x, position.y
)
if not response:
return []
if (not isinstance(response, dict) and not response) or isinstance(
response, str
): # or (isinstance(response, dict) and not response):
raise Exception("Could not get entities", response)
entities_list = []
for raw_entity_data in response:
if isinstance(raw_entity_data, list):
continue
entity_data = self.clean_response(raw_entity_data)
# Find the matching Prototype
matching_prototype = None
for prototype in Prototype:
if prototype.value[0] == entity_data["name"].replace("_", "-"):
matching_prototype = prototype
break
if matching_prototype is None:
print(
f"Warning: No matching Prototype found for {entity_data['name']}"
)
continue
# Apply standard filtering
if entities and matching_prototype not in entities:
continue
metaclass = matching_prototype.value[1]
while isinstance(metaclass, tuple):
metaclass = metaclass[1]
# Process nested dictionaries (like inventories)
for key, value in entity_data.items():
if isinstance(value, dict):
entity_data[key] = self.process_nested_dict(value)
entity_data["prototype"] = matching_prototype
# remove all empty values from the entity_data dictionary
entity_data = {
k: v for k, v in entity_data.items() if v or isinstance(v, int)
}
try:
entity = metaclass(**entity_data)
entities_list.append(entity)
except Exception as e1:
print(f"Could not create {entity_data['name']} object: {e1}")
# Group entities when:
# 1. User explicitly requests group types, OR
# 2. User provides a position filter (suggesting they want nearby entities grouped), OR
# 3. No specific entities requested (get all entities - should be grouped)
should_group = (
not entities # No filter = group everything
or any(
proto
in {
Prototype.ElectricityGroup,
Prototype.PipeGroup,
Prototype.BeltGroup,
}
for proto in entities
) # Explicit group request
or (
entities and position is not None
) # Individual entities with position filter = group for convenience
)
if should_group:
# get all pipes into a list
pipes = [
entity
for entity in entities_list
if hasattr(entity, "prototype")
and entity.prototype in (Prototype.Pipe, Prototype.UndergroundPipe)
]
group = agglomerate_groupable_entities(pipes)
[entities_list.remove(pipe) for pipe in pipes]
entities_list.extend(group)
poles = [
entity
for entity in entities_list
if hasattr(entity, "prototype")
and entity.prototype
in (
Prototype.SmallElectricPole,
Prototype.BigElectricPole,
Prototype.MediumElectricPole,
)
]
group = agglomerate_groupable_entities(poles)
[entities_list.remove(pole) for pole in poles]
entities_list.extend(group)
walls = [
entity
for entity in entities_list
if hasattr(entity, "prototype")
and entity.prototype == Prototype.StoneWall
]
group = agglomerate_groupable_entities(walls)
[entities_list.remove(wall) for wall in walls]
entities_list.extend(group)
belt_types = (
Prototype.TransportBelt,
Prototype.FastTransportBelt,
Prototype.ExpressTransportBelt,
Prototype.UndergroundBelt,
Prototype.FastUndergroundBelt,
Prototype.ExpressUndergroundBelt,
)
belts = [
entity
for entity in entities_list
if hasattr(entity, "prototype") and entity.prototype in belt_types
]
group = agglomerate_groupable_entities(belts)
[entities_list.remove(belt) for belt in belts]
entities_list.extend(group)
# Final filtering after grouping is complete
if entities:
filtered_entities = []
for entity in entities_list:
# Check entity prototype or group type
if hasattr(entity, "prototype") and entity.prototype in entities:
filtered_entities.append(entity)
elif hasattr(entity, "__class__"):
# Handle group entities
if entity.__class__.__name__ == "ElectricityGroup":
pole_types = {
Prototype.SmallElectricPole,
Prototype.MediumElectricPole,
Prototype.BigElectricPole,
}
if Prototype.ElectricityGroup in group_requests:
# Explicit group request - return the group
filtered_entities.append(entity)
elif (
any(pole_type in entities for pole_type in pole_types)
and position is not None
):
# Individual poles requested with position - return group for convenience
filtered_entities.append(entity)
elif (
any(pole_type in entities for pole_type in pole_types)
and position is None
):
# Individual poles requested without position - extract individual poles from group
for pole in entity.poles:
if (
hasattr(pole, "prototype")
and pole.prototype in entities
):
filtered_entities.append(pole)
elif entity.__class__.__name__ == "PipeGroup":
pipe_types = {Prototype.Pipe, Prototype.UndergroundPipe}
if Prototype.PipeGroup in group_requests:
# Explicit group request - return the group
filtered_entities.append(entity)
elif (
any(pipe_type in entities for pipe_type in pipe_types)
and position is not None
):
# Individual pipes requested with position - return group for convenience
filtered_entities.append(entity)
elif (
any(pipe_type in entities for pipe_type in pipe_types)
and position is None
):
# Individual pipes requested without position - extract individual pipes from group
for pipe in entity.pipes:
if (
hasattr(pipe, "prototype")
and pipe.prototype in entities
):
filtered_entities.append(pipe)
elif entity.__class__.__name__ == "BeltGroup":
belt_types = {
Prototype.TransportBelt,
Prototype.FastTransportBelt,
Prototype.ExpressTransportBelt,
Prototype.UndergroundBelt,
Prototype.FastUndergroundBelt,
Prototype.ExpressUndergroundBelt,
}
if Prototype.BeltGroup in group_requests:
# Explicit group request - return the group
filtered_entities.append(entity)
elif (
any(belt_type in entities for belt_type in belt_types)
and position is not None
):
# Individual belts requested with position - return group for convenience
filtered_entities.append(entity)
elif (
any(belt_type in entities for belt_type in belt_types)
and position is None
):
# Individual belts requested without position - extract individual belts from group
for belt in entity.belts:
if (
hasattr(belt, "prototype")
and belt.prototype in entities
):
filtered_entities.append(belt)
elif entity.__class__.__name__ == "WallGroup":
# WallGroup doesn't have a corresponding Prototype, but include if present
filtered_entities.append(entity)
entities_list = filtered_entities
return entities_list
except Exception as e:
raise Exception(f"Error in GetEntities: {e}")
def process_nested_dict(self, nested_dict):
"""Helper method to process nested dictionaries"""
if isinstance(nested_dict, dict):
if all(isinstance(key, int) for key in nested_dict.keys()):
return [
self.process_nested_dict(value) for value in nested_dict.values()
]
else:
return {
key: self.process_nested_dict(value)
for key, value in nested_dict.items()
}
return nested_dict