分岐限定法

前回の「深さ優先探索」の続きです。

Javaによる知能プログラミング入門』の「2.探索とパターン照合」にある分岐限定法のソースコードをPythonで実装してみました。

有向グラフとして表現されている状態空間の、初期状態から目標状態までの探索を行います。

ノード間の□の中の数字は、各ノード間のコストを表します。
分岐限定法では、コストが最も小さいノードからたどっていきます。

ポイントはソースコードの次の部分です。
ノードのg_valueは、初期状態からそのノードへの最小コストを表します。

# node を展開して子節点をすべて求める
for m in node.children:
    # 子節点mがopenにもclosedにも含まれていなければ
    if m not in open and m not in closed:
        # m から node へのポインタを付ける.
        m.pointer = node
        # nodeまでの評価値とnode->mのコストを足したものをmの評価値とする
        m.g_value = node.g_value + node.get_cost(m)
        open.append(m)
    # 子節点mがopenに含まれているならば,
    if m in open:
        gmn = node.g_value + node.get_cost(m)
        if gmn < m.g_value:
            m.g_value = gmn
            m.pointer = node
# Nodeをg_valueの昇順(小さい順)に列べ換える
sort_upper_by_g_value(open)

探索するノードが見つかると、g_valueを計算します。
そして、g_valueでソートしてコストの小さいものから探索します。

#! /usr/bin/python
# -*- coding: Shift_JIS -*-

def printNodes(nodes):
    """
    Nodeのリストの出力用文字列を作成する
    nodes Nodeのリスト
    """
    return map(str, nodes)

class Node(object):
    def __init__(self, name):
        self.name = name
        self.children = [] #遷移できるノードのリスト
        self.children_costs = {} #<Node, Integer>
        self.pointer = None #解表示のためのポインタ
        self.__g_value = 0 # コスト
        self.has_g_value = False

    def get_g_value(self):
        return self.__g_value
    def set_g_value(self, value):
        self.has_g_value = True
        self.__g_value = value
    g_value = property(get_g_value, set_g_value)

    def add_child(self, child, cost):
        """
        @param child: この節点の子節点
        @param cost:  その子節点までのコスト
        """
        self.children.append(child)
        self.children_costs[child] = cost

    def get_cost(self, child):
        """
        子節点までのコストを取得する
        @param child: この節点の子節点
        @return: 子節点までのコスト
        """
        return self.children_costs[child]

    def __str__(self):
        result = self.name
        if self.has_g_value:
            result = "%s(g:%d)" % (result, self.__g_value)
        return result

def nodes_str(nodes):
    """
    Nodeのリストの出力用文字列を作成する
    @param nodes: Nodeのリスト
    """
    return map(str, nodes)

def make_state_space():
    """
    状態空間の生成
    @return: ノードのリストを返す。
    """
    node = [Node("L.A.Airport"),
            Node("UCLA"),
            Node("Hoolywood"),
            Node("Anaheim"),
            Node("GrandCanyon"),
            Node("SanDiego"),
            Node("Downtown"),
            Node("Pasadena"),
            Node("DisneyLand"),
            Node("Las Vegas")]
    node[0].add_child(node[1],1)
    node[0].add_child(node[2],3)
    node[1].add_child(node[2],1)
    node[1].add_child(node[6],6)
    node[2].add_child(node[3],6)
    node[2].add_child(node[6],6)
    node[2].add_child(node[7],3)
    node[3].add_child(node[4],5)
    node[3].add_child(node[7],2)
    node[3].add_child(node[8],4)
    node[4].add_child(node[8],2)
    node[4].add_child(node[9],1)
    node[5].add_child(node[1],1)
    node[6].add_child(node[5],7)
    node[6].add_child(node[7],2)
    node[7].add_child(node[8],3)
    node[7].add_child(node[9],7)
    node[8].add_child(node[9],5)
    return node

def print_solution(node):
    """
    解の表示
    """
    if node.pointer == None:
        print node
    else:
        print node, "<-",
        print_solution(node.pointer)

def sort_upper_by_g_value(open):
    """
    Nodeをg_valueの昇順(小さい順)に列べ換える
    @param open: Nodeのリスト
    """
    def g_value_compare(x, y):
        return x.g_value - y.g_value
    open.sort(g_value_compare)

def branch_and_bound(start, goal):
    """
    分岐限定法
    @param start: 探索開始ノード
    @param goal: 探索終了ノード
    """
    open = [start] #探索予定のノードのリスト
    start.g_value = 0
    closed = [] #探索を終了したノードのリスト
    success = False
    step = 0

    while 1:
        step += 1
        print "STEP:", step
        print "OPEN:", nodes_str(open)
        print "closed:", nodes_str(closed)

        if len(open) == 0:
            success = False
            break

        node = open.pop(0)
        if node == goal:
            success = True
            break

        closed.append(node)
        # node を展開して子節点をすべて求める.
        for m in node.children:
            # 子節点mがopenにもclosedにも含まれていなければ,
            if m not in open and m not in closed:
                # m から node へのポインタを付ける.
                m.pointer = node
                # nodeまでの評価値とnode->mのコストを足したものをmの評価値とする
                m.g_value = node.g_value + node.get_cost(m)
                open.append(m)
            # 子節点mがopenに含まれているならば,
            if m in open:
                gmn = node.g_value + node.get_cost(m)
                if gmn < m.g_value:
                    m.g_value = gmn
                    m.pointer = node
        # Nodeをg_valueの昇順(小さい順)に列べ換える
        sort_upper_by_g_value(open)
    if success:
        print "*** Solution ***"
        print_solution(goal)

if __name__ == "__main__":
    nodes = make_state_space()
    branch_and_bound(nodes[0], nodes[-1])

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です

このサイトはスパムを低減するために Akismet を使っています。コメントデータの処理方法の詳細はこちらをご覧ください