【学习笔记】逆元

定义

ax\equiv 1\pmod{p},则称 xa 关于 p 的逆元,记作 a^{-1}

逆元存在的充要条件:a,p 互质,即 \mathrm{gcd}(a,p)=1

求解

exgcd求逆元

由于同余方程 ax\equiv 1\pmod{p} 与不定方程 ax+py=1 等价,因此可用exgcd求出不定方程的一组特解 x,y,则 x 的最小非负整数值即为所求逆元

单次复杂度 O(\log n)

费马小定理求逆元

费马小定理:若 a 是一个整数,p 是一个质数,则 a^p \equiv a\pmod{p}

如果 a 不是 p 的倍数,则费马小定理也可写为:a^{p-1}\equiv 1\pmod{p}

所以有 a\cdot a^{p-2}\equiv 1\pmod{p},显然 a^{p-2}a 的一个逆元,用快速幂即可求得

单次复杂度 O(\log n)

需要注意的是,费马小定理求逆元只适用于 p 为质数的情况

线性递推求逆元

p/i=k\dots r,即 p 整除 i 得到的商和余数分别为 k,r

显然有 ki+r=p,改写成同余方程即为 ki+r\equiv 0\pmod{p}

给同余方程两边同乘 i^{-1}r^{-1},得 kr^{-1}+i^{-1}\equiv 0\pmod{p}

所以 i^{-1}\equiv -kr^{-1}\equiv -\lfloor p/i\rfloor\cdot (p\%i)^{-1}

这样就得到了逆元的递推式,可在 O(n) 时间内求解1到n的逆元

线性求阶乘逆元

若要求出 1!\dots n! 的逆元,首先在 O(\log n) 时间内求出 n! 的逆元,然后由递推式 (n!)^{-1}\equiv(n+1)\cdot((n+1)!)^{-1}\pmod{p} 从后往前递推即可。

总复杂度 O(n)

Code

Luogu P3811 【模板】乘法逆元

exgcd求逆元:

#include <iostream>
#include <stdio.h>
#include <string.h>
#define int long long

using namespace std;

int n, p;

void exgcd(int a, int b, int &x, int &y) {
    if (!b) {
        x = 1, y = 0;
        return;
    }
    exgcd(b, a % b, y, x);
    y -= a / b * x;
}

int inv(int a) {
    int x, y;
    exgcd(a, p, x, y);
    return (x % p + p) % p;
}

signed main() {
    scanf("%lld%lld", &n, &p);
    for (int i = 1; i <= n; i++) {
        printf("%lld\n", inv(i));
    }
}

费马小定理求逆元:

#include <iostream>
#include <stdio.h>
#include <string.h>
#define int long long

using namespace std;

int n, p;

int fpow(int x, int k) {
    int ans = 1;
    while (k) {
        if (k & 1) {
            ans = ans * x % p;
        }
        x = x * x % p, k >>= 1;
    }
    return ans;
}

signed main() {
    scanf("%lld%lld", &n, &p);
    for (int i = 1; i <= n; i++) {
        printf("%lld\n", fpow(i, p - 2));
    }
}

线性递推求逆元:

#include <iostream>
#include <stdio.h>
#include <string.h>
#define MAX_N 3000005
#define int long long

using namespace std;

int n, p;
int inv[MAX_N];

signed main() {
    scanf("%lld%lld", &n, &p);
    inv[1] = 1;
    for (int i = 2; i <= n; i++) {
        inv[i] = (p - p / i) * inv[p % i] % p;
    }
    for (int i = 1; i <= n; i++) {
        printf("%lld\n", inv[i]);
    }
}
点赞

发表评论

电子邮件地址不会被公开。必填项已用 * 标注