【学习笔记】扩展Lucas定理(exLucas)

前言

扩展Lucas定理适用于大数组合数取模,且模数不一定为质数的情况

求解

Step 1

n,m 为非负整数,n\ge mp 为任意正整数。现在要求的是 C_n^m\bmod p

首先将 p 分解为质因子的幂之积:
p=\prod_{i=1}^t p_i^{k_i}
若是能够求出 C_n^m\bmod p_i^{k_i}\quad(i=1\dots t) 的值,则 C_n^m\bmod p 就可通过中国剩余定理(CRT)合并得到

所以在第一步,问题被转化为了:求 C_n^m\bmod p^k,其中 p 为素数,k 为任意正整数

Step 2

由组合数公式:
C_n^m\equiv\frac{n!}{m!(n-m!)}\pmod{p^k}
但由于 m!,(n-m)! 不一定与 p^k 互质,即 m!,(n-m)! 中可能包含因子 p,所以可能不存在 m!,(n-m)! 关于 p^k 的逆元。所以将上式变形,得:
C_n^m\equiv\frac{\frac{n!}{p^x}}{\frac{m!}{p^y}\cdot\frac{(n-m!)}{p^z}}\cdot p^{x-y-z}\pmod{p^k}
其中 \mathrm{gcd}(n!,p^k)=\mathrm{gcd}(m!,p^k)=\mathrm{gcd}((n-m)!,p^k)=1,换句话说,就是对每个阶乘项除以了 p 的幂次,使得该项中不再含有因子 p。此时式中三项均有关于 p^k 的逆元

如果可以求得 \frac{n!}{p^x}\bmod p^k,\frac{m!}{p^y}\bmod p^k,\frac{(n-m)!}{p^z}\bmod p^k 以及 x,y,z,则容易求出 C_n^m\bmod p^k

所以在第二步,问题被转化成了:求 \frac{n!}{p^x}\bmod p^k 以及 x 的值

Step 3

f(n)\equiv\frac{n!}{p^x}\pmod{p^k},其中 x 满足 \mathrm{gcd}(\frac{n!}{p^x},p)=1

现在对 n! 进行变形,将 1\dots np 的倍数分离出来:
\begin{aligned} n!&=(p\cdot 2p\cdots \lfloor\frac{n}{p}\rfloor p)\cdot \prod_{i=1,p\nmid i}^n i\\ &=p^{\lfloor\frac{n}{p}\rfloor}\cdot\lfloor\frac{n}{p}\rfloor!\cdot \prod_{i=1,p\nmid i}^n i \end{aligned}
n! 除去因子 p 时,第一项 p^{\lfloor\frac{n}{p}\rfloor} 一定会被全部除掉,第二项 \lfloor\frac{n}{p}\rfloor! 中可能包含因子 p,而第三项累乘项 \prod_{i=1,p\nmid i}^n i 中显然不可能包含因子 p。这样就得到了 f(n) 的递推式:
f(n)\equiv f(\lfloor\frac{n}{p}\rfloor)\cdot\prod_{i=1,p\nmid i}^n i\pmod{p^k}
边界条件为:当 n=0 时,f(n)=1

若将递推式中累乘项展开,则在模 p^k 意义下,存在长度为 p^k 的循环节。所以该项可写成「循环节的幂次乘余项」的形式:
\prod_{i=1,p\nmid i}^n i\equiv(\prod_{i=1,p\nmid i}^{p^k}i)^{\lfloor\frac{n}{p^k}\rfloor}\cdot\prod_{i=1,p\nmid i}^{n\bmod p^k}i\pmod{p^k}
循环节和余项都能在 O(p^k) 的时间内求出,循环节的幂次使用快速幂也可在 O(\log n) 时间内求出。所以 f(n) 每次递推的复杂度为 O(p^k)。由于递归次数为 \log n,所以求解 f(n) 的复杂度为 O(p^k\log n)

所以在第三步,我们求出了 \frac{n!}{p^x}\bmod p^k 的值

