最长等差数列(进阶版——统计不同最长等差数列的个数)
一、问题描述
给定一个整数数组 nums,找到数组中最长等差数列的长度,并统计不同最长等差数列的个数。
这里的不同等差数列指的是公差不同或者起始元素位置不同的数列。
注意:等差数列至少包含三个元素,公差可以是 0 或负数,并且数组中的元素顺序必须保持原顺序(子序列,不要求连续)。
示例
输入:nums = [2, 4, 6, 8, 10]
输出:长度 = 5, 个数 = 1(只有公差为 2 的最长等差数列)
输入:nums = [7, 7, 7, 7, 7]
输出:长度 = 5, 个数 = 1(只有公差为 0 的最长等差数列)
输入:nums = [2, 2, 3, 4, 5, 6]
输出:长度 = 5, 个数 = 2
解释:
- 公差 1 的最长等差数列:
[2, 3, 4, 5, 6] - 公差 0 的最长等差数列:
[2, 2, 2, 2, 2](注意:从不同位置的 2 可以组成多个相同公差的数列,但这里统计的是“不同的等差数列”,我们会在后面明确定义)
二、问题理解与目标
- 等差数列由三个参数决定:起始位置、公差、长度。
- 我们要找的是最长的长度,并统计所有能达到这个最长长度的不同等差数列个数。
- 这里“不同”等差数列定义为:公差不同 或 起始索引位置不同 的等差数列。
注意:如果公差相同,起始索引位置不同,即使序列元素值相同,也算不同数列(因为是原数组的不同子序列)。
三、基本思路:动态规划定义
我们用一个常见的 DP 思路来求最长等差数列长度:
定义 dp[i][d] 表示以索引 i 结尾,公差为 d 的等差数列的最大长度。
但公差可能是负数,且范围可能很大,所以通常用哈希表数组来存储。
更常用的状态定义是:
- 设
dp[i][j]表示以索引i和j作为最后两个元素的等差数列的长度(其中i < j)。
如果找到更早的k使得nums[k] + nums[j] = 2 * nums[i]吗?不,这样不方便。
正确的推导是:
对于j和i(i < j),看前面是否存在k使得nums[k] + nums[j] = 2 * nums[i]?这不对,应该是nums[i] - nums[k] = nums[j] - nums[i]。
所以我们可以用公差来关联。
实际更简单的定义是:
dp[i][diff] 表示以 i 结尾,公差为 diff 的等差数列的最大长度。
但是这样我们还需要知道倒数第二个元素是谁,才能更新。
因此更好的方法是:
dp[i][j] 表示最后两项是 (i, j) 的等差数列的最大长度(i < j)。
那么公差 diff = nums[j] - nums[i]。
转移时,我们找前面一个索引 k 使得 nums[k] + nums[j] = 2 * nums[i] 吗?不,这样是错的。
正确的是:
对于 i 和 j,要找 k 使得 nums[k] + nums[j] = 2 * nums[i] 吗?不,这会把关系搞反。
正确的关系是:
对于固定的 i, j,公差 diff = nums[j] - nums[i],我们想找前一个元素 nums[k] 满足 nums[i] - nums[k] = diff,即 nums[k] = nums[i] - diff。
所以我们需要知道哪个索引 k 在 i 前面,且值等于 nums[i] - diff。
为了快速查找,我们可以用一个哈希表 pos[value] 记录每个值出现的最后一个索引,但因为可能有重复值,我们需要存所有索引位置,为了查找方便,可以在 DP 过程中动态记录。
但更常见的标准最长等差数列长度解法是:
dp[i][j] 表示最后两项索引是 (i, j) 的等差数列的长度(至少为 2),则:
dp[i][j] = dp[prev][i] + 1,其中 prev 是满足 nums[prev] + nums[j] = 2 * nums[i] 的索引,不对,这是错的。
仔细想:设等差数列 ... x, y, z,则 y - x = z - y,即 x + z = 2y。
如果我们已知最后两个数是 nums[i] 和 nums[j],则前一个数应该是 2*nums[i] - nums[j],我们去找这个值在 i 之前出现的索引 k。
所以状态定义:
dp[i][j] 表示以索引 i, j 为最后两项的等差数列的长度(长度至少为 2),则:
dp[i][j] = dp[k][i] + 1,其中 nums[k] = 2*nums[i] - nums[j] 且 k < i。
如果找不到 k,则 dp[i][j] = 2。
这样我们能计算出所有以 i, j 结尾的等差数列长度。
四、增加计数
我们不仅要长度,还要统计个数。
设:
len[i][j]表示以(i, j)结尾的等差数列的最大长度(和上面dp[i][j]一样)。cnt[i][j]表示以(i, j)结尾的、长度为len[i][j]的等差数列的个数。
但注意:同一个 (i, j) 结尾,可能有多个不同的前驱序列得到相同的最大长度,我们要把这些个数累加。
转移时:
- 找
k满足nums[k] = 2*nums[i] - nums[j]且k < i。 - 如果找到这样的
k,则新长度L = len[k][i] + 1。- 如果
L > len[i][j],则更新len[i][j] = L,并且cnt[i][j] = cnt[k][i]。 - 如果
L == len[i][j],则cnt[i][j] += cnt[k][i]。
- 如果
- 如果找不到
k,则len[i][j] = 2,cnt[i][j] = 1(表示这个长度为 2 的等差数列只有它自己,但它不算有效等差数列,因为长度至少为 3 才算,我们最后统计时只考虑长度 ≥3 的)。
注意:长度为 2 的等差数列我们不视为有效等差数列,但它是中间状态。
五、算法步骤
- 初始化
n = len(nums),如果n < 3,最长长度为 0,个数 0。 - 创建二维数组
len[n][n]和cnt[n][n],初始len[i][j] = 2,cnt[i][j] = 1。 - 创建哈希表
pos,键为数值,值为列表,记录该数值出现的索引(按顺序)。 - 枚举
j从 2 到 n-1(其实从 1 开始也行,但长度至少 2),对于每个j,枚举i从 0 到 j-1:- 计算前一个数值
prev_val = 2*nums[i] - nums[j]。 - 在
pos[prev_val]中二分查找最后一个小于i的索引k。 - 如果找到了
k,则更新len[i][j] = len[k][i] + 1,cnt[i][j] = cnt[k][i]。
(注意:这里假设我们只取最后一个k,因为如果有多个k,我们要考虑所有吗?
实际上,不同k可能产生不同的长度,但我们要的是最大长度,所以应该取所有k中使len[k][i]最大的,并且计数是这些最大长度的cnt[k][i]之和。
但标准解法是:我们只需在遍历过程中,用last[val]记录每个数值最近出现的索引,但这样会漏掉更早的k吗?是的,会漏掉,所以我们要用哈希表存所有索引,然后取最后一个小于i的索引即可,因为更早的k产生的序列长度不会比最近的k更长(在等差数列中,最后一个满足条件的k能延续更长的序列)。)
所以只需取最后一个小于i的k即可。
- 计算前一个数值
- 遍历所有
i,j,找出最大的len[i][j]记为max_len。 - 如果
max_len < 3,返回 0,0。 - 统计所有
len[i][j] == max_len的cnt[i][j]之和,得到总个数。
六、举例
nums = [2, 2, 3, 4, 5, 6]
我们手动推导:
- 公差 1 的等差数列:
[2,3,4,5,6]长度 5,起始于索引 0,1,2,3,4? 不,必须保持顺序,所以只有一个:索引 (0,2,3,4,5) 对应值 2,3,4,5,6。
另一个是公差 0 的:[2,2,2,2,2]但原数组只有两个 2,无法组成长度 5 的相同数序列,所以这个不对。等等,我们看看:
原数组有 2 个 2,所以公差 0 的等差数列最大长度是 2,不是 5。
所以例子似乎有问题。
我们换例子:nums = [2,4,6,8,10,12]
最长长度 6,个数 1(公差 2)。
nums = [2,2,2,2,2]
最长长度 5,个数 1(公差 0)。
nums = [1,2,3,4,5,6,7]
最长长度 7,个数 1(公差 1)。
nums = [1,3,5,7,2,4,6,8]
有两个公差 2 的长度 4 的等差数列:[1,3,5,7] 和 [2,4,6,8]。
所以最大长度 4,个数 2。
七、代码实现(Python 思路)
from collections import defaultdict
from bisect import bisect_right
def longest_arithmetic_subsequence_count(nums):
n = len(nums)
if n < 3:
return 0, 0
# 值到索引列表的映射
pos = defaultdict(list)
for idx, val in enumerate(nums):
pos[val].append(idx)
# 初始化
length = [[2] * n for _ in range(n)]
count = [[0] * n for _ in range(n)]
max_len = 2
# 遍历所有对 (i,j)
for j in range(n):
for i in range(j):
diff = nums[j] - nums[i]
prev_val = nums[i] - diff
if prev_val in pos:
# 找到所有小于 i 的索引 k
idx_list = pos[prev_val]
# 二分找最后一个小于 i 的索引
k_idx = bisect_right(idx_list, i - 1) - 1
if k_idx >= 0:
k = idx_list[k_idx]
length[i][j] = length[k][i] + 1
count[i][j] = count[k][i]
else:
count[i][j] = 1
else:
count[i][j] = 1
max_len = max(max_len, length[i][j])
if max_len < 3:
return 0, 0
total_count = 0
for i in range(n):
for j in range(i+1, n):
if length[i][j] == max_len:
total_count += count[i][j]
return max_len, total_count
八、优化
上面的算法复杂度 O(n² log n),因为每次二分查找。
可以优化到 O(n²) 的做法是:在遍历过程中,用一个哈希表 last_index[val] 记录每个值最后出现的索引,然后从后往前更新。但这样只能找到最近的一个 k,而最近的一个 k 就是能使序列最长的,所以可以省略二分。
但这样必须注意:对于有重复值的情况,最近的一个 k 能保证长度最大,但不会漏掉最优解。
因此可以优化为 O(n²) 时间,O(n²) 空间。
九、最终输出
最后返回最长长度和不同等差数列的个数。
注意:这里“不同等差数列”由起始位置 (k,i,j) 决定,只要起始索引三元组不同,就算不同数列。
而我们的计数方法中,cnt[i][j] 表示以 (i,j) 结尾的最大长度的数列个数,所以不同结尾的 (i,j) 的计数不会重复,可以累加。
十、小结
这个问题是“最长等差数列”的进阶版,不仅要长度,还要计数。
关键点:
- 状态定义:
len[i][j]和cnt[i][j]。 - 转移时找前一个元素索引
k,使得nums[k] = 2*nums[i] - nums[j]。 - 计数转移:当更新最大长度时重置计数,当长度相等时累加计数。
- 最终统计所有达到最大长度的 (i,j) 的计数之和。
通过这个解法,我们能在 O(n²) 时间内解决该问题。