Exploring Multiple Dispatch

Multiple dispatch is a very powerful technique that can simplify code substantially. In short, it's about being able to define generic functions that can work on a wide variety of types, perhaps using different code depending on the types used.

In this notebook, we'll explore how to do so-called automatic differentiation with multiple dispatch, using the Julia language.

Taking Derivatives

Firstly, let's have a look at the problem. Suppose we have a function $f$ that does something unknown. We want to find its derivative, $f'$. How might we approach this problem?

Let's consider an example. Polynomial $p$ below is defined by $p(x) = 2x^3 + 3x^2 + 6x + 6$. We know from calculus that the derivative $p'(x) = 6x^2 + 6x + 6$. But how might we implement this in code? Recall the limit definition of a derivative: \begin{equation} f'(a) = \lim_{x\to a} \frac{f(x) - f(a)}{x-a} \end{equation}

Then, to calculate the derivative at $a$, we might be inclined to choose a value of $x$ very close to $a$, and simply evaluate the above expression. How close? With floating point numbers, we might wish to choose $a + a \cdot \mathbf{u}$, where $\mathbf{u}$ is known as "machine epsilon", and is a bound on the error due to rounding of floating point operations. (We will see later that this is not necessarily the best choice.)

We'll work with $64$-bit floating points, so let's first compute our value of $\mathbf{u}$ and define polynomial $p$:

In [1]:
const u = eps(Float64)
In [2]:
p(x) = 2x^3 + 3x^2 + 6x + 6
Out[2]:
p (generic function with 1 method)

Now let's create our differentiation function. We want this function to take in a function, and return a function that computes its derivative. In Julia, the -> syntax denotes the creation of a function. Think of the below code as mathematically representing \begin{equation} f' = a \mapsto \frac{f(a + a \cdot \mathbf{u}) - f(a)}{a \cdot \mathbf{u}} \end{equation} which is just a fancy way of saying \begin{equation} f'(a) = \frac{f(a + a \cdot \mathbf{u}) - f(a)}{a \cdot \mathbf{u}} \end{equation}

In [3]:
der_u(f) = a -> (f(a + a * u) - f(a)) / (a * u)
Out[3]:
der_u (generic function with 1 method)

To test this function, we'll define p′ (note that this is not an apostrophe, which would be a syntax error; this is a unicode character that looks similar), and evaluate it at several points. For instance, we expect $p'(0) = 6$, $p'(1) = 18$, since these are the exact result that calculus gives us.

In [4]:
p = der_u(p)
p(1)
WARNING: Base.SingleAsyncWork is deprecated, use Base.AsyncCondition instead.
  likely near /home/fengyang/.julia/v0.5/IJulia/src/kernel.jl:31
in gc_protect_handle at /home/fengyang/.julia/v0.5/ZMQ/src/ZMQ.jl
WARNING: Base.SingleAsyncWork is deprecated, use Base.AsyncCondition instead.
  likely near /home/fengyang/.julia/v0.5/IJulia/src/kernel.jl:31
WARNING: Base.SingleAsyncWork is deprecated, use Base.AsyncCondition instead.
  likely near /home/fengyang/.julia/v0.5/IJulia/src/kernel.jl:31
in gc_protect_handle at /home/fengyang/.julia/v0.5/ZMQ/src/ZMQ.jl
WARNING: Base.SingleAsyncWork is deprecated, use Base.AsyncCondition instead.
  likely near /home/fengyang/.julia/v0.5/IJulia/src/kernel.jl:31
in gc_protect_handle at /home/fengyang/.julia/v0.5/ZMQ/src/ZMQ.jl
WARNING: Base.SingleAsyncWork is deprecated, use Base.AsyncCondition instead.
  likely near /home/fengyang/.julia/v0.5/IJulia/src/kernel.jl:31
in gc_protect_handle at /home/fengyang/.julia/v0.5/ZMQ/src/ZMQ.jl
Out[4]:
16.0

Yikes. This is a complete disaster. Two off is bad enough, and returning a result that looks so innocent is even worse!

In [5]:
p(0)
Out[5]:
NaN

