Compile-time Binomial Coefficient
Building a fast n choose k function.
Combinatorics coding problems often requires calculating binomial coefficient. python has
math.comb
andmath.factorial
built in, so these problems become easier for using python vs c++. To further complicate things, these problem usually requires the result to be%
by a large prime number e.g.1000000007
since the result is too large to fit in built-in C++ number types.
In this post, I share some implementations I made for calculating binomial coefficient quickly.
TL;DR. Here’s the helper class Comb(N, K)
that initializes at runtime with modulo 1e9+7
.
I find it performs reasonably without knowing the details of test cases.
To use it, we need to first initialize it
RuntimeInitCacheComb combHelper(N+1);
Then, just call
combHelper(n, k);
to get (n choose k) % 1000000007
class RuntimeInitCacheComb {
public:
size_t N;
RuntimeInitCacheComb() = default;
RuntimeInitCacheComb(size_t n) : N(n){
facts.resize(N);
inv_fact.resize(N);
std::fill_n(begin(facts), size(facts), -1);
std::fill_n(begin(inv_fact), size(inv_fact), -1);
buildfact();
for (int k = 0; k < N; k++) {
buildInvFact(k);
}
}
ll fact(ll x) {
if (x <= 1) {
return 1;
}
if (x >= N) {
throw std::runtime_error("out of bound: " + to_string(x));
}
return facts[x];
}
ll comb(ll n, ll k) const {
if (n >= N) {
cout << "n: " << n << " out of bound";
assert(false);
}
return mod_mul(facts[n], mod_mul(inv_fact[k], inv_fact[n-k]));
}
private:
vector<ll> facts;
vector<ll> inv_fact;
constexpr void buildfact() {
for (int i = 0; i < N; i++) {
facts[i] = i<=1 ? 1 : mod_mul(i, facts[i-1]);
}
}
constexpr ll inv(ll i) {
if (i > MOD) {
i = i % MOD;
}
if (i <= 1) {
return i;
}
return MOD - mod_mul((long long)(MOD / i), inv(MOD % i));
}
constexpr ll buildInvFact(ll k) {
inv_fact[k] = inv(facts[k]);
return inv_fact[k];
}
};
The Math
I learned all the math of this by following links in this post
TL;DR. Here are the essential equations we’ll need. The post goes in detail in deriving these equations.
(n choose k) mod m equation using inverse factorial.
long long binomial_coefficient(int n, int k) {
return factorial[n] * inverse_factorial[k] % m * inverse_factorial[n - k] % m;
}
factorial modular inverse equation
long long inv(long long i) {
return i <= 1 ? i : m - (long long)(m/i) * inv(m % i) % m;
}
Runtime Initialization with Cache
The version you see above is essentially caching all factorial results up to N
and also their modular inverse value at initialization time so that the comb(n, k)
function can reach O(1)
at runtime by directly using the cached results.
- This is most versatile version for all combinatorics problems on LeetCode.
CompileTime version (Just for fun)
You might notice the constexpr
signature in the version above, which is not needed for the runtime init version at all. I just copied it from my compile time version. constexpr
signature tells the compiler the function may be evaluated at compile time, so they can be used at runtime as well.
Please see below for the compile time version.
- Unfortunately, it won’t work when compiling on LeetCode. With a large
N
, we need to allow set compiler flags to allow largerconstexpr steps
otherwise we will get error like:constexpr evaluation hit maximum step limit; possible infinite loop?
. - In theory, it is faster than the runtime init version.
template<int N>
class CompileTimeComb {
public:
constexpr CompileTimeComb() {
std::fill_n(begin(facts), size(facts), -1);
std::fill_n(begin(inv_fact), size(inv_fact), -1);
buildfact();
for (int k = 0; k < N; k++) {
buildInvFact(k);
}
}
ll fact(ll x) {
if (x <= 1) {
return 1;
}
if (x >= N) {
throw std::runtime_error("out of bound: " + to_string(x));
}
return facts[x];
}
ll comb(ll n, ll k) const {
assert(n < N);
return mod_mul(facts[n], mod_mul(inv_fact[k], inv_fact[n-k]));
}
private:
ll facts[N] = {0};
ll inv_fact[N] = {0};
constexpr void buildfact() {
for (int i = 0; i < N; i++) {
facts[i] = i<=1 ? 1 : mod_mul(i, facts[i-1]);
}
}
constexpr ll inv(ll i) {
if (i > MOD) {
i = i % MOD;
}
if (i <= 1) {
return i;
}
if (i < N && inv_fact[i] != -1) {
return inv_fact[i];
}
auto res = MOD - mod_mul((long long)(MOD / i), inv(MOD % i));
if (i < N) {
inv_fact[i] = res;
}
return res;
}
constexpr ll buildInvFact(ll k) {
inv_fact[k] = inv(facts[k]);
return inv_fact[k];
}
};
Test Results
I used problem 2539. Count the Number of Good Subsequences with input vjegrusjmearasjqegqesrevdduqaqvsduguqgqvqdgmeeeuqgeuaumrjsejqmgrqgjudegagmjgqrrsresuurusrduuae
as benchmark input.
Here’s the google benchmark result. The time does not include initialization time or compile time. See source here.
- The
Runtime Only
version will calculate both factorial and modular inverse of a factorial at runtime, so obviously it has the worst result.
Comb function types | Runtime average |
---|---|
Compile Time | 0.653 ms |
Runtime Init with Cache | 0.709 ms |
Runtime Only | 20.0 ms |