Step 4

现在考虑如何求对应的 x

g(n)=x,其中 x 满足 \mathrm{gcd}(\frac{n!}{p^x},p)=1

n! 的变形式,可以类似地得到 g(n) 的递推式:
g(n)=\lfloor\frac{n}{p}\rfloor+g(\lfloor\frac{n}{p}\rfloor)
求解 g(n) 的复杂度为 O(\log n)

所以在第四步,求出来了对应的 x

至此,原问题已得到解决

算法流程

第一步:对 p 分解质因数

第二步:对于每一对 p_i^{k_i},求出 C_n^m\equiv \frac{f(n)}{f(m)f(n-m)}\cdot p^{g(n)-g(m)-g(n-m)}\pmod{p_i^{k_i}}

第三步:用中国剩余定理合并答案

总复杂度 O(p\log n)

Code

Luogu P4720 【模板】扩展卢卡斯

#include <iostream>
#include <stdio.h>
#include <string.h>
#define MAX_L 25
#define int long long

using namespace std;

int n, m, p, tot = 0;
int fac[MAX_L];
int pk[MAX_L];
int ans[MAX_L];

void getfac(int x) {
    int t = x;
    for (int i = 2; i * i <= x; i++) {
        if (!(t % i)) {
            fac[++tot] = i, pk[tot] = 1;
            while (!(t % i)) {
                t /= i, pk[tot] *= i;
            }
        }
    }
    if (t != 1) {
        fac[++tot] = t, pk[tot] = t;
    }
}

void exgcd(int a, int b, int &x, int &y) {
    if (!b) {
        x = 1, y = 0;
        return;
    }
    exgcd(b, a % b, y, x);
    y -= a / b * x;
}

int inv(int a, int p) {
    int x, y;
    exgcd(a, p, x, y);
    return (x % p + p) % p;
}

int fpow(int x, int k, int p) {
    int ans = 1;
    while (k) {
        if (k & 1) {
            ans = ans * x % p;
        }
        x = x * x % p, k >>= 1;
    }
    return ans;
}

int getf(int n, int p, int pk) {
    if (!n) {
        return 1;
    }
    int rep = 1, rem = 1;
    for (int i = 1; i <= pk; i++) {
        if (i % p) {
            rep = rep * i % pk;
        }
    }
    for (int i = 1, t = n % pk; i <= t; i++) {
        if (i % p) {
            rem = rem * i % pk;
        }
    }
    int prod = fpow(rep, n / pk, pk) * rem % pk;
    return getf(n / p, p, pk) * prod % pk;
}

int getg(int n, int p) {
    if (n < p) {
        return 0;
    }
    return getg(n / p, p) + n / p;
}

int solve(int n, int m, int p, int pk) {
    int fn = getf(n, p, pk), fm = getf(m, p, pk), fnm = getf(n - m, p, pk);
    int x = getg(n, p), y = getg(m, p), z = getg(n - m, p);
    int ans = fn * inv(fm, pk) % pk * inv(fnm, pk) % pk * fpow(p, x - y - z, pk) % pk;
    return ans;
}

int crt(int *a, int *b, int n) {
    int x, y, A = 1, ans = 0;
    for (int i = 1; i <= n; i++) {
        A *= a[i];
    }
    for (int i = 1; i <= n; i++) {
        int m = A / a[i];
        exgcd(a[i], m, x, y);
        ans = (ans + m * y % A * b[i]) % A;
    }
    return (ans % A + A) % A;
}

int exlucas(int n, int m, int p) {
    getfac(p);
    for (int i = 1; i <= tot; i++) {
        ans[i] = solve(n, m, fac[i], pk[i]);
    }
    return crt(pk, ans, tot);
}

signed main() {
    scanf("%lld%lld%lld", &n, &m, &p);
    printf("%lld\n", exlucas(n, m, p));
}
点赞

发表评论

电子邮件地址不会被公开。必填项已用 * 标注