A*の実装

はじめに

A*を実装してみた。 実装したのはA* だが、ヒューリスティック関数が常に0を返すため、動作はダイクストラ法と変わらない。

A*とは

A*は経路探索でよく用いられるアルゴリズムダイクストラ法に「現在の点から終点までの推定コスト」を追加することで、より効率的に最短経路を見つけることが可能になる。 その推定コストを計算する関数をヒューリスティック関数というが、その設計がよくないと効率的にはならない。 詳しくは調べたほうが早い。 気が向いたら書く。

ソースコード

ほとんど出席しなかった大学の授業で札幌駅-新千歳空港駅間の最安乗り換えの話がでていたので、 最安乗り換えの一つを求めるプログラムを自分で実装してみた。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# A* search algorithm
# reference: https://en.wikipedia.org/wiki/A*_search_algorithm

import sys
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any, Dict, List
from heapq import heappush, heappop

@dataclass(order=True)
class PrioritizedItem(object):
    priority: int
    item: Any=field(compare=False)

@dataclass
class PriorityQueue(object):
    que: List['PrioritizedItem']=field(default_factory=list, init=False)

    def push(self, priority, item):
        heappush(self.que, PrioritizedItem(priority, item))

    def pop(self):
        return heappop(self.que).item

    def top(self):
        return self.que[0].item

    def __len__(self):
        return len(self.que)

INF = 10000000

NULL = 0
OPEN = 1
CLOSE = 2

@dataclass
class Vertex(object):
    name: str=field(default=None)
    costs: Dict[int, int]=field(default_factory=dict, init=False)
    # other fields for the heuristic function

def estimate_cost(u, v):
    return 0 # Dijkstra's algorithm

def reconstruct_path(paths, current_id):
    path_ids = [current_id]
    while current_id in paths:
        current_id = paths[current_id]
        path_ids.append(current_id)
    path_ids.reverse()
    return path_ids

def a_star(vertices, start_id, goal_id):
    paths = {}
    goal = vertices[goal_id]
    start = vertices[start_id]
    states = [NULL] * len(vertices)
    states[start_id] = OPEN
    start_priority = estimate_cost(start, goal)
    priorities = { start_id:start_priority }
    costs = defaultdict(lambda: INF)
    costs[start_id] = 0
    que = PriorityQueue()
    que.push(start_priority, start_id)
    while len(que):
        current_id = que.pop()
        if states[current_id] == CLOSE:
            continue
        current_cost = costs[current_id]
        if current_id == goal_id:
            return current_cost, reconstruct_path(paths, current_id)
        states[current_id] = CLOSE
        current = vertices[current_id]
        for neighbor_id, cur_nbr_cost in current.costs.items():
            if states[neighbor_id] == CLOSE:
                continue
            neighbor = vertices[neighbor_id]
            neighbor_cost = current_cost + cur_nbr_cost
            if states[neighbor_id] != OPEN:
                states[neighbor_id] = OPEN
            elif neighbor_cost >= costs[neighbor_id]:
                continue
            paths[neighbor_id] = current_id
            costs[neighbor_id] = neighbor_cost
            priority = neighbor_cost + estimate_cost(neighbor, goal)
            priorities[neighbor_id] = priority
            que.push(priority, neighbor_id)
    return 0, list()

def main():
    n, start_id, goal_id = map(int, sys.stdin.readline().split())
    vertices = [Vertex(sys.stdin.readline().rstrip()) for i in range(n)]
    while True:
        line_split = sys.stdin.readline().split()
        if len(line_split) == 1 and int(line_split[0]) == -1:
            break
        i, j, cost = map(int, line_split)
        vertices[i].costs[j] = cost
        vertices[j].costs[i] = cost

    total_cost, path_ids = a_star(vertices, start_id, goal_id)
    print('from:', vertices[start_id].name)
    print('to:', vertices[goal_id].name)
    print('cost:', total_cost)
    print('path:')
    for i in path_ids:
        v = vertices[i]
        print(v.name)

if __name__ == '__main__':
    main()

実行例

2019-02-27現在の運賃を使って計算した。

入力

8 0 7
新千歳空港
南千歳
千歳
恵庭
北広島
新札幌
白石
札幌
0 1 310
0 2 350
0 3 400
0 4 590
0 5 880
0 6 980
0 7 1070
1 2 170
1 3 260
1 4 450
1 5 640
1 6 740
1 7 840
2 3 220
2 4 360
2 5 640
2 6 740
2 7 840
3 4 260
3 5 450
3 6 540
3 7 640
4 5 260
4 6 360
4 7 450
5 6 210
5 7 260
6 7 210
-1

実行結果

from: 新千歳空港
to: 札幌
cost: 1040
path:
新千歳空港
恵庭
札幌

30円安くなる。時は金なり。

参考サイト

A* search algorithm - Wikipedia