AtCoder ABC 441-E 题解

题目描述

给定一个由 ABC 三种字符组成、长度为 N 的字符串 S。求 S 的所有非空连续子串中,包含的 A 的数量多于 B 的数量的子串个数。

输入格式

1
2
N
S

输出格式

输出满足条件的子串个数。

约束条件

  • 1N5×1051 \leq N \leq 5 \times 10^5
  • S 是由 ABC 组成的长度为 N 的字符串

解题思路

前缀和转换

对于 i=0,1,2,,Ni = 0, 1, 2, \ldots, N,定义:

  • AiA_i:S 前 i 个字符中 A 的数量
  • BiB_i:S 前 i 个字符中 B 的数量

考虑 S 的第 i 到第 j 个字符构成的子串(iji \leq j),该子串满足条件当且仅当:

AjAi1>BjBi1A_j - A_{i-1} > B_j - B_{i-1}

等价于:

(AjBj)>(Ai1Bi1)(A_j - B_j) > (A_{i-1} - B_{i-1})

差分数组

定义差分数组 D=(D0,D1,,DN)D = (D_0, D_1, \ldots, D_N),其中 Di=AiBiD_i = A_i - B_i

那么问题转化为:求满足 Di<DjD_i < D_j0i<jN0 \leq i < j \leq N 的整数对 (i,j)(i, j) 的数量。

这类似于求逆序对的问题,但这里求的是顺序对的数量。

前缀统计

维护一个数组 counter[d + N],表示在当前位置之前,满足 Dj=dD_j = djj 的个数(由于 dd 的范围是 [N,N][-N, N],我们需要偏移 N 使索引非负)。

同时维护 sum,表示所有小于当前 DiD_icounter[d] 之和。

扫描过程

初始时 D0=0D_0 = 0counter[N] = 1(表示 D0=0D_0 = 0 出现了一次)。

对于每个字符 c:

  1. 如果 c = ADi=Di1+1D_i = D_{i-1} + 1
  2. 如果 c = BDi=Di11D_i = D_{i-1} - 1
  3. 如果 c = CDi=Di1D_i = D_{i-1}

对于每个新的 DiD_i,将当前满足 Dj<DiD_j < D_i 的所有 j 都可以作为左端点,i 作为右端点,构成一个满足条件的子串。

具体步骤

  1. 遇到 ADiD_i 增加 1,sum 增加 counter[D](因为新的更大的 DiD_i 可以与之前较小的 DjD_j 配对)
  2. 遇到 BDiD_i 减少 1,sum 减少 counter[D](因为新的更小的 DiD_i 不能与之前较小的 DjD_j 配对)
  3. 更新 counter[D_i],将当前 DiD_i 记录下来

代码实现

C++ 版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#include <iostream>
#include <vector>
using namespace std;

int main() {
unsigned N;
string S;
cin >> N >> S;

vector<unsigned> counter(2 * N + 1); // counter[d + N] = D_j = d 的数量
unsigned D{N}; // D 的初始值偏移为 N
++counter[D]; // D_0 = 0

unsigned long sum{}; // 小于当前 D 的 counter 之和
unsigned long ans{};

for (const auto c : S) {
if (c == 'A') {
sum += counter[D++];
} else if (c == 'B') {
sum -= counter[--D];
}
// C 不影响 D 值

++counter[D];
ans += sum;
}

cout << ans << endl;
return 0;
}

Python 版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def solve():
N = int(input())
S = input().strip()

# D 的范围是 [-N, N],偏移 N 后为 [0, 2N]
counter = [0] * (2 * N + 1)
D = N # 初始 D_0 = 0,偏移后为 N
counter[D] = 1

sum_less = 0
ans = 0

for c in S:
if c == 'A':
sum_less += counter[D]
D += 1
elif c == 'B':
D -= 1
sum_less -= counter[D]
# C 不改变 D

counter[D] += 1
ans += sum_less

print(ans)

if __name__ == "__main__":
solve()

复杂度分析

  • 时间复杂度O(N)O(N),只需扫描一遍字符串
  • 空间复杂度O(N)O(N),需要大小为 2N+12N+1 的计数数组

示例解释

示例 1

1
2
输入:ACBBCABCAB
输出:8

满足条件的子串(标号从 1 开始):

  • A:位置 1
  • AC:位置 1-2
  • CA:位置 5-6
  • CABCA:位置 5-9
  • A:位置 6
  • ABCA:位置 6-9
  • CA:位置 8-9
  • A:位置 9

示例 2

1
2
输入:CCBC
输出:0

所有子串中 A 的数量都不少于 B 的数量。


总结

这道题的关键在于将问题转化为差分数组的顺序对计数问题。通过巧妙的前缀统计,我们可以在 O(N)O(N) 时间内解决问题,而不需要使用 O(N2)O(N^2) 的暴力枚举。


参考AtCoder ABC 441 E 官方题解


AtCoder ABC 441-E 题解
http://tingshuoyiqing.top/2026/01/18/AtCoder-ABC-441-E-题解/
作者
ting
发布于
2026年1月18日
许可协议