Oof. If being off by $2$ were bad, being off by NaN must be even worse. Clearly we need to refine our approach. Firstly, we need to make sure that there's a minimum difference that we use (for the case of $0$). And next, we should be more conservative about our choice of epsilon. Machine epsilon is too small. How about its square root, which is a bigger number? We'll call this $\mathbf{v} = \sqrt{\mathbf{u}}$.

In [6]:
const v = sqrt(u)
der_v(f) = function(a)
    diff = a * v + v
    (f(a + diff) - f(a)) / diff
end
Out[6]:
der_v (generic function with 1 method)
In [7]:
p = der_v(p)
p(1)
Out[7]:
18.00000023841858
In [8]:
p(0)
Out[8]:
6.000000059604645

Phew. This is much better. It could be more exact, but it's quite good as is. Let's try it on some other functions. Let's start with $\mathrm{e}^x$, whose derivative is $\mathrm{e}^x$.

In [9]:
exp = der_v(exp)
exp(0)
Out[9]:
1.0
In [10]:
exp(1)
Out[10]:
2.7182818800210953
In [11]:
exp(log(1000))
Out[11]:
1000.0000560290179

So far so good. How about $\sin(x)$, whose derivative is $\cos(x)$?

In [12]:
sin = der_v(sin)
sin(0)
Out[12]:
1.0
In [13]:
sin(acos(0.9))
Out[13]:
0.8999999961400434

Excellent! Everything is just fine and dandy. Now how about $\log(x)$ (base $\mathrm{e}$), whose derivative is $\frac{1}{x}$?

In [14]:
log = der_v(log)
log(1)
Out[14]:
0.9999999850988391
In [15]:
log(10)
Out[15]:
0.09999999674883756
In [16]:
log(1/1000000)
Out[16]:
992622.6095255658

Hmm. The error near $0$ is a bit higher than we might prefer. We could fix that by reducing $\mathbf{v}$, but if we do that too much we end up in trouble. For now let's consider higher order derivatives. What about the second derivative of $\mathrm{e}^x$?

In [17]:
exp′′ = der_v(exp)
exp′′(0)
Out[17]:
1.9999999701976776
In [18]:
exp′′(1)
Out[18]:
2.1408590227365494
In [19]:
exp′′(log(10))
Out[19]:
9.441047175715648

This is a disaster. The error, which was acceptable initially, just compounded itself. And the problem is not specific to $\mathrm{e}^x$.

In [20]:
sin′′ = der_v(sin)
sin′′(π)
Out[20]:
0.008254849939365927
In [21]:
sin′′(π/2)
Out[21]:
-0.9078537644150019
In [22]:
log′′ = der_v(log)
log′′(10)
Out[22]:
0.007438016716729511

You might imagine it to be a problem when the second deriviative of $\log$, which is $x\mapsto \frac{-1}{x^2}$, is computed to be a positive number of magnitude comparable to the negative number it's supposed to be. Third and fourth derivatives, as you might imagine, are even worse. The error is ridiculous.

In [23]:
exp′′′ = der_v(exp′′)
exp′′′(1)
Out[23]:
1.6777211679570556e7
In [24]:
sin′′′′ = der_v(der_v(sin′′))
sin′′′′(π)
Out[24]:
-0.7602858901604265

It is easy to not notice the scientific notation in the computation of the third derivative of $\mathrm{e}^x$ at $1$. The correct answer was $\mathrm{e}=2.718\dots$, and $1.6777$ would be far enough off already. But, look closely and it was actually computed to be $1.6777\times10^7$. Seven orders of magnitude off. For your entertainment, I have included the fifth derivative of $\mathrm{e}^x$ also.

In [25]:
exp′′′′′ = der_v(der_v(exp′′′))
exp′′′′′(1)
Out[25]:
-3.777892623345787e22

Oof. We need a better way. In an ideal world, we could compute exact derivatives. Instead of approximating a limit by picking small numbers, we could compute the limit exactly. Wouldn't that be neat. It would be great if it were possible, but we are working with dumb silicon machines here, which cannot possibly compute values as exact as a human could. One can dream though. And if one dreamt hard enough, and one knew enough linear algebra, then one might come up with dual numbers, the technique that makes this possible.

Dual Numbers

