Friday, May 24, 2013

Tree iterator in Python

Sometimes you have a tree data structure and want to iterate over the nodes of the tree in breadth-first or depth-first order. This also seems to be a question that is occasionally posed in coding interviews. Since it is fun and especially easy to implement in Python let's do it. First things first. We design a Node class that represents the nodes of the tree.

  
class Node(object):

    def __init__(self, value, children=[]):
        self.value = value
        self.children = children

The constructor __init__ takes the value of the node and a list of child nodes that is empty by default. We can now construct a tree such as the following

  
root = Node(1, [Node(2, [Node(4), Node(5)]), Node(3)])

To be able to display the tree we add two method to the Node class. __str__ returns a string representation of a node and showTree performs a recursive, depth-first traversal of the tree and prints the node values indented by their hierarchical level.

  
class Node(object):

    def __init__(self, value, children=[]):
        self.value = value
        self.children = children

    def __str__(self):
        return str(self.value)

    def showTree(self, level=0):
        print "%s%s" % ("."*level, self)
        for node in self.children:
            node.showTree(level+2)

Running root.showTree() for the tree defined above produces the following output. 1 is the root node with two children 2 and 3, where node 2 has 4 and 5 as children.

1
..2
....4
....5
..3

Now, the iterator. There is two ways (actually there are more) to iterate over the nodes. Level by level (= breadth first) or the same way showTree works (depth first). Both are equally easy. The only difference is one implementation uses a stack while the other uses a queue. Let's start with breadth first. The following code implements an iterator class. As such it has to provide two methods: __iter__, which returns the class object itself and next, which returns the next node every time it is called.

  
class BreadthFirst(object):

    def __init__(self, node):
        self.queue = [node]

    def __iter__(self):
        return self

    def next(self):
        if not self.queue: raise StopIteration
        node = self.queue.pop(0)
        self.queue += node.children
        return node

The constructor takes a node and puts it into a list that we will treat as a queue. Therefore we name it "queue" and not "nodes" or even worse "list". What happens in the next method? When the queue is empty (not self.queue) a StopIteration exception is raised that stops the iterator. If the queue is not empty a node is dequeued (terrible word) via queue.pop(0) from the beginning of the list and the children of the node are added to the end of the list (enqueued). The iterator is then called as follows and returns the tree nodes in breadth-first order.

  
for node in BreadthFirst(root):
    print node,

1 3 2 5 4

Using lists as queues is not very efficient and there is a Queue class in Python, which is likely to be faster but the beauty lies in the fact that the implementation of the breadth-first iterator is mirror image of the depth-first iterator. As mentioned before, the only difference is that the list is now used as a stack:

  
class DepthFirst(object):

    def __init__(self, node):
        self.stack = [node]

    def __iter__(self):
        return self

    def next(self):
        if not self.stack: raise StopIteration
        node = self.stack.pop()
        self.stack += node.children
        return node

C'mon, admit it! This IS beautiful. Finally, using the iterator generates the following output

  
for node in DepthFirst(root):
    print node,
1 2 3 4 5

That's it.