Compile-time Binomial Coefficient

Building a fast n choose k function.

Combinatorics coding problems often requires calculating binomial coefficient. python has math.comb and math.factorial built in, so these problems become easier for using python vs c++. To further complicate things, these problem usually requires the result to be % by a large prime number e.g. 1000000007 since the result is too large to fit in built-in C++ number types.

In this post, I share some implementations I made for calculating binomial coefficient quickly.

TL;DR. Here’s the helper class Comb(N, K) that initializes at runtime with modulo 1e9+7.

I find it performs reasonably without knowing the details of test cases.

To use it, we need to first initialize it

  • RuntimeInitCacheComb combHelper(N+1);

Then, just call

  • combHelper(n, k);

to get (n choose k) % 1000000007

class RuntimeInitCacheComb {
public:
    size_t N;

    RuntimeInitCacheComb() = default;
    RuntimeInitCacheComb(size_t n) : N(n){
        facts.resize(N);
        inv_fact.resize(N);
        std::fill_n(begin(facts), size(facts), -1);
        std::fill_n(begin(inv_fact), size(inv_fact), -1);
        buildfact();
        for (int k = 0; k < N; k++) {
            buildInvFact(k);
        }
    }
    ll fact(ll x) {
        if (x <= 1) {
            return 1;
        }
        if (x >= N) {
            throw std::runtime_error("out of bound: " + to_string(x));
        }
        return facts[x];
    }

    ll comb(ll n, ll k) const {
        if (n >= N) {
            cout << "n: " << n << " out of bound";
            assert(false);
        }
        return mod_mul(facts[n], mod_mul(inv_fact[k], inv_fact[n-k]));
    }
private:
    vector<ll> facts;
    vector<ll> inv_fact;

    constexpr void buildfact() {
        for (int i = 0; i < N; i++) {
            facts[i] = i<=1 ? 1 : mod_mul(i, facts[i-1]);
        }
    }

    constexpr ll inv(ll i) {
        if (i > MOD) {
            i = i % MOD;
        }
        if (i <= 1) {
            return i;
        }
        return MOD - mod_mul((long long)(MOD / i), inv(MOD % i));
    }

    constexpr ll buildInvFact(ll k) {
        inv_fact[k] = inv(facts[k]);
        return inv_fact[k];
    }
};

The Math

I learned all the math of this by following links in this post

TL;DR. Here are the essential equations we’ll need. The post goes in detail in deriving these equations.

(n choose k) mod m equation using inverse factorial.

long long binomial_coefficient(int n, int k) {
    return factorial[n] * inverse_factorial[k] % m * inverse_factorial[n - k] % m;
}

factorial modular inverse equation

long long inv(long long i) {
  return i <= 1 ? i : m - (long long)(m/i) * inv(m % i) % m;
}

Runtime Initialization with Cache

The version you see above is essentially caching all factorial results up to N and also their modular inverse value at initialization time so that the comb(n, k) function can reach O(1) at runtime by directly using the cached results.

  • This is most versatile version for all combinatorics problems on LeetCode.

CompileTime version (Just for fun)

You might notice the constexpr signature in the version above, which is not needed for the runtime init version at all. I just copied it from my compile time version. constexpr signature tells the compiler the function may be evaluated at compile time, so they can be used at runtime as well. Please see below for the compile time version.

  • Unfortunately, it won’t work when compiling on LeetCode. With a large N, we need to allow set compiler flags to allow larger constexpr steps otherwise we will get error like: constexpr evaluation hit maximum step limit; possible infinite loop?.
  • In theory, it is faster than the runtime init version.
template<int N>
class CompileTimeComb {
public:
    constexpr CompileTimeComb() {
        std::fill_n(begin(facts), size(facts), -1);
        std::fill_n(begin(inv_fact), size(inv_fact), -1);
        buildfact();
        for (int k = 0; k < N; k++) {
            buildInvFact(k);
        }
    }
    ll fact(ll x) {
        if (x <= 1) {
            return 1;
        }
        if (x >= N) {
            throw std::runtime_error("out of bound: " + to_string(x));
        }
        return facts[x];
    }

    ll comb(ll n, ll k) const {
        assert(n < N);
        return mod_mul(facts[n], mod_mul(inv_fact[k], inv_fact[n-k]));
    }
private:
    ll facts[N] = {0};
    ll inv_fact[N] = {0};

    constexpr void buildfact() {
        for (int i = 0; i < N; i++) {
            facts[i] = i<=1 ? 1 : mod_mul(i, facts[i-1]);
        }
    }

    constexpr ll inv(ll i) {
        if (i > MOD) {
            i = i % MOD;
        }
        if (i <= 1) {
            return i;
        }
        if (i < N && inv_fact[i] != -1) {
            return inv_fact[i];
        }
        auto res = MOD - mod_mul((long long)(MOD / i), inv(MOD % i));
        if (i < N) {
            inv_fact[i] = res;
        }
        return res;
    }

    constexpr ll buildInvFact(ll k) {
        inv_fact[k] = inv(facts[k]);
        return inv_fact[k];
    }
};

Test Results

I used problem 2539. Count the Number of Good Subsequences with input vjegrusjmearasjqegqesrevdduqaqvsduguqgqvqdgmeeeuqgeuaumrjsejqmgrqgjudegagmjgqrrsresuurusrduuae as benchmark input.

Here’s the google benchmark result. The time does not include initialization time or compile time. See source here.

  • The Runtime Only version will calculate both factorial and modular inverse of a factorial at runtime, so obviously it has the worst result.
Comb function types Runtime average
Compile Time 0.653 ms
Runtime Init with Cache 0.709 ms
Runtime Only 20.0 ms

© 2022. All rights reserved.