A dual number is like a complex number in that it has two components. But instead of $\mathbb{R}[\sqrt{-1}]$, we will consider $\mathbb{R}[\varepsilon]$, where $\varepsilon$ is defined so that $\varepsilon^2=0$. The intuition behind this definition is that $\varepsilon$ is really small. It's so small that it's smaller than any positive real number, but just a little bigger than $0$. It's so small that if we multiply it by itself (multiplying really small by really small), that we can no longer see it.

Aha. Now what if instead of using $\mathbf{v} + a\mathbf{v}$ as our epsilon, which is small but not that small, we can use $\varepsilon$ as our epsilon? Fat chance, right? If even machine epsilon ($\mathbf{u}$) was too small to work, how could $\varepsilon$ possibly work?

Our first speed bump is that $\varepsilon$ is not something built-in. We need to define it. Let's worry about the details later and build a kind of DualNumber type:

In [26]:
immutable DualNumber{N <: Number} <: Number
    re::N
    ep::N
end

What the above definition says, in short, is that we wish to create a new DualNumber type with a single type parameter, which is a Number. In Julia, Number is an abstract type with many concrete implementations. Mathematically, we want DualNumbers to work with any field. Examples of common fields are $\mathbb{R}$ and $\mathbb{C}$. Furthermore, a DualNumber itself is a kind of number. We can create DualNumber{Float64} to deal with floating point type DualNumbers, and that's what we want for now. But we could do other cool stuff too. We'll look at that later.

The type itself contains two fields, a re field containing the real (big) part, and the ep field containing the epsilon (small) part. So something like DualNumber(1.0, 2.0) would be equivalent to $1 + 2\epsilon$.

Julia will by default print DualNumbers like DualNumber{Float64}(1.0, 2.0). This is descriptive and useful, but for our eyes' sake we should print DualNumbers of floats the way we expect, as $a+b\varepsilon$, just like how complex numbers are usually printed. We can define the Base.show method on DualNumber{T} for all T <: AbstractFloat. AbstractFloat in Julia is an abstract type, again with many concrete implementations, in particular Float16, Float32, Float64, and BigFloat. We only want to define this show method on floats, because if we had DualNumbers of other types (such as complex numbers), it may not be printed nicely as $a+b\varepsilon$.

In [27]:
import Base: show

function show{T <: AbstractFloat}(io::IO, num::DualNumber{T})
    show(io, num.re)
    if signbit(num.ep)
        print(io, " − ")
        show(io, -num.ep)
    else
        print(io, " + ")
        show(io, num.ep)
    end
    print(io, "ɛ")
end
Out[27]:
show (generic function with 133 methods)

It would be convenient to have an $\varepsilon$ constant for testing and usability purposes. But what type should this constant be? Ideally, we want a field that can be embedded in any bigger field, so that our constant can be converted to any type of DualNumber. The smallest possible field is $\mathbb{Z}_2$. This is the field of booleans. So we want to define $\varepsilon$ as a DualNumber{Bool}. Sounds crazy, but it's mathematically sound.

In [28]:
const ɛ = DualNumber(false, true)
Out[28]:
DualNumber{Bool}(false,true)

Good. It works. But Julia still doesn't know how these numbers work; we've only taught it how to show them, and what $\varepsilon$ is. First, since these things are a vector space over their base field, we need to define addition and scalar multiplication. Scalar multiplication could be defined right now, but since we will be defining multiplication with these things soon enough, we can afford to wait on that briefly. The definition of addition is another example of multiple dispatch.

In [29]:
import Base: +

+{T<:Number}(x::DualNumber{T}, y::DualNumber{T}) = DualNumber{T}(x.re + y.re, x.ep + y.ep)

DualNumber(10.0, 17.0) + DualNumber(5.0, 9.0)
Out[29]:
15.0 + 26.0ɛ

It will also be useful to define some conversions between dual numbers and scalar types. This will let us write, say, 5 + 2.0ɛ and get the right type back. Julia handles conversions and promotions using multiple dispatch. (Get used to it. Multiple dispatch is Julia's best feature, and good language design means heavily relying on your best feature! That's why C relies so much on raw pointer arithmetic.)

In [30]:
import Base: convert

convert{T<:Number}(::Type{DualNumber{T}}, x::DualNumber{T}) = x
convert{T<:Number}(::Type{DualNumber{T}}, x::DualNumber) =
    DualNumber{T}(convert(T, x.re), convert(T, x.ep))
