-
-
Notifications
You must be signed in to change notification settings - Fork 7.3k
Expand file tree
/
Copy pathConstraintTree.py
More file actions
157 lines (131 loc) · 6.65 KB
/
ConstraintTree.py
File metadata and controls
157 lines (131 loc) · 6.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from dataclasses import dataclass
from typing import Optional, TypeAlias
import heapq
from PathPlanning.TimeBasedPathPlanning.Node import NodePath, Position, PositionAtTime
AgentId: TypeAlias = int
@dataclass(frozen=True)
class Constraint:
position: Position
time: int
@dataclass
class ConstrainedAgent:
agent: AgentId
constraint: Constraint
@dataclass
class ForkingConstraint:
constrained_agents: tuple[ConstrainedAgent, ConstrainedAgent]
@dataclass(frozen=True)
class AppliedConstraint:
constraint: Constraint
constrained_agent: AgentId
@dataclass
class ConstraintTreeNode:
parent_idx = int
constraint: Optional[ForkingConstraint | AppliedConstraint]
paths: dict[AgentId, NodePath]
cost: int
def __init__(self, paths: dict[AgentId, NodePath], parent_idx: int, all_constraints: list[AppliedConstraint]):
self.paths = paths
self.cost = sum(path.goal_reached_time() for path in paths.values())
self.parent_idx = parent_idx
self.constraint = self.get_constraint_point()
self.all_constraints = all_constraints
def __lt__(self, other) -> bool:
if self.cost == other.cost:
return len(self.all_constraints) < len(other.all_constraints)
return self.cost < other.cost
def get_constraint_point(self, verbose = False) -> Optional[ForkingConstraint]:
final_t = max(path.goal_reached_time() for path in self.paths.values())
positions_at_time: dict[PositionAtTime, AgentId] = {}
for t in range(final_t + 1):
possible_constraints: list[ForkingConstraint] = []
for agent_id, path in self.paths.items():
# Check for edge conflicts
last_position = None
if t > 0:
last_position = path.get_position(t - 1)
position = path.get_position(t)
if position is None:
continue
position_at_time = PositionAtTime(position, t)
if position_at_time not in positions_at_time:
positions_at_time[position_at_time] = AgentId(agent_id)
# edge conflict
if last_position:
new_position_at_last_time = PositionAtTime(position, t-1)
old_position_at_new_time = PositionAtTime(last_position, t)
if new_position_at_last_time in positions_at_time and old_position_at_new_time in positions_at_time:
conflicting_agent_id1 = positions_at_time[new_position_at_last_time]
conflicting_agent_id2 = positions_at_time[old_position_at_new_time]
if conflicting_agent_id1 == conflicting_agent_id2 and conflicting_agent_id1 != agent_id:
if verbose:
print(f"Found edge constraint between with agent {conflicting_agent_id1} for {agent_id}")
print(f"\tpositions old: {old_position_at_new_time}, new: {position_at_time}")
new_constraint = ForkingConstraint((
ConstrainedAgent(agent_id, position_at_time),
ConstrainedAgent(conflicting_agent_id1, Constraint(position=last_position, time=t))
))
possible_constraints.append(new_constraint)
continue
# double reservation at a (cell, time) combination
if positions_at_time[position_at_time] != agent_id:
conflicting_agent_id = positions_at_time[position_at_time]
constraint = Constraint(position=position, time=t)
possible_constraints.append(ForkingConstraint((
ConstrainedAgent(agent_id, constraint), ConstrainedAgent(conflicting_agent_id, constraint)
)))
continue
if possible_constraints:
if verbose:
print(f"Choosing best constraint of {possible_constraints}")
# first check for edge constraints
for constraint in possible_constraints:
if constraint.constrained_agents[0].constraint.position != constraint.constrained_agents[1].constraint.position:
if verbose:
print(f"\tFound edge conflict constraint: {constraint}")
return constraint
# if none, then return first normal constraint
if verbose:
print(f"\tReturning normal constraint: {possible_constraints[0]}")
return possible_constraints[0]
return None
class ConstraintTree:
# Child nodes have been created (Maps node_index to ConstraintTreeNode)
expanded_nodes: dict[int, ConstraintTreeNode]
# Need to solve and generate children
nodes_to_expand: heapq #[ConstraintTreeNode]
def __init__(self, initial_solution: dict[AgentId, NodePath]):
self.nodes_to_expand = []
self.expanded_nodes = {}
heapq.heappush(self.nodes_to_expand, ConstraintTreeNode(initial_solution, -1, []))
def get_next_node_to_expand(self) -> Optional[ConstraintTreeNode]:
if not self.nodes_to_expand:
return None
return heapq.heappop(self.nodes_to_expand)
"""
Add a node to the tree and generate children if needed. Returns true if the node is a solution, false otherwise.
"""
def add_node_to_tree(self, node: ConstraintTreeNode) -> bool:
heapq.heappush(self.nodes_to_expand, node)
"""
Get the constraints that were applied to all parent nodes starting with the node at the provided parent_index.
"""
def get_ancestor_constraints(self, parent_index: int) -> list[AppliedConstraint]:
constraints: list[AppliedConstraint] = []
while parent_index != -1:
node = self.expanded_nodes[parent_index]
if node.constraint and isinstance(node.constraint, AppliedConstraint):
constraints.append(node.constraint)
else:
raise RuntimeError(f"Expected AppliedConstraint, but got: {node.constraint}")
parent_index = node.parent_idx
return constraints
"""
Add an expanded node to the tree. Returns the index of this node in the expanded nodes dictionary.
"""
def add_expanded_node(self, node: ConstraintTreeNode) -> int:
node_idx = len(self.expanded_nodes)
self.expanded_nodes[node_idx] = node
return node_idx
def expanded_node_count(self) -> int:
return len(self.expanded_nodes)