A*の実装
はじめに
A*を実装してみた。 実装したのはA* だが、ヒューリスティック関数が常に0を返すため、動作はダイクストラ法と変わらない。
A*とは
A*は経路探索でよく用いられるアルゴリズム。 ダイクストラ法に「現在の点から終点までの推定コスト」を追加することで、より効率的に最短経路を見つけることが可能になる。 その推定コストを計算する関数をヒューリスティック関数というが、その設計がよくないと効率的にはならない。 詳しくは調べたほうが早い。 気が向いたら書く。
ソースコード
ほとんど出席しなかった大学の授業で札幌駅-新千歳空港駅間の最安乗り換えの話がでていたので、 最安乗り換えの一つを求めるプログラムを自分で実装してみた。
#!/usr/bin/env python3 # -*- coding: utf-8 -*- # A* search algorithm # reference: https://en.wikipedia.org/wiki/A*_search_algorithm import sys from collections import defaultdict from dataclasses import dataclass, field from typing import Any, Dict, List from heapq import heappush, heappop @dataclass(order=True) class PrioritizedItem(object): priority: int item: Any=field(compare=False) @dataclass class PriorityQueue(object): que: List['PrioritizedItem']=field(default_factory=list, init=False) def push(self, priority, item): heappush(self.que, PrioritizedItem(priority, item)) def pop(self): return heappop(self.que).item def top(self): return self.que[0].item def __len__(self): return len(self.que) INF = 10000000 NULL = 0 OPEN = 1 CLOSE = 2 @dataclass class Vertex(object): name: str=field(default=None) costs: Dict[int, int]=field(default_factory=dict, init=False) # other fields for the heuristic function def estimate_cost(u, v): return 0 # Dijkstra's algorithm def reconstruct_path(paths, current_id): path_ids = [current_id] while current_id in paths: current_id = paths[current_id] path_ids.append(current_id) path_ids.reverse() return path_ids def a_star(vertices, start_id, goal_id): paths = {} goal = vertices[goal_id] start = vertices[start_id] states = [NULL] * len(vertices) states[start_id] = OPEN start_priority = estimate_cost(start, goal) priorities = { start_id:start_priority } costs = defaultdict(lambda: INF) costs[start_id] = 0 que = PriorityQueue() que.push(start_priority, start_id) while len(que): current_id = que.pop() if states[current_id] == CLOSE: continue current_cost = costs[current_id] if current_id == goal_id: return current_cost, reconstruct_path(paths, current_id) states[current_id] = CLOSE current = vertices[current_id] for neighbor_id, cur_nbr_cost in current.costs.items(): if states[neighbor_id] == CLOSE: continue neighbor = vertices[neighbor_id] neighbor_cost = current_cost + cur_nbr_cost if states[neighbor_id] != OPEN: states[neighbor_id] = OPEN elif neighbor_cost >= costs[neighbor_id]: continue paths[neighbor_id] = current_id costs[neighbor_id] = neighbor_cost priority = neighbor_cost + estimate_cost(neighbor, goal) priorities[neighbor_id] = priority que.push(priority, neighbor_id) return 0, list() def main(): n, start_id, goal_id = map(int, sys.stdin.readline().split()) vertices = [Vertex(sys.stdin.readline().rstrip()) for i in range(n)] while True: line_split = sys.stdin.readline().split() if len(line_split) == 1 and int(line_split[0]) == -1: break i, j, cost = map(int, line_split) vertices[i].costs[j] = cost vertices[j].costs[i] = cost total_cost, path_ids = a_star(vertices, start_id, goal_id) print('from:', vertices[start_id].name) print('to:', vertices[goal_id].name) print('cost:', total_cost) print('path:') for i in path_ids: v = vertices[i] print(v.name) if __name__ == '__main__': main()
実行例
2019-02-27現在の運賃を使って計算した。
入力
8 0 7 新千歳空港 南千歳 千歳 恵庭 北広島 新札幌 白石 札幌 0 1 310 0 2 350 0 3 400 0 4 590 0 5 880 0 6 980 0 7 1070 1 2 170 1 3 260 1 4 450 1 5 640 1 6 740 1 7 840 2 3 220 2 4 360 2 5 640 2 6 740 2 7 840 3 4 260 3 5 450 3 6 540 3 7 640 4 5 260 4 6 360 4 7 450 5 6 210 5 7 260 6 7 210 -1
実行結果
from: 新千歳空港 to: 札幌 cost: 1040 path: 新千歳空港 恵庭 札幌
30円安くなる。時は金なり。