区间动态规划例题:统计不同非空回文子序列个数问题
题目描述
给定一个字符串 S,找出 S 中不同的非空回文子序列的个数。结果可能很大,请返回它对 10^9 + 7 取模后的结果。
子序列是从原字符串中删除0个或多个字符后,不改变剩余字符相对顺序形成的新字符串。如果两个子序列对应的字符串不同,就认为它们是不同的。
示例
输入:S = "bccb"
输出:6
解释:6个不同的非空回文子序列为:'b', 'c', 'bb', 'cc', 'bcb', 'bccb'。
解题思路
这是一个典型的区间动态规划问题。我们需要统计一个字符串的所有子序列中,哪些是回文且互不相同。直接枚举所有子序列会超时,因此需要使用动态规划来高效地计数。
-
定义状态
我们定义dp[i][j]为字符串 S 在区间[i, j]内的不同回文子序列的个数。 -
状态转移方程
考虑如何由更小的区间状态推导出dp[i][j]。- 基本情况:当
i == j时,区间只有一个字符,它本身就是一个回文子序列。所以dp[i][j] = 1。 - 一般情况:当
i < j时,我们考虑区间[i, j]的回文子序列。它们可以分为两大类:
a. 不包含 S[i] 和 S[j] 的子序列:这些子序列完全包含在区间[i+1, j-1]内。所以这部分的数量就是dp[i+1][j-1]。
b. 包含 S[i] 或 S[j] 或两者都包含的子序列:这是比较复杂的部分,我们需要仔细处理以避免重复计数,并确保是回文。
更精确的推导如下:
我们考虑四种情况,基于 S[i] 和 S[j] 是否相等:-
情况 1: S[i] != S[j]
此时,区间[i, j]的回文子序列可以由三部分不重叠的集合组成:- 在区间
[i+1, j]内的回文子序列(这些子序列不包含 S[i])。 - 在区间
[i, j-1]内的回文子序列(这些子序列不包含 S[j])。 - 在区间
[i+1, j-1]内的回文子序列(这些子序列既不包含 S[i] 也不包含 S[j]),但这一部分被前两部分重复计算了(因为[i+1, j-1]既包含在[i+1, j]中也包含在[i, j-1]中)。
根据容斥原理,我们有:
dp[i][j] = dp[i+1][j] + dp[i][j-1] - dp[i+1][j-1]
- 在区间
-
情况 2: S[i] == S[j]
这种情况要复杂一些。设c = S[i] = S[j]。
此时,区间[i, j]的回文子序列除了包含情况1中的三部分,还有一类新的回文子序列:它们以字符c开头和结尾。
具体来说,我们可以这样构造:- 首先,区间
[i+1, j-1]内的任何回文子序列,我们都可以在它的左右两边各加上一个字符c,形成一个新的、更长的回文子序列(例如,如果[i+1, j-1]内有 "aa",那么可以形成 "caac")。 - 此外,单个字符
c本身(即子序列 "c")也是一个回文子序列。注意,这个 "c" 可能已经在[i+1, j-1]中存在了,但当我们用两边的c包裹一个空字符串时,实际上得到的就是 "cc"。那么 "c" 本身从哪里来呢?我们需要单独考虑。
一个更清晰的思路是:
令L为i右边第一个字符等于c的位置(如果存在),R为j左边第一个字符等于c的位置(如果存在)。我们需要根据L和R的相对位置来避免重复计数。
但一个更简洁且正确的递推式是:
当S[i] == S[j]时,dp[i][j] = dp[i+1][j-1] * 2 + 2
这个公式的解释是:
dp[i+1][j-1]:这是中间部分的所有回文子序列。对于中间的每一个回文子序列,我们都可以选择 不加 两边的c(得到原来的子序列)或者 加 两边的c(得到一个新的子序列)。所以是乘以 2。+ 2:这对应两种特殊情况:- 加两边的
c到空序列上,得到 "cc"(注意,空序列本身不是非空子序列,所以这里我们实际上是在创造新的序列)。 - 只取一个
c,即子序列 "c"。
但是,这里有一个陷阱:如果中间部分[i+1, j-1]中已经包含了字符c,那么子序列 "c" 可能已经被计算在dp[i+1][j-1]中了。如果我们直接*2 + 2,会导致 "c" 被重复计算。
- 加两边的
因此,正确的做法是找到区间
[i+1, j-1]内最左边和最右边的字符c的位置,记为low和high。- 如果
low > high,说明[i+1, j-1]中没有字符c。那么:
dp[i][j] = dp[i+1][j-1] * 2 + 2
+2对应的是新创建的两个序列:"c" 和 "cc"。 - 如果
low == high,说明[i+1, j-1]中只有一个字符c。那么:
dp[i][j] = dp[i+1][j-1] * 2 + 1
+1对应的是新创建的序列 "cc"。因为序列 "c" 已经存在于dp[i+1][j-1]中了,所以不能重复加。 - 如果
low < high,说明[i+1, j-1]中有至少两个字符c。那么区间[low+1, high-1]的回文子序列已经被dp[low][high]计算过,并且它们在dp[i+1][j-1]中被重复计算了(因为当我们用两边的c包裹[i+1, j-1]时,包裹[low+1, high-1]和直接包裹[i+1, j-1]会产生重复)。所以需要减去这部分重复:
dp[i][j] = dp[i+1][j-1] * 2 - dp[low+1][high-1]
这里不需要+2或+1,因为单个 "c" 和 "cc" 都已经通过包裹更小的区间的方式被正确地创建或排除了。
- 首先,区间
- 基本情况:当
-
初始化
对于所有i > j的情况,区间是无效的,我们设dp[i][j] = 0。
对于i == j的情况,dp[i][j] = 1。 -
计算顺序
由于大区间[i, j]依赖于小区间[i+1, j],[i, j-1],[i+1, j-1],所以我们需要按照区间长度len从小到大的顺序来遍历和计算dp数组。即先计算所有长度为1的区间,然后是长度为2的区间,以此类推,直到长度为n。 -
最终结果
最终我们要求的是整个字符串S的不同回文子序列个数,即dp[0][n-1]。
代码实现(思路伪代码)
- 初始化一个二维数组
dp,大小为n x n,初始化为0。 - 对于所有
i,设置dp[i][i] = 1。 - 对于长度
len从 2 到n:- 对于起点
i从 0 到n - len:- 设置
j = i + len - 1。 - 如果
S[i] == S[j]:- 找到区间
[i+1, j-1]内最左边的S[i]的位置low。 - 找到区间
[i+1, j-1]内最右边的S[i]的位置high。 - 如果
low > high(没有找到):dp[i][j] = dp[i+1][j-1] * 2 + 2
- 否则如果
low == high(找到一个):dp[i][j] = dp[i+1][j-1] * 2 + 1
- 否则(找到多个):
dp[i][j] = dp[i+1][j-1] * 2 - dp[low+1][high-1]
- 找到区间
- 否则(
S[i] != S[j]):dp[i][j] = dp[i+1][j] + dp[i][j-1] - dp[i+1][j-1]
- 对
dp[i][j]取模(防止溢出)。
- 设置
- 对于起点
- 返回
dp[0][n-1]。
这个算法的时间复杂度是 O(n²),空间复杂度也是 O(n²)。通过精细地处理字符相等时的情况,避免了重复计数,确保了结果的正确性。