diff --git a/graph.py b/graph.py index 06bc21d..acc6230 100644 --- a/graph.py +++ b/graph.py @@ -83,6 +83,28 @@ def dfs(graph, node, visited=None): for node in dfs(graph, neighbor, visited=visited): yield node +def prim(graph,r=None): + if r is None: + r = list(graph.nodes)[0] + Vt = set(); + Vt.add(r); + d = { node: Infinity() for node in graph.nodes } + d[r] = 0; + for _, neighbor, distance in graph.all_edges_of(r): + d[neighbor] = distance + + while Vt < graph.nodes: + u = min(graph.nodes - Vt, key=lambda node: d[node]) + if d[u] == Infinity(): + break + Vt.add(u) + for _, neighbor, distance in graph.all_edges_of(u): + if neighbor not in Vt: + d[neighbor] = min([ d[neighbor], distance ]) + + return d; + + def main(): g = Graph() for i in range(1,6): @@ -100,6 +122,10 @@ def main(): print(d) + d = prim(g) + + print(d) + print( list(bfs(g,1)) ) print( list(dfs(g,1)) ) @@ -114,6 +140,26 @@ def main(): print( list(bfs(g2,1)) ) print( list(dfs(g2,1)) ) + g3 = Graph() + g3.add_node(1) + g3.add_node(2) + g3.add_node(3) + g3.add_node(4) + g3.add_node(5) + g3.add_node(6) + + g3.add_edge(1,2,20) + g3.add_edge(1,3,10) + g3.add_edge(2,4,30) + g3.add_edge(2,5,20) + g3.add_edge(3,4,10) + g3.add_edge(3,5,20) + g3.add_edge(4,6,10) + g3.add_edge(5,6,30) + + d = prim(g3,1) + print(d) + if __name__ == '__main__':