one-file-projects/astar.py

109 lines
2.4 KiB
Python

from math import sqrt
class PriorityQueue:
def __init__(self):
self.nlist = []
self.plookup = {}
def put(self, node, priority):
if node not in self.nlist:
self.nlist.append(node)
self.plookup[node] = priority
self.__sort()
def __sort(self):
self.nlist.sort(key=lambda node: self.plookup[node])
def __contains__(self, node):
return self.nlist.__contains__(node)
def __getitem__(self, key):
return self.nlist.__getitem__(key)
def get(self):
m = self.nlist[0]
del self.nlist[0]
del self.plookup[m]
return m
def empty(self):
return len(self.nlist) == 0
class AStar:
def solve(self, begin, end):
#set up data structures for search
glookup = {}
pdlookup = {}
openlist = PriorityQueue()
closedlist = []
openlist.put(begin,0)
glookup[begin] = 0
while not openlist.empty():
currentNode = openlist.get()
if currentNode == end:
return self.__path(pdlookup,begin,end)
closedlist.append(currentNode)
self.__expandNode(currentNode,glookup, pdlookup, openlist, closedlist, end)
return None
def __path(self,pdlookup,begin,end):
route = [end]
node = end
while node != begin:
node = pdlookup[node]
route.append(node)
return list(reversed(route))
def __expandNode(self,node, glookup, pdlookup, openlist, closedlist, end):
for successor in self.get_successors(node):
if successor in closedlist:
continue
tentative_g = glookup[node] + self.cost(node, successor)
if successor in openlist and glookup[successor] <= tentative_g:
continue
pdlookup[successor] = node
glookup[successor] = tentative_g
openlist.put(successor, tentative_g + self.esteemed_cost(successor,end))
def cost(self, node1, node2):
raise "cost must be overwritten"
def esteemed_cost(self, node, end):
raise "esteemed_cost must be overwritten"
def get_successors(self, node):
raise "get_successors must be overwritten"
if __name__ == '__main__':
#Testing
def addnodes(*nodes):
result = (0,0)
for node in nodes:
result = (result[0] + node[0], result[1]+node[1])
return result
class Test(AStar):
def cost(self, node1, node2):
return 1
def esteemed_cost(self, node, end):
return sqrt( (node[0]-end[0])**2 + (node[1]-end[1])**2 )
def get_successors(self, node):
return map(lambda offset: addnodes(offset, node), [(1,0),(0,1),(-1,0),(0,-1),(1,1),(1,-1),(-1,-1),(-1,1)])
test = Test()
print test.solve( (1,1), (3,5) )