Iterators
Lua doesn’t have a lot of control structures. There’s the obvious if
statement, the while
loop and repeat
/until
loop, and the for
loop. Mostly, the for
gets used to iterate over tables:
t = {'a','b','c','d','e'} for i, v in ipairs(t) do print(i,t) end
It’s annoying to have to remember to type ipairs
every time. I’ve forgotten more than once. But, that minor annoyance is a good trade for the benefit of what the for
statement actually does: generic iteration.
What’s an iterator?
The for
statement isn’t just used for looping across tables, it loops through any sort of sequence represented by an iterator. An iterator in Lua is any function that conforms to a certain interface, and a lot of problems can be made simpler by writing your own iterators, instead of just using pairs
and ipairs
.
Let’s look at a standard loop:
t = {a=1, b=2, c=3, d=4} for k, v in pairs(t) do print(k, v) end
This is using the standard library iterator pairs
. Let’s take a look at what that call to pairs
actually returns:
pairs(t) -- function: 0x418590 table: 0x141b600 nil
Not to keep you in suspense, the table it returns is t
itself, and the function is the builtin function next
. What next
does is take a table and a key in the table, and return the “next” key / value. It’s guaranteed to go over all the keys, in no particular order:
next(t, nil) -- c 3 next(t, 'c') -- b 2 next(t, 'b') -- a 1 next(t, 'a') -- d 4 next(t, 'd') -- nil
So let’s write that original for
loop in a more explicit way:
for k, v in next, t, nil do print(k, v) end
Try it, it does the same thing. In fact, if you just want to go over part of the table, you can pass in an initial argument for next
:
for k, v in next, t, 'b' do print(k, v) end
That will print out just the ‘a’ and ‘d’ keys (yours may vary; next
iterates in an arbitrary order). That’s all the for
statement does: it takes a function, an “invariant” first argument (the table), and an initial second argument. It calls the function repeatedly, making the second argument of each call be the first return value of the previous call, until the function returns nil
.
Making an iterator
With that in mind, we can write our own iterators. Here’s one that loops over members of the triangular series:
function triangular(_, n) if not n then n = 0 end n = n + 1 return n, n * (n+1) / 2 end for n, v in triangular do print(v) if n == 10 then break end end
Since this iterator will never naturally end, we insert a break
statement after a while.
Note how easy calling the iterator is compared to writing it, not that writing it is very hard. This is pretty common with iterators; you spend some effort writing one in order to make the rest of the code simpler.
A more complex example: depth-first traversal
So let’s write some iterators that do actually-useful things, like traversing a tree. What we’d like to be able to do is visit every node of a tree, and get the value of that node and the path down to it from the root. For example:
tree = {a = {p = {p = {l = {e = {}}}}, n = {t = {}}}} for n, path in dft(tree) do print(n, inspect(path)) end
(I’m using inspect.lua to print the path, a really handy library).
This iterator is a little different from the other one: it has complex state, rather than just a single number, so we’ll have dft
just return a function that keeps the state in a closure. You can do that; for
doesn’t care if your invariant state is nil
, it’ll still pass it in every time but you can just ignore it. So, here’s the code:
function dft(tree) local value_stack = {} local node_stack = {} return function() -- These represent the current node: local value, node = value_stack[#value_stack], node_stack[#node_stack] -- Now, to find the next node: if not next(node_stack) then -- Node stack empty, push the root node: table.insert(value_stack, (next(tree))) table.insert(node_stack, tree) elseif next(node[value]) then -- Otherwise, if the current node has children, push them on to the stack: table.insert(value_stack, (next(node[value]))) table.insert(node_stack, node[value]) elseif next(node, value) then -- Otherwise, if there's a right sibling, alter the stack to show it: value_stack[#value_stack] = next(node, value) else -- Otherwise, pop the stack and find the next node of our parent while true do table.remove(node_stack) table.remove(value_stack) -- Must be the end of the tree: if not next(node_stack) then return nil end local value, node = value_stack[#value_stack], node_stack[#node_stack] if next(node, value) then value_stack[#value_stack] = next(node, value) break end end end -- Return the top of the value stack, and the current value stack return value_stack[#value_stack], value_stack end end
This is pretty straightforward. We keep a stack of the node labels / values we’ve visited, going back up to the root. We start with empty stacks, and find the next node in the traversal:
- If the stacks are empty, the next node is the root.
- If the current node (top of the stacks) has children, the next node is the first child.
- If the current node has no children but a next sibling, then it’s next.
- Finally, if none of those are true, we go to the previous node and look for a next sibling there.
After all that, the value stack has a path to the current node, so we return its value, and the value stack itself. When I run it as above, I get this:
a { "a" } n { "a", "n" } t { "a", "n", "t" } p { "a", "p" } p { "a", "p", "p" } l { "a", "p", "p", "l" } e { "a", "p", "p", "l", "e" }
Coroutine iterators
But, doing it iteratively like that is somewhat of a pain. It’s more natural to traverse a tree recursively. But how do we recurse in an iterator?
First, let’s write this as a coroutine. Forget about iterators for now, let’s just write a coroutine that will yield all the nodes / paths:
stack = {} function traverse(node) for k, v in pairs(node) do table.insert(stack, k) coroutine.yield(k, stack) if type(v) == 'table' then traverse(v) end table.remove(stack) end end co = coroutine.create(traverse)
Then, we can call it like this:
repeat local _, node, path = coroutine.resume(co, tree) print(node, inspect(path)) until not node
(We pass in the same tree every time but that’s okay, because the argument is ignored every time after the first.)
So, let’s now turn this general pattern into an iterator:
function co_dft(tree) local stack = {} local function traverse(node) for k, v in pairs(node) do table.insert(stack, k) coroutine.yield(k, stack) if type(v) == 'table' then traverse(v) end table.remove(stack) end end local co = coroutine.create(function() traverse(tree) end) return function() local _, value, stack = coroutine.resume(co) return value, stack end end
It’s a pretty simple transformation. Move everything into the local scope of the iterator, and return a function that’s a wrapper for coroutine.resume
. We can do this to make an iterator out of any coroutine, actually. And now that it’s an iterator, we can call it just like the iterative version:
for node, path in co_dft(tree) do print(node, inspect(path)) end
Iterators are powerful
So, that’s how iterators work. It’s more than just an awkward syntax for a for
loop; it’s actually an incredibly powerful feature of Lua.
As always, the code for this is available on Github.