convert{T<:Number}(::Type{DualNumber{T}}, x::Number) =
    DualNumber{T}(convert(T, x), zero(T))
Out[30]:
convert (generic function with 555 methods)

The above three methods describe ways to convert various kinds of things to DualNumbers. The first one says that to convert a dual number type to itself, nothing needs to be done. The second one tells Julia that to convert a DualNumber of one type to another, just convert each component. The third one tells Julia that to convert a regular number to a DualNumber, convert it to the base field and set the epsilon component to $0$. All these methods are mathematically straightforward and sound.

Conversions themselves are not enough to add dual numbers with scalars. We also need to tell Julia how to promote the operands to a common, compatible type. In our case, we want to convert the scalars to DualNumbers.

In [31]:
import Base: promote_rule

promote_rule{T<:Number, U<:Number}(::Type{DualNumber{T}}, ::Type{DualNumber{U}}) =
    DualNumber{promote_type(T, U)}
promote_rule{T<:Number, U<:Number}(::Type{DualNumber{T}}, ::Type{U}) =
    DualNumber{promote_type(T, U)}
Out[31]:
promote_rule (generic function with 99 methods)

The first rule here says that two DualNumber types should be promoted to a DualNumber type whose base field is one compatible with both scalar types. The second rule says that a DualNumber and a scalar should be promoted to a DualNumber, with base field again compatible with both scalar types. By now we have encoded all the properties of dual numbers as being $\mathbb{F}[\varepsilon]$ for some field $\mathbb{F}$, but we have yet to define the behaviour of this algebra under multiplication.

Note \begin{equation} (a + b\varepsilon)(c + d\varepsilon) = ac + ad\varepsilon + bc\varepsilon + bd\varepsilon^2 = ac + (ad + bc)\varepsilon \end{equation} which you might notice resembles the Leibniz product rule. We can easily encode this:

In [32]:
import Base: *

*{T<:Number}(x::DualNumber{T}, y::DualNumber{T}) =
    DualNumber{T}(x.re * y.re, x.re * y.ep + x.ep * y.re)
Out[32]:
* (generic function with 140 methods)

As mentioned above, a vector space must support scalar multiplication. But the above definition, in addition to our promotion rules, has us covered. Have a look:

In [33]:
7(5.0 + 2ɛ)
Out[33]:
35.0 + 14.0ɛ

When you think about it, this is remarkably little code to implement an entire algebra on a new kind of number, especially when you consider that we have already handled tricky things like working with different types and conversions. This type business is a really good example of how powerful multiple dispatch is. We may now implement our derivative function:

In [34]:
der(f) = a -> f(DualNumber(a, one(a))).ep
Out[34]:
der (generic function with 1 method)

A result from linear algebra tells us that because of how our dual numbers are defined, when the real part is looked at in isolation, the operations of addition and multiplication return the same results as if we did them on real numbers. So there is no sense inspecting the real component of $f(a + \varepsilon)$ unless we also care about the value at that point. The epsilon component alone determines the derivative.

This may sound magical still, but rest assured that it is all mathematically sound. And it works:

In [35]:
p = der(p)
p(0)
Out[35]:
6
In [36]:
p(1)
Out[36]:
18

Exact derivatives. Isn't that neat. But before we get too excited, we should note that our dual numbers only work with addition and multiplication right now. We haven't implemented any other operations (though Julia was able to figure out how to do (integer) exponentiation, which is another testament to the power of multiple dispatch). We can't yet take the derivative of, say, $\mathrm{e}^x$ or $\sin(x)$ or $\cos(x)$ or $\log(x)$, for example. Nor have we implemented division. Heck, we haven't even implemented subtraction! None of that is very hard though. Observe:

In [37]:
import Base: exp, log, sin, cos, /, -

-{T<:Number}(x::DualNumber{T}) = DualNumber{T}(-x.re, -x.ep)
x::DualNumber - y::DualNumber = x + (-y)
exp(x::DualNumber) = DualNumber(exp(x.re), exp(x.re) * x.ep)
sin(x::DualNumber) = DualNumber(sin(x.re), cos(x.re) * x.ep)
cos(x::DualNumber) = DualNumber(cos(x.re), -sin(x.re) * x.ep)
log(x::DualNumber) = DualNumber(log(x.re), x.ep / x.re)
x::DualNumber / y::DualNumber = DualNumber(x.re / y.re, x.ep / y.re - x.re * y.ep / y.re^2)
Out[37]:
/ (generic function with 49 methods)

