numerically stable way to multiply log probability matrices in numpy

logsumexp works by evaluating the right-hand side of the equation

log(∑ exp[a]) = max(a) + log(∑ exp[a - max(a)])

I.e., it pulls out the max before starting to sum, to prevent overflow in exp. The same can be applied before doing vector dot products:

log(exp[a] ⋅ exp[b])
 = log(∑ exp[a] × exp[b])
 = log(∑ exp[a + b])
 = max(a + b) + log(∑ exp[a + b - max(a + b)])     { this is logsumexp(a + b) }

but by taking a different turn in the derivation, we obtain

log(∑ exp[a] × exp[b])
 = max(a) + max(b) + log(∑ exp[a - max(a)] × exp[b - max(b)])
 = max(a) + max(b) + log(exp[a - max(a)] ⋅ exp[b - max(b)])

The final form has a vector dot product in its innards. It also extends readily to matrix multiplication, so we get the algorithm

def logdotexp(A, B):
    max_A = np.max(A)
    max_B = np.max(B)
    C = np.dot(np.exp(A - max_A), np.exp(B - max_B))
    np.log(C, out=C)
    C += max_A + max_B
    return C

This creates two A-sized temporaries and two B-sized ones, but one of each can be eliminated by

exp_A = A - max_A
np.exp(exp_A, out=exp_A)

and similarly for B. (If the input matrices may be modified by the function, all the temporaries can be eliminated.)

Leave a Comment