KMP

2025-05-18
2482字

KMP算法的切入点很简单,是一个叫做最长相等真前后缀的东西。有点拗口,但是理解起来很简单,以字符串ababa为例:

  • 前缀:字符串ababa的前缀有aababaababababa。真前缀就是不等于自身的那些前缀。
  • 后缀:字符串ababa的后缀有abaabababaababa。真后缀就是不等于自身的那些后缀。
  • 相等真前后缀:真前缀中和真后缀中相等的那些:aaba
  • 最长相等真前后缀:相等真前后缀中最长的那一个:aba

最长相等真前后缀可以用于优化字符串比较。

接下我们从最原始的朴素暴力解开始一步步探索最长相等真前后缀的作用。

朴素字符串匹配算法中,第一步就是把模式串p和主串s左对齐,然后依次比较对应的字符,直到整个模式串都匹配完(匹配到了对应字符串)或出现了不相等的字符。

上图中的红色字位置,模式串和主串的字符不同。这意味着主串从0开始的这个位置,不可能和模式串匹配了。下一步当然就是把模式串后移一格,然后再依次比较对应的字符。如下图:

后移一格之后,我们马上发现模式串第一个字符就和主串对应的位置不同。然后再后移一格:

还是第一个字符就不同。

这似乎走了太多弯路。

其实仔细思考一下,只有下图中紫色的部分对应相等的时候,模式串右移一格才是有意义的,否则这次匹配注定会失败:

所以我们通过瞪眼法很快就可以确定,直接把模式串移动到下图中的位置就行了,这个位置之前的位置一定是徒劳的比较。

仔细观察一下,主串中紫色的ab是字符串abcab的真后缀,模式串中紫色的ab是字符串abcab的真前缀。

这不就是字符串abcab的最长相等真前后缀吗?

所以我们可以得出结论:当字符匹配到s[i]p[j]发生失配时,可以直接将模式串右移到模式串的p[new_j]和主串s[i]对齐的位置。其中new_j是字符串p[0:j](左闭右开,不包括p[j],下同)的最长相等真前后缀的长度

如果能确定模式串中每个p[0:j]的最长相等真前后缀,那么后续比较过程的时间复杂度就是$O(n)$($n$为主串的长度)。因为对于主串的指针i是不会回退的,对于任意一个位置i,要么和模式串对应字符一致,跳转到i = i + 1,要么模式串通过上述结论在常数时间内向右移动出界了(出界之后当然也是要i = i + 1的)。

如果使用暴力方法,计算p[0:j]最长相等真前后缀的时间复杂度是$O(j^2)$,计算出整个模式串每个位置的最长相等真前后缀的时间复杂度就是$O(m^3)$($m$是模式串的长度)。显然这是不可接受的。

幸运的是,我们可以通过动态规划将时间复杂度优化到$O(m)$。

以字符串abacababac为例,我们可以通过瞪眼法得到下面的结果:

字符 a b a c a b a b a c
index 0 1 2 3 4 5 6 7 8 9
PMT 0 0 1 0 1 2

其中PMT(Partial Match Table)表示从字符串开始到当前字符(包含当前字符)形成的子字符串的最长相等真前后缀的长度。例如前三个字符形成的子字符串aba的最长相等真前后缀是1。

在计算PMT[6]时,可以根据PMT[5]来推导:

PMT[5]是2,意味着s[0] == s[4]s[1] == s[5]。所以直接比较s[2]s[6]是否相等,如果s[2] == s[6],说明最长相等真前后缀在两个紫色的基础上增加了一个,因此直接令PMT[6] = PMT[5] + 1

这个逻辑比较好理解。

接下来计算PMT[7]PMT[6]是3,所以比较s[7]s[3]。坏了,它俩不相等。接下来怎么办呢?退回到暴力方法?那倒也不用。看下面这个图(PMT[6] == 3意味着图中浅黄色和浅绿色是相等的):

