【学习笔记】Lucas定理

定义

对于非负整数 n,m 和素数 p,设 n,mp 进制下的展开为:
n=\sum_{i=1}^k n_ip^{k_i}\\m=\sum_{i=1}^k m_ip^{k_i}\\
则有同余式:
C(n,m)\equiv\prod_{i=1}^k C(n_i,m_i)\pmod{p}
这就是Lucas定理。

写成递推式的形式,即为:
\mathrm{lucas}(n,m)=C(n\bmod p,m\bmod p)\cdot\mathrm{lucas}(\lfloor n/p\rfloor,\lfloor m/p\rfloor)
边界条件为:当 m=0 时,\mathrm{lucas}(n,m)=1

应用

Lucas定理适用于求解大数组合数取模的问题。在求解时,根据递推式及边界条件递归即可

总复杂度为 O(f+tg\log n),其中 f 为预处理组合数的复杂度,g 为单次求组合数的复杂度,t 为Lucas定理使用的次数

于是有了下面几种求解方案:

  1. 杨辉三角预处理组合数。此时 f=p^2,g=1,总复杂度 O(p^2+t\log n),适用于 p 不太大且 t 较大的情况
  2. 预处理阶乘及阶乘逆元,公式求组合数。此时 f=p,g=1,总复杂度 O(p+t\log n),适用于 t 较大的情况
  3. 暴力计算组合数。此时 f=0,g=p,总复杂度 O(tp\log n),适用于 t 较小的情况

综上,方案2应该是最普适最常用的方案

Code

Luogu P3807 【模板】卢卡斯定理

预处理阶乘及阶乘逆元方案:

#include <iostream>
#include <stdio.h>
#include <string.h>
#define MAX_P 100005
#define inv(x) fpow(x, p - 2)
#define int long long

using namespace std;

int T, n, m, p;
int fac[MAX_P];
int inf[MAX_P];

int fpow(int x, int k) {
    int ans = 1;
    while (k) {
        if (k & 1) {
            ans = ans * x % p;
        }
        x = x * x % p, k >>= 1;
    }
    return ans;
}

int getc(int n, int m) {
    if (n < m) {
        return 0;
    }
    return fac[n] * inf[m] % p * inf[n - m] % p;
}

int lucas(int n, int m) {
    if (!m) {
        return 1;
    }
    return getc(n % p, m % p) * lucas(n / p, m / p) % p;
}

signed main() {
    scanf("%lld", &T);
    while (T--) {
        scanf("%lld%lld%lld", &n, &m, &p);
        fac[0] = 1;
        for (int i = 1; i < p; i++) {
            fac[i] = fac[i - 1] * i % p;
        }
        inf[p - 1] = inv(fac[p - 1]);
        for (int i = p - 2; i >= 0; i--) {
            inf[i] = inf[i + 1] * (i + 1) % p;
        }
        printf("%lld\n", lucas(n + m, m));
    }
}

暴力计算组合数方案:

#include <iostream>
#include <stdio.h>
#include <string.h>
#define inv(x) fpow(x, p - 2)
#define int long long

using namespace std;

int T, n, m, p;

int fpow(int x, int k) {
    int ans = 1;
    while (k) {
        if (k & 1) {
            ans = ans * x % p;
        }
        x = x * x % p, k >>= 1;
    }
    return ans;
}

int getfac(int n) {
    int ans = 1;
    for (int i = 2; i <= n; i++) {
        ans = ans * i % p;
    }
    return ans;
}

int getc(int n, int m) {
    if (n < m) {
        return 0;
    }
    return getfac(n) * inv(getfac(m) * getfac(n - m) % p) % p;
}

int lucas(int n, int m) {
    if (!m) {
        return 1;
    }
    return getc(n % p, m % p) * lucas(n / p, m / p) % p;
}

signed main() {
    scanf("%lld", &T);
    while (T--) {
        scanf("%lld%lld%lld", &n, &m, &p);
        printf("%lld\n", lucas(n + m, m));
    }
}
点赞

发表评论

电子邮件地址不会被公开。必填项已用 * 标注