Note that each of the functions we implemented is an application of the chain rule, with the exception of / which is an expanded form of the quotient rule. Convince yourself that these are mathematically correct. Then we can compute the derivative of any function, as long as it only calls the elementary functions that we've defined. Any other elementary functions needed can usually be quickly defined.

For instance, here is an implementation of exponentiation, which extends the integer exponentiation Julia inferred:

In [38]:
import Base: ^

x::DualNumber ^ y::DualNumber = exp(y * log(x))
Out[38]:
^ (generic function with 47 methods)

And thus we can calculate the derivative of x^(x^x) like a piece of cake. Easier than Sal Khan could do it! You can check Wolfram Alpha that this does indeed work.

In [39]:
der(x -> x^x^x)(2)
Out[39]:
107.11041244660137

We might expect that we could not find second derivatives in this way. The result of the der function, which relies on all the dual number arithmetic we developed, is too complex to even consider the possibility that it could deal with itself. Right?

In [40]:
der(der(x -> e^x))(5)
Out[40]:
148.4131591025766
In [41]:
e^5
Out[41]:
148.4131591025766
In [42]:
der(der(x -> sin(x)))(1)
Out[42]:
-0.8414709848078965
In [43]:
-sin(1)
Out[43]:
-0.8414709848078965
In [44]:
der(der(der(x -> log(x))))(10)
Out[44]:
0.002
In [45]:
2/10^3
Out[45]:
0.002

You really can't make this stuff up. Higher order derivatives, for free, and they're as exact as you want them. How? It turns out that the way we've written our code allows DualNumber{DualNumber{Float64}}. That is, we can make DualNumbers from DualNumbers, and get our second derivatives that way. The der function was general enough that it didn't care what type it was using, as long as it was a subtype of Number, which DualNumber is. Mathematically, this isn't quite kosher, since the DualNumbers are not a field, but it still works for derivatives. Some vector space properties may be lost in translation though, but we did not really use those anyways.

The only thing that would make this even better is if we could also compute exact complex derivatives. As you might have guessed, the answer is yes.

In [46]:
der(x -> x^2)(1 + 2im)
Out[46]:
2 + 4im

I hope that this illustrates two important principles: linear algebra is powerful, and multiple dispatch is powerful. When I saw this for the first time, it blew my mind.

So what is multiple dispatch?

It might not have been clear what this magical fairy I was referring to actually is. What is multiple dispatch, and how did it help us achieve our result? I will try to briefly summarize below, but for more information, the Julia documentation gives a very good overview.

When we talk in mathematics about a concept like $+$, we are talking about an abstract operation that has certain properties. For example, we expect it to be associative and commutative. But how this operation actually works is different depending on what exactly we are adding. Adding two integers is not the same operation as adding two dual numbers, as you probably could guess. But from a more abstract level, we think of both as $+$.

Multiple dispatch allows the computer (and the person who implemented the methods in the first place) to handle the details of what exactly is the operation meant by $+$, while we relax and simply use it. The same function has multiple implementations depending on the types of its arguments. We may inspect this by looking at the native code generated for the addition (for my computer; your results may vary):

In [47]:
@code_native 8 + 8
	.text
Filename: int.jl
Source line: 0
	pushq	%rbp
	movq	%rsp, %rbp
Source line: 32
	leaq	(%rdi,%rsi), %rax
	popq	%rbp
	retq
	nopw	(%rax,%rax)
In [48]:
@code_native (1.0 + ɛ) + (2.0 - ɛ)
	.text
Filename: In[29]
Source line: 0
	pushq	%rbp
	movq	%rsp, %rbp
Source line: 3
	movsd	(%rsi), %xmm0           # xmm0 = mem[0],zero
	movsd	8(%rsi), %xmm1          # xmm1 = mem[0],zero
	addsd	(%rdx), %xmm0
	addsd	8(%rdx), %xmm1
	movsd	%xmm1, 8(%rdi)
	movsd	%xmm0, (%rdi)
	movq	%rdi, %rax
	popq	%rbp
	retq
	nopw	%cs:(%rax,%rax)

At the processor level, therefore, these two are very different operations. But the compiler handles those issues for us. We simply use $+$. So in some sense, multiple dispatch allows us to reason about code as we would reason about mathematics. It allows us to abstract away details and perform different operations using the same name. Then we can compose those operations into advanced functionality, and have the compiler generate the right sequence of instructions.

If small, building blocks like $+$ are specialized on types, that allows bigger, higher-level functions to not worry about them. That is how we were able to define polynomials like $p$, thinking that they would be applied to real numbers, and then evaluate them on DualNumbers so we could compute their derivative. That is how Julia was able to compute integer exponentiation for our DualNumbers even though we told it nothing about it. That is how things "just worked" when we tried to dualize Complex and DualNumber. That is magic.

Normally such a feature might make things slower. Julia's compiler is built in such a way that multiple dispatch comes at no runtime penalty. That's technologically incredible. Almost magic.

In fact, nothing about what we did above would scream "fast, efficient code" in most languages. It's also not the fastest way to do it in Julia. But even so, Julia 0.5 is able to optimize der in some very simple cases. For example, below it was able to constant fold the derivative of $x \mapsto x$ to $1$:

In [49]:
code_native(der(x -> x), (Int,))
	.text
Filename: In[34]
Source line: 0
	pushq	%rbp
	movq	%rsp, %rbp
	movl	$1, %eax
Source line: 1
	popq	%rbp
	retq
	nopl	(%rax,%rax)

Even though we've made no attempt to optimize our code, der is pretty fast. Check this out:

In [51]:
p(0)
p(0)

@time p(5)
@time p(5)
  0.000002 seconds (4 allocations: 160 bytes)
  0.000003 seconds (4 allocations: 160 bytes)
Out[51]:
186

Note that these times are both overestimates. My system clock cannot distingush anything that takes shorter than about 0.000005 seconds. More important is the number of bytes allocated, which is pretty reasonable for both.

Bonus Content

Above I mentioned that der is able to take exact derivatives. That's a bold claim. I've done nothing to prove that so far, since we've been working with floating point numbers.

It turns out that I did not exaggerate. All we need is a type that can represent things symbolically! Enter... SymbolicReal:

In [52]:
immutable SymbolicReal <: Number
    expr
end

+(r::SymbolicReal, s::SymbolicReal) = SymbolicReal(:($(r.expr) + $(s.expr)))
*(r::SymbolicReal, s::SymbolicReal) = SymbolicReal(:($(r.expr) * $(s.expr)))
-(r::SymbolicReal, s::SymbolicReal) = SymbolicReal(:($(r.expr) - $(s.expr)))
/(r::SymbolicReal, s::SymbolicReal) = SymbolicReal(:($(r.expr) / $(s.expr)))
^(r::SymbolicReal, s::SymbolicReal) = exp(s * log(r))
sin(r::SymbolicReal) = SymbolicReal(:(sin($(r.expr))))
cos(r::SymbolicReal) = SymbolicReal(:(cos($(r.expr))))
exp(r::SymbolicReal) = SymbolicReal(:(exp($(r.expr))))
log(r::SymbolicReal) = SymbolicReal(:(log($(r.expr))))

convert(::Type{SymbolicReal}, r::SymbolicReal) = r
convert(::Type{SymbolicReal}, n::Number) = SymbolicReal(n)
promote_rule{T<:Number}(::Type{SymbolicReal}, ::Type{T}) = SymbolicReal
Out[52]:
promote_rule (generic function with 100 methods)

Here we bootstraped off Julia's expressions. Of course, we don't do any simplification or anything above. Expressions just keep growing and growing. But it really does work:

In [53]:
der(x -> x^3)(SymbolicReal(:(x)))
Out[53]:
SymbolicReal(:(x * (x * 1 + 1x) + 1 * (x * x)))

A human would simplify the above to $3x^2$, which is the exact derivative.

In [54]:
der(x -> x^x)(SymbolicReal(:(x)))
Out[54]:
SymbolicReal(:(exp(x * log(x)) * (x * (1 / x) + 1 * log(x))))

A human might simplify the above to $x^x (1 + \log(x))$, which is the exact derivative.

DualNumbers.jl

Our implementation of dual numbers is rushed, hurried, and incomplete. For a solid implementation currently in use today, see DualNumbers.jl.