这个图和图1简直一模一样。此时s[7] != p[3],接下来当然是把“模式串”往右移。移多少呢?根据前面的结论,移到p[new_j]p[PMT[3 - 1]]p[1]s[7]对应。注意:这里用到的PMT一定是前序过程计算过的。

如果“模式串”右移出界都没有找到相等真前后缀,那么这个位置的PMT值就是0。

根据上面的算法,很容易补全PMT:

字符 a b a c a b a b a c
index 0 1 2 3 4 5 6 7 8 9
PMT 0 0 1 0 1 2 3 2 3 4

虽然PMT也能用,但是每次s[i] != p[j]时,都要去找PMT[j - 1],不仅麻烦,而且还要处理j - 1可能的下标越界问题。如果把PMT集体右移得到next数组就更好用了:

字符 a b a c a b a b a c
index 0 1 2 3 4 5 6 7 8 9
PMT 0 0 1 0 1 2 3 2 3 4
next -1 0 0 1 0 1 2 3 2 3

前面补一个-1倒是无所谓,但是最后一个PMT移出去了怎么呢?不办。因为它根本就用不上。

有了next数组,就可以轻松实现KMP算法了:

def kmp(s: str, p: str) -> int:
    if len(s) < len(p):
        return -1

    nxt = get_next_arr(p)

    n, m = len(s), len(p)
    i = j = 0

    while i < n and j < m:
        if s[i] == p[j]:
            i += 1
            j += 1
        else:
            if j == 0:
                i += 1
            else:
                j = nxt[j]

    return i - m if j == m else -1

get_next_arr怎么实现呢?

我们先根据计算PMT的步骤写出计算PMT的算法:

def get_pmt(p: str) -> list[int]:
    pmt = [0] * len(p)

    for i in range(1, len(p)):
        k = pmt[i - 1]
        while p[i] != p[k]:
            if k == 0:
                k = -1
                break
            else:
                k = pmt[k - 1]
        pmt[i] = k + 1

    return pmt

这个算法完全按照推导步骤,在计算pmt[j]的时候用到了pmt[j - 1],但是这种写法不利于迁移到计算next中,可以采用通过pmt[j]推导pmt[j + 1]的方式重写这个算法:

def get_pmt_forward(p: str) -> list[int]:
    pmt = [0] * len(p)

    for i in range(len(p) - 1):
        k = pmt[i]
        while p[i + 1] != p[k]:
            if k == 0:
                k = -1
                break
            else:
                k = pmt[k - 1]
        pmt[i + 1] = k + 1

    return pmt

如果仔细观察代码,在循环末尾pmt[i + 1] = k + 1,在下一次循环开始又k = pmt[i],也就是k = k + 1,有点蠢,优化掉:

def get_pmt_forward_rev(p: str) -> list[int]:
    pmt = [0] * len(p)
    k = 0

    for i in range(len(p) - 1):
        while p[i + 1] != p[k]:
            if k == 0:
                k = -1
                break
            else:
                k = pmt[k - 1]
        k += 1
        pmt[i + 1] = k

    return pmt

从这个算法就很容易推导出计算next数组的算法了:

def get_next_arr(p: str) -> list[int]:
    nxt = [-1] * len(p)
    k = -1

    for i in range(len(p) - 1):
        while k != -1:
            if p[i] == p[k]:
                break
            else:
                k = nxt[k]
        k += 1
        nxt[i + 1] = k

    return nxt

除了调整循环的终止条件(防止下标越界)和k = pmt[k - 1]变为k = nxt[k](PMT和next的关系决定的),其余部分完全一致。

你也可以对代码稍作修改,变成下面这个最普遍的样子:

def get_next_arr_rev(p: str) -> list[int]:
    nxt = [-1] * len(p)
    k = -1
    i, n_1 = 0, len(p) - 1

    while i < n_1:
        if k == -1 or p[i] == p[k]:
            k += 1  # 因为nxt[0] == -1,k += 1可以推广到整个数组
            i += 1
            nxt[i] = k
        else:
            k = nxt[k]

    return nxt