给定正整数\(a, b, p\),计算\(a^b\ mod\ p\)是RSA的加密和解密阶段均要使用的算法。 同时,在算法竞赛题目中,经常出现由于指数运算和排列组合运算的增长速度太快,一般的整形变量不足以表达结果的情况。此时,题目一般会要求将结果对一个特定的数取模。如果没有掌握相应的算法,在面对这种情况时,我们就容易遇到计算速度过慢导致超时,或者数值过大导致溢出的问题。

本文介绍高效计算\(a^b\ mod\ p\)和\(C(n, k)\ mod\ p\)并避免溢出的方法,希望能为读者带来帮助。

为方便起见,本文使用基于模运算符号$$mod$$的写法代替模同余的写法。例如,我们使用\(17\ mod\ 4 = 1\)来表示\(17 \equiv 1 (mod 4)\)

1. 快速幂算法

快速幂算法可以在\(O(log\ b)\)的时间复杂度内完成\(a^b\)的计算,相比较而言,直接计算则需要进行\(b\)次乘法。

快速幂算法利用了幂运算的法则:\(a^{b+c} = a^b * a^c\)。假设我们要计算\(3^{16}\),注意到\(3^{16} = 3^{8 + 8} = 3^8 * 3^8\)。也就是说,我们只要算出\(3^{(16 // 2)}\),就可以使用1次额外的乘法算出\(3^{16}\),额外时间复杂度仅为\(O(1)\)。 当指数为奇数时,这个规律依然成立。假设我们要计算计算\(3^{17}\),注意到\(3^{17} = 3^{8+8+1} = 3^8 * 3^8 * 3\),在算出\(3^{(17 // 2)}\)之后,我们只需要使用2次额外的乘法即可算出\(3^{17}\),额外时间复杂度也是\(O(1)\)。

从上述两个例子可见,我们可以将计算\(a^b\)的问题,在\(O(1)\)的时间内转换成计算\(a^{(b//2)}\)的问题,这也就意味着,我们可以在\(O(log\ b)\)的时间复杂度内将问题转换成常数时间复杂度的问题。这就是快速幂算法。

了解了快速幂算法的原理之后,我们可以用递归的方法实现快速幂算法。递归版本的python实现如下:

def quick_pow(a: int, b: int) -> int:
    """使用O(log b)时间复杂度计算a^b的算法(递归版本)。"""
    if b == 0:
        return 1
    
    h = quick_pow(a, b // 2)
    if b % 2 == 1:
        return h * h * a
    
    return h * h

我们也可以从另一个角度看待快速幂计算。设想我们要计算\(a^{13}\),\(13\)的二进制表示是\(1101\),即\(13=2^0*1 +2^1*0+2^2*1+2^3*1\),,因此,\(a^{13}=a^{2^0*1}*a^{2^1*0}*a^{2^2*1}*a^{2^3*1}= (a)^1*(a^2)^0*((a^2)^2)^1*(((a^2)^2)^2)^1\)。容易注意到,从第二项开始,每一项的底数是上一项的平方,而指数则与\(13\)的相应二进制位相同。 因此在一般情况下,要计算\(a^b\),我们可以将结果初始化为\(1\),乘数初始化为\(a\)。然后,我们从低到高遍历\(b\)的每一个二进制位,如果当前位的值是\(1\),则将当前的乘数乘到结果上。然后不论当前位的值是\(0\)还是\(1\),都将乘数更新为它的平方。

基于这个角度,我们可以实现循环版本的快速幂算法。python实现如下:

def quick_pow(a: int, b: int) -> int:
    """使用O(log b)时间复杂度计算a^b的算法(循环版本)。"""
    result = 1
    while b > 0:
        if b % 2 == 1:
            result *= a
        a *= a
        b //= 2
        
    return result

2. 快速幂取模算法

RSA需要将幂运算的结果对特定数字取模。同时,幂运算的增长速度很快,在算法竞赛中经常遇到指数运算的结果是一个天文数字的情况,此时,题目一般也会要求将结果对一个特定的数字取模。

基于快速幂算法,快速幂取模算法可以在\(O(log\ b)\)的时间复杂度内完成\(a^b\ mod\ p\)的计算。而且,当\(p\)较小时,无论\(a\)和\(b\)的值多大,都不会发生溢出。

快速幂取模算法的基本原理是取模运算的两个性质:\(a^b\ mod\ p = (a\ mod\ p)^b\ mod\ p\)以及\((a * b)\ mod\ p = (a\ mod\ p * b\ mod\ p)\ mod\ p\) 。我们依次证明这两个性质。

  1. 令\(m=a\ mod\ p\),此时有\(a=kp+m\)。\(a^b = (kp + m)^b\)。根据二项式定理,\((kp + m)^b\)的展开式中除了最后一项\(m^b\)中\(p\)的次数为\(0\)外,其它的项均为\(p\)的倍数,因此,\(a^b\ mod\ p = (kp + m)^b\ mod\ p=m^b\ mod\ p\),得证。
  2. 令\(m_1=a\ mod\ p\),\(m_2=b\ mod\ p\),此时有\(a=k_1p+m_1, b=k_2p+m_2\)。\(a*b = k_1k_2p^2+k_1m_2p+k_2m_1p+m_1m_2\)。除\(m_1m_2\)外,其余各项均为\(p\)的倍数。因此,\((a*b)\ mod\ p=m_1m_2\ mod\ p\),得证。

基于这两个原理,我们在快速幂算法的基础上,首先对输入中的底数取模,并在每次乘法运算时,先对被乘数和乘数取模,并将结果相乘,再对结果取模。这就是快速幂取模算法。

快速幂取模算法的python实现如下:

def quick_pow_mod(a: int, b: int, p: int) -> int:
    """快速幂取模算法,计算a^b % p"""
    a = a % p
    result = 1
    while b > 0:
        if b % 2 == 1:
            result = (result * a) % p
        a = (a * a) % p
        b //= 2

    return result

3. 组合数取模算法

在算法竞赛中,除了幂运算之外,组合数运算也可能产生较大的天文数字,需要对结果进行取模防止溢出。下面介绍:计算组合数\(C(n,k)\)对符合一定条件的正整数\(p\)取模的方法。

3.1 组合数取模的基础算法

我们知道,\(C(n,k)=\frac{n!}{(n-k)!*k!}=\frac{n * (n - 1) * \cdots * (n - k + 1)}{1 * 2 * \cdots * k}\)。如果我们直接计算分子和分母的值,不仅速度很慢,而且当\(n\)和\(k\)较大时会发生溢出,一般的整型无法表示出分子和分母的准确数值。注意到分子和分母内部只包含乘法运算,我们可以使用模运算的乘法性质,分别求出分子和分母模\(p\)的余数。 此时,问题就被转换为:有两个数\(a\)和\(b\),我们不知道\(a\)和\(b\)的具体数值,只知道它们各自模另一个整数\(p\)的余数,且\(a\)是\(b\)的倍数。求\(a/b\)模\(p\)的余数。

当\(p\)和\(b\)不互质时,仅凭\(a\ mod\ p\)和\(b\ mod\ p\),无法唯一确定\((a/b)\ mod\ p\)的值,一个反例:\(a=14, b=2, p=6\)和\(a=8, b=2, p=6\)。而当\(p\)是质数,且\(p\)和\(a, b\)均互质时,这个问题是有高效的求解方法的。这也是算法竞赛中普遍选择大质数\(10^9+7\)作为\(p\)的值的原因。

注意,在下面的讨论中,我们默认\(p\)是质数,且\(p\)和\(a, b\)互质。由于组合数的分子\(a\)一定是分母\(b\)的倍数,不妨设\(a=kb,即a/b=k\)。

设想:如果我们能找到一个数\(b'\),使得\(b * b'\ mod\ p =1\),则有:

\[\begin{aligned} (a * b')\ mod\ p &= (k * b * b')\ mod\ p \\ &=(k\ mod\ p)*(b*b'\ mod\ p)\ mod\ p \\ &= (k\ mod\ p * 1)\ mod\ p \\ &= k\ mod\ p \\ &= (a/b)\ mod\ p \end{aligned}\]

也就是说,如果我们能找到这样的\(b'\),我们就能把模运算的除法问题转换为模运算的乘法问题!满足\(b * b'\ mod\ p=1\)的\(b'\)称为\(b\)的模\(p\)逆元。模逆元存在的充要条件是\(b\)和\(p\)互质,如果\(b\)和\(p\)的最大公约数不是1,则不论\(k\)取什么整数,\(kb\ mod\ p\)永远是最大公约数的倍数,不可能为1。

费马小定理可以帮助我们找到这样的\(b'\)。费马小定理指出:如果\(b\)是一个整数,\(p\)是一个质数,且\(b\)不是\(p\)的倍数,则有\(b^{p-1}\equiv p(mod\ 1)\)。例如5是一个整数,3是一个质数,5不是3的倍数,则\(5^{3-1}=25\)对3取模的结果是1。

由于\(b\)和\(b^{p-2}\)的乘积正好是\(b^{p-1}\),而\(b^{p-1}\ mod\ p=1\),这也就意味着,\(b\)的一个模\(p\)逆元就是\(b^{p-2}\)!由模运算的乘法性质可知,如果\(b^{p-2}\)是\(b\)的一个模\(p\)逆元,那\(b^{p-2}\ mod\ p\)也是\(b\)的一个模\(p\)逆元。尽管\(b\)和\(p\)的值可能很大,但我们可以使用快速幂取模算法,高效求出\(b^{p-2}\ mod\ p\)的值。

得到了\(b\)的模\(p\)逆元\(b'\),我们就可以将\(a / b\ mod\ p\)转换为\(a * b'\ mod\ p\),并利用模运算的乘法性质,在不发生溢出的情况下求出它的值了。

3.2 组合数取模的高效算法:Lucas定理

Lucas定理指出:

\[C(n, k)\ mod\ p=C(n // p, k // p)\ * C(n\ mod\ p, k\ mod\ p)\ mod\ p\]

当\(n\)和\(k\)很大时,\(n\ mod\ p\)和\(k\ mod\ p\)都小于\(p\),而\(C(n // p, k // p)\)可以用Lucas定理进一步化简。因此,当问题规模很大时,Lucas定理可以帮助我们高效